# Sampling the VAE

By [Allison Parrish](http://www.decontextualize.com/)

I wrote a little helper class to make it easier to sample strings from the VAE model—in particular, models trained with tokens from `bpemb`. This notebook takes you through the functionality, using the included `poetry_500k_sample` model.

In [48]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
import argparse, importlib
import torch
from vaesampler import BPEmbVaeSampler

First, load the configuration and assign the parameters to a `Namespace` object. Then, create the `BPEmbVaeSampler` object with the same `bpemb` parameters used to train the model and the path to the pre-trained model.

In [52]:
config_file = "config.config_poetry_500k_sample"
params = argparse.Namespace(**importlib.import_module(config_file).params)
bpvs = BPEmbVaeSampler(lang='en', vs=10000, dim=100,
                       decode_from="./models/poetry_500k_sample/2019-08-09T08:27:43.289493-011.pt",
                       params=params)



## Decoding

The main thing you'll want to do is decode strings from a latent variable `z`. This variable has a Gaussian distribution (or at least it *should*—that's the whole point of a VAE, right?). There are three methods for decoding strings from `z`:

* `.sample()` samples the (softmax) distribution of the output with the given temperature at each step;
* `.greedy()` always picks the most likely next token;
* `.beam()` expands multiple "branches" of the output and returns the most likely branch

(These methods use the underlying implementations in the `LSTMDecoder` class.)

Below you'll find some examples of each. First, `.sample()` with a temperature of 1.0. (Increase the temperature for more unlikely output; it approximates `.greedy()` as the temperature approaches 0.)

In [54]:
with torch.no_grad():
    print("\n".join(bpvs.sample(torch.randn(14, 32), temperature=1.0)))

Some painted
As the mind would reach in funeral hand!
On three vain gives nitence minstrels poets and any solemerus fly:
The glides pass in air to side their one;
The throne that kill of king,
'tis dread melanse of the deep white sublime,
Nor flame like life or wood.
To line the meadowing throat like vigorous flower. For even o'er the
(nell for I can have exultes,
But the spiritual king of greece;
Thus thus they wilt the greatest house they gain.
And slept at times in the secret loud
In least essence worship the fearfully plumes;
A wife in youth: these exads, honest two shagci


Greedy decoding (usually boring):

In [31]:
with torch.no_grad():
    print("\n".join(bpvs.greedy(torch.randn(14, 32))))

To the king, and the great and the night,
The wildly arose.
"and is thee, and thee, and the _thee_
When the first of the world.
The old man's sons of the earth
And I know thee and thee,
The wind of the old man's hand.
When the great of the old man's feet
I know it, and the same,
I have seen, and a young man,
And I have been in the air.
The mists of the old-born eyes.
A thousand years in the house.
That I have been in the world.


Beam search (a good compromise, but slow):

In [55]:
with torch.no_grad():
    print("\n".join(bpvs.beam(torch.randn(14, 32), 4)))

To make thee with all the world of the earth
There came in the sea
If the wind of the wind,
And ráma's self--and's bliss.
On the fiery wrought, and
As I'll have seen thee.
When I'll give thee,
But, if I know, and they
This is the sunset
To whom I have not not, I'll see.
Beheld them in the darkness, and in all,
A woman's eyes, and with anger
I have the golden
If I have seen them, and in his hand.


## Homotopies (linear interpolation)

Using the VAE, you can explore linear interpolations between two lines of poetry. The code in the cell below picks two points at random in the latent space and decodes at evenly-spaced points between the two:

In [93]:
with torch.no_grad():
    x = torch.randn(1, 32)
    y = torch.randn(1, 32)
    steps = 10
    for i in range(steps + 1):
        z = (x * (i/steps)) + (y * 1-(i/steps))
        #print(bpvs.sample(z, 0.2)[0])
        #print(bpvs.greedy(z)[0])
        print(bpvs.beam(z, 3)[0])

And, when the leaves of the sky.
That, the leaves of the wind.
As the wind of the wind.
As the leaves of the sky.
As the wind of the wind.
That, the wind of the sky.
That, the wind of the sky.
That, the wind of the sky.
When, when the sun
When the sun of the sky.
I saw the sun


Using this same logic, you can produce variations on a line of poetry by adding a bit of random noise to the vector:

In [85]:
with torch.no_grad():
    x = torch.randn(1, 32)
    steps = 14
    for i in range(steps + 1):
        z = x + (torch.randn(1, 32)*0.1)
        print(bpvs.sample(z, 0.35)[0])
        #print(bpvs.greedy(z)[0])
        #print(bpvs.beam(z, 4)[0])

From the windless and strange,
When I have heard the rustling, and with the earthly flame,
The red and a moment, and a vain;
A man and the sunshine,
When the wildness of the night, and the lips of the dear,
A little face, and the great and the same.
Than the first birds in the darkness, and the careless,
When I have been the sunset, and the music of the earth.
With the great and weeping, and the lips of the earth.
Into the fields of the glorious breeze,
To the glory of the grass and dome,
Against the sunset, and the long foolish,
Down to the world's bosom and aught.
When the night to the dreadful eyes.
Into the shrillness of the earth and the wall,


## Reconstructions

You can ask the model to produce the latent vector for any given input. (Using `BPEmb` helps ensure that arbitrary inputs won't contain out-of-vocabulary tokens.)

The `.z()` method returns a sample from the latent Gaussian, while `.mu()` returns the mean. You can then pass this to `.sample()`, `.beam()`, or `.greedy()` to produce a string. The model's reconstructions aren't very accurate, but you can usually see some hint of the original string's meaning or structure in the output.

In [72]:
strs = ["This is just to say",
        "I have eaten the plums",
        "That were in the icebox"]

In [74]:
bpvs.sample(bpvs.z(strs), 0.5)

['That is not to the sons of the light,',
 'This was the shrine of the dead,',
 'To the sinkling of the sea']

In [75]:
bpvs.beam(bpvs.mu(strs), 2)

['To whom I know,',
 "I'll have seen them in the air.",
 'As I have seen the sunset']

In [76]:
bpvs.greedy(bpvs.mu(strs))

['To whom the heart of the earth',
 "I'll have been the golden breeze,",
 'As the same of the old man,']