<a href="https://colab.research.google.com/github/raphaelgbe/Variational-AutoEncoders/blob/master/VAEs_for_text.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VAE for text data trained on PTB dataset with ELBO & MMD objectives (with annealing & embedding dropout):

In [1]:
  import os, sys, math, re
  import numpy as np
  import pandas as pd
  import torch
  import torch.nn as nn
  import torch.nn.functional as F
  import torchvision
  
  from collections import Counter
  
  import random
  
  from google.colab import drive
  drive.mount('/content/drive')
  
  seed = 4056
  random.seed(seed)
  torch.manual_seed(seed)
  
  # really helpful points of reference, from which this notebook takes 
  # inspiration: 
  # #https://arxiv.org/abs/1511.06349
  # #https://github.com/timbmg/Sentence-VAE 

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


<torch._C.Generator at 0x7f753fc6bb50>

In [0]:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [0]:
data_folder = 'drive/My Drive/text_data'

def import_text_dataset(path):
  
  res = []
  
  with open(path, 'r') as f:
    for line in f:
      res.append(line)
      
  return res

data_train = import_text_dataset(os.path.join(data_folder, 'ptb.train.txt'))
data_val = import_text_dataset(os.path.join(data_folder, 'ptb.valid.txt'))
data_test = import_text_dataset(os.path.join(data_folder, 'ptb.test.txt'))


In [0]:
other_words = ['<start>', '<end>', '<pad>']
vocab = set([x for s in data_train for x in s.split()]).union(set(other_words))

vocab_size = len(vocab)

vocab.update(set([x for s in data_val for x in s.split()]))
vocab.update(set([x for s in data_test for x in s.split()]))

total_vocab_size = len(vocab)

In [0]:
sentence_lengths = Counter()

sentence_lengths.update(list(map(len, data_train + data_val + data_test)))

In [6]:
print(len(data_train + data_val + data_test))

# quick statistics of sentence lengths to determine a good max_length
print(sum([v for k, v in sentence_lengths.items() if k >= 200]))

49199
4795


In [0]:
batch_size = 32
lr = 0.001
epochs = 10

loss_type = 'elbo' # or 'mmd'
anneal_function = 'logistic'
anneal_coeff = 0.0025
anneal_bias = 4000#2500
variational_coeff = 0.5

z_dim = 16
embedding_dim = 300
embedding_dropout = 0.5
hidden_dim = 256

max_sequence_length = 60 #200
min_occurrences = 1

word_dropout = 0.

# intiialy only implemented LSTM but might be interesting to look at other options:
rnn_type = 'GRU'

# four tokens pre-allocated: 0 for UNK, 1 for START, 2 for END and 3 for PAD

In [0]:
def compute_kernel(x, y):
  
    x_size = x.shape[0]
    y_size = y.shape[0]
    dim = x.shape[1]
    tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1)
    tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1)
    return torch.exp(- torch.mean((tiled_x - tiled_y)**2, dim=2) / float(dim))

def compute_mmd(x, y):
  
  x_kernel = compute_kernel(x, x)
  y_kernel = compute_kernel(y, y)
  xy_kernel = compute_kernel(x, y)
  return (x_kernel + y_kernel - 2 * xy_kernel).mean()


In [0]:
class TextVAE(nn.Module):
  
  def __init__(self, embedding_dim=embedding_dim, hidden_dim=hidden_dim, 
               z_dim=z_dim, word_dropout=word_dropout, vocab_size=vocab_size, 
               embedding_dropout = embedding_dropout, min_occ=min_occurrences, 
               max_len=max_sequence_length, anneal_bias=anneal_bias, 
               anneal_coeff=anneal_coeff, variational_coeff=variational_coeff, 
               loss_type=loss_type, anneal_function=anneal_function, 
               rnn_type=rnn_type, generate_cell_state=False):
    
    super(TextVAE, self).__init__()
    
    # Attributes definition:
    
    self.max_len = max_len
    self.min_occ = min_occ
    self.word_dropout = word_dropout
    self.anneal_coeff = anneal_coeff
    self.variational_coeff = variational_coeff
    self.loss_type = loss_type
    self.anneal_function = anneal_function
    self.hidden_dim = hidden_dim
    self.vocab_size = vocab_size
    self.rnn_type = rnn_type
    if self.rnn_type == 'LSTM':
      self.generate_cell_state = generate_cell_state
    
    # Layers definition: 
    
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.embedding_dropout = nn.Dropout(p=embedding_dropout)
    
    if self.rnn_type == 'LSTM':
      self.encoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
      self.decoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
    elif self.rnn_type == 'GRU':
      self.encoder = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
      self.decoder = nn.GRU(embedding_dim, hidden_dim, batch_first=True)

    
    if loss_type == 'elbo':
      self.latent_dim = z_dim
      self.fc1 = nn.Linear(hidden_dim, self.latent_dim)
      self.fc2 = nn.Linear(hidden_dim, self.latent_dim)
    elif loss_type == 'mmd':
      self.latent_dim = 2 * z_dim
      self.fc1 = nn.Linear(hidden_dim, self.latent_dim)
    else:
      raise ValueError("Not a valid loss: choose between 'elbo' and 'mmd'.")
      
    self.fc_generate_hidden = nn.Linear(self.latent_dim, hidden_dim)
    if self.rnn_type == 'LSTM':
      if self.generate_cell_state:
        self.fc_generate_cell = nn.Linear(self.latent_dim, hidden_dim)
    self.final_layer = nn.Linear(hidden_dim, vocab_size)
    
    
  def forward(self, x, batch_lengths, batch_size=batch_size):
    
    # On keeping sigmoid activations in LSTMs:
    # https://www.quora.com/Why-does-an-LSTM-with-ReLU-activations-diverge
    
    original_input = x.clone() # needs to be used when decoding
    
    # Encoding:
    
    x = self.embedding(x)
      
    x = self.preprocess_sequential_input(x, batch_lengths)
    
    # hidden input set to 0:
    if self.rnn_type == 'LSTM':
      hidden_initial = (torch.zeros(1, batch_size, self.hidden_dim),
                            torch.zeros(1, batch_size, self.hidden_dim))
    
      outputs, hidden = self.encoder(x, hidden_initial)
    
      hidden = hidden[0] # hidden contains last (hidden_state, cell_state) pair
    
    else:
      hidden_initial = torch.zeros(1, batch_size, self.hidden_dim)
    
      outputs, hidden = self.encoder(x, hidden_initial)
    
    
    if self.loss_type == 'elbo':
      mean = self.fc1(hidden)
      logvar = self.fc2(hidden)
      z = self.reparametrize(mean, logvar)
    elif self.loss_type == 'mmd':
      z = self.fc1(hidden) 
      
    # Decoding:
    
    hidden_state = self.fc_generate_hidden(z)
    if self.rnn_type == 'LSTM':
      if self.generate_cell_state:
        cell_state = self.fc_generate_cell(z)
    
    
    if self.word_dropout:
      # the 'words' that are not to be predicted are excluded: START & PAD:
      exclusion_condition = original_input.data != 1 or original_input.data != 3
      
      modif_entries_shape = original_input[exclusion_condition].size()
      
      dropout_filter = torch.rand(modif_entries_shape)
      dropout_filter[dropout_filter.data < self.word_dropout] = 0
      dropout_filter[dropout_filter.data >= self.word_dropout] = 1
      
      original_input[exclusion_condition] = \
         original_input[exclusion_condition].mul(dropout_filter)
      
    
    x = self.embedding(original_input)
    
    if self.training:
      x = self.embedding_dropout(x)
      
    x = self.preprocess_sequential_input(x, batch_lengths)
    
    if self.rnn_type == 'LSTM':
      if self.generate_cell_state:
        hidden = (hidden_state, cell_state)
      else:
        hidden = (hidden_state, torch.zeros(1, batch_size, self.hidden_dim))
    else:
      hidden = hidden_state
    
    
    outputs, hidden = self.decoder(x, hidden) 
    
    padded_outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
    
    seq_len = padded_outputs.size(1)
    
    results = F.log_softmax(self.final_layer(padded_outputs.contiguous().view(-1, self.hidden_dim)), dim=-1)
      
    results = results.view(batch_size, seq_len, self.vocab_size)
    
    if self.training:
      if self.loss_type == 'elbo':
        return mean, logvar, z, results
      elif self.loss_type == 'mmd':
        return z, results
    else:
      return results
    
  def sample_sentence(self, nb_examples=1, z=None, start_tok=1, end_tok=2, pad_tok=3):
    
    if z is not None:
      nb_examples = z.size(0)
    else:
      z = torch.randn(nb_examples, self.latent_dim)
      
    z = z.unsqueeze(0)
     
    hidden_state = self.fc_generate_hidden(z)
    
    if self.rnn_type == 'LSTM':
      if self.generate_cell_state:
        cell_state = self.fc_generate_cell(z)
        hidden = (hidden_state, cell_state)
      else:
        hidden = (hidden_state, torch.zeros(1, nb_examples, self.hidden_dim))
    else:
      hidden = hidden_state
      
    t = 0
    
    input_seq = start_tok * torch.ones(nb_examples, 1)
    
    res_seq = pad_tok * torch.ones(nb_examples, self.max_len)
    
    ended_seqs = [0] * nb_examples
    
    while t < self.max_len and 0 in ended_seqs:
      
      x = self.embedding(input_seq.long())
      
      output, hidden = self.decoder(x, hidden)
      
      results = self.final_layer(output)
      
      generated_token = torch.max(results, dim=-1)[1][:, 0]
      
      res_seq[:, t] = generated_token
      
      if t < self.max_len - 1:
        input_seq = generated_token.unsqueeze(1).float()
        
      ended_seqs += (generated_token == end_tok).tolist()
      
      t += 1
      
    return z, res_seq
    
    
  def preprocess_sequential_input(self, input_sequence, batch_lengths):
    
    """
    For computational efficiency, we pad and pack the input sequence of the 
    encoder; note that when using batches, we make sure to sort input sequences
    by hand before the forward pass:
    https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
    https://discuss.pytorch.org/t/understanding-pack-padded-sequence-and-pad-packed-sequence/4099/4
    """
    
    return nn.utils.rnn.pack_padded_sequence(input_sequence, batch_lengths,
                                            batch_first = True)
    
  
  def reparametrize(self, mean, logvar, batch_size=batch_size):
    
    """
    Reparametrization trick: from mean, variance,and normal noise, we issue a
    Gaussian variable of corresponding mean and variance:
    https://stats.stackexchange.com/questions/199605/how-does-the-reparameterization-trick-for-vaes-work-and-why-is-it-important
    http://stillbreeze.github.io/REINFORCE-vs-Reparameterization-trick/
    """
    
    return mean + torch.randn((batch_size, self.latent_dim), 
                              requires_grad=True) * torch.exp(0.5 * logvar)
  
  def loss_nll(self, x, x_recon, batch_length):
    
    # for NLLLoss, we shape input as [-1, vocab_size] and target as [-1]
    # first, lengths have to be the same:
    x = x[:,:torch.max(batch_lengths).item()]
    x = x.contiguous().view(-1)
    x_recon = x_recon.contiguous().view(-1, x_recon.size(-1))
    
    NLL_loss = nn.NLLLoss(ignore_index=3)
    
    return NLL_loss(x_recon, x)
  
  def loss_mmd(self, z, batch_size=batch_size):
    
    true_samples = torch.randn(batch_size, self.latent_dim)
    return compute_mmd(true_samples, z.squeeze())
    
  def loss_elbo(self, mean, logvar):
    
    return - self.variational_coeff * torch.sum(logvar + 1 
                                                - logvar.exp() - mean**2)
  
  def annealing(self, step,anneal_function=anneal_function, 
                anneal_coeff=anneal_coeff, anneal_bias=anneal_bias):
    
    if anneal_function == 'logistic':
      return 1 / (1 + np.exp(-anneal_coeff* (step - anneal_bias)))
    else:
      raise NotImplementedError
  
  

In [0]:
from collections import OrderedDict
set_of_words = set([word for ls_words in [x.split() for x in data_train] for 
                    word in ls_words])
nb_words = len(set_of_words) + 3 # '<unk>' is in data_train but not <start>,
                                 # <end>, <pad>
  
word2token = OrderedDict({'<unk>': 0, '<start>': 1, '<end>': 2, '<pad>': 3})

for word in sorted(list(set_of_words)):
  if word != '<unk>':
    word2token[word] = len(word2token)

In [0]:
def convert_sentence_to_sequence_of_tokens(sentence, word2token=word2token):
  
  return [word2token[word] if word in word2token.keys() 
            else word2token['<unk>'] for word in sentence.split()]

def number_to_categorical(nb, vec_size):
  
  return [0] * nb  + [1] + [0] * (vec_size - 1 - nb)

def convert_sentence_to_sequence_of_cat_tokens(sentence, word2token=word2token):
  
  seq = convert_sentence_to_sequence_of_tokens(sentence, word2token)
  return [number_to_categorical(i, len(word2token)) for i in seq]

def convert_sequence_of_tokens_to_sentence(seq, word2token=word2token):
  
  return ' '.join([list(word2token.keys())[token] for token in seq 
                    if token != word2token['<pad>']])

def truncate_or_pad_input(input_sequence, max_len=max_sequence_length, pad_tok=3):
  
  if len(input_sequence) < max_len:
    return input_sequence + [pad_tok] * (max_len - len(input_sequence))
  else:
    return input_sequence[:max_len]

In [0]:
data_train_tokens = [convert_sentence_to_sequence_of_tokens(s) for s in data_train]
data_train_tokens = [truncate_or_pad_input(s) for s in data_train_tokens]

data_val_tokens = [convert_sentence_to_sequence_of_tokens(s) for s in data_val]
data_val_tokens = [truncate_or_pad_input(s) for s in data_val_tokens]

data_test_tokens = [convert_sentence_to_sequence_of_tokens(s) for s in data_test]
data_test_tokens = [truncate_or_pad_input(s) for s in data_test_tokens]

In [0]:
def shuffle_lists_in_parallel(*ls):
  
  zipped_ls = list(zip(*ls))
  
  random.shuffle(zipped_ls)
  
  return zip(*zipped_ls)

In [0]:
random.shuffle(data_train_tokens)

ELBO training:

In [0]:
vae = TextVAE(rnn_type='LSTM')
optimizer = torch.optim.Adam(vae.parameters(), lr=lr)

In [18]:
from copy import deepcopy

nb_batches = len(data_train_tokens)//batch_size
pad_tok = word2token['<pad>']
start_tok = word2token['<start>']
end_tok = word2token['<end>']

for epoch in range(epochs):
    for id_batch in range(nb_batches):
        data_batch_as_list = data_train_tokens[(batch_size * id_batch):
                                        ((id_batch + 1) * batch_size)]
        batch_lengths = [len(seq[:seq.index(pad_tok)]) if pad_tok in seq 
                              else len(seq) for seq in data_batch_as_list]
        batch_lengths = torch.IntTensor(batch_lengths)
        sorted_lengths, sorted_idx = torch.sort(batch_lengths, descending=True)
        
        input_batch = deepcopy(data_batch_as_list)
        input_batch = [[start_tok] + seq[:-1] for seq in input_batch]
        input_batch = torch.LongTensor(input_batch)
        input_batch = input_batch.view(batch_size, max_sequence_length)
        input_batch = input_batch[sorted_idx]
        
        target_batch = deepcopy(data_batch_as_list)
        target_batch = [seq[:-1] + [end_tok] if seq[-1] != pad_tok 
                          else seq[:seq.index(pad_tok)] + [end_tok] + 
                            seq[seq.index(pad_tok) + 1:] for seq in target_batch]
        target_batch = torch.LongTensor(target_batch)
        target_batch = target_batch.view(batch_size, max_sequence_length)
        target_batch = target_batch[sorted_idx]        
        
        
        mu_batch, logvar_batch,\
            latent_var, data_batch_recon = vae(input_batch, sorted_lengths)
        
        
        elbo_loss = vae.loss_elbo(mu_batch, logvar_batch)
        nll_loss = vae.loss_nll(target_batch, data_batch_recon, batch_lengths)
        elbo_weight = vae.annealing(id_batch + nb_batches * epoch)
        
        loss = (nll_loss + elbo_weight * elbo_loss) / batch_size
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not id_batch % 10:
            print(f'Epoch {epoch + 1}/{epochs}, Batch {id_batch}/{nb_batches}, \
             Loss ELBO: {elbo_loss}; Loss NLL: {nll_loss}, Anneal: {elbo_weight}')
            
    # save model at each epoch:
    torch.save(vae, 'textVAE_elbo_after_epoch_{}_incorp_reparam.pt'.format(epoch))
    # https://stackoverflow.com/questions/52277083/pytorch-saving-model-userwarning-couldnt-retrieve-source-code-for-container-of

Epoch 1/10, Batch 0/1314,              Loss ELBO: 3.5825061798095703; Loss NLL: 9.200987815856934, Anneal: 4.5397868702434395e-05
Epoch 1/10, Batch 10/1314,              Loss ELBO: 2.9815711975097656; Loss NLL: 9.059290885925293, Anneal: 4.654706772585396e-05
Epoch 1/10, Batch 20/1314,              Loss ELBO: 7.619480133056641; Loss NLL: 8.073885917663574, Anneal: 4.772535611847556e-05
Epoch 1/10, Batch 30/1314,              Loss ELBO: 128.17884826660156; Loss NLL: 6.9002461433410645, Anneal: 4.893347014058169e-05
Epoch 1/10, Batch 40/1314,              Loss ELBO: 39.09267044067383; Loss NLL: 6.901294231414795, Anneal: 5.0172164683764205e-05
Epoch 1/10, Batch 50/1314,              Loss ELBO: 9.143732070922852; Loss NLL: 6.731674671173096, Anneal: 5.144221374220898e-05
Epoch 1/10, Batch 60/1314,              Loss ELBO: 6.511533737182617; Loss NLL: 6.680471420288086, Anneal: 5.274441089589183e-05
Epoch 1/10, Batch 70/1314,              Loss ELBO: 5.466487407684326; Loss NLL: 6.6334257125

KeyboardInterrupt: ignored

In [0]:
# results after 4 epochs (posterior collapse is clear in the greedy 
# sampling of sentences)

pad_tok = word2token['<pad>']
start_tok = word2token['<start>']
end_tok = word2token['<end>']

from copy import deepcopy
data_batch_as_list = data_train_tokens[:batch_size]
batch_lengths = [len(seq[:seq.index(pad_tok)]) if pad_tok in seq 
                      else seq for seq in data_batch_as_list]
batch_lengths = torch.IntTensor(batch_lengths)
sorted_lengths, sorted_idx = torch.sort(batch_lengths, descending=True)

example_batch = deepcopy(data_batch_as_list)
example_batch = [[start_tok] + seq[:-1] for seq in example_batch]
example_batch = torch.LongTensor(example_batch)
example_batch = example_batch.view(batch_size, max_sequence_length)
example_batch = example_batch[sorted_idx]
vae.train()
for i in range(example_batch.size(0)):
  print(i, convert_sequence_of_tokens_to_sentence((torch.max((vae(example_batch, sorted_lengths)[-1]), dim=-1)[1][i]).tolist()),
          convert_sequence_of_tokens_to_sentence(example_batch[i].tolist()))
list(map(convert_sequence_of_tokens_to_sentence, vae.sample_sentence(z=vae(example_batch, sorted_lengths)[0].squeeze())[1].int()))

0 the the the the are the is are to in a as <unk> <unk> to year been been a <unk> in <unk> the <unk> of the and <unk> to the <unk> <unk> <unk> <unk> of to the and the of the <unk> <unk> <unk> <unk> <unk> <unk> <start> for journalists however who write what they <unk> view as history 's first draft this has also been a week to give a lot of space and time to ron and nancy 's sales appearance in japan on behalf of a communications giant and its controversial founder
1 the <unk> <unk> spokesman president of chief of the <unk> and the boston of the 's to the the of <unk> the <unk> <unk> court are to be to be N N of and N N of N N N N N N of N N to <start> robert white a vice president and manager of corporate trade at first interstate of california agreed with that view and predicted the u.s. federal funds rate will drop to between N N N and N N within N days from its current level at N N N
2 the and at a N N increase profit of in was have the <unk> of of the hanover and said a N N million

["the company said it expects to sell its <unk> operations in the first half of N million shares outstanding by the company 's <unk> unit of $ N million of N N to $ N million from $ N million in the third quarter compared with a N N stake in the company 's <unk> unit of $ N million",
 "the company said it expects to sell its <unk> operations in the first half of N million shares outstanding by the company 's <unk> unit of $ N million of N N to $ N million from $ N million in the third quarter compared with a N N stake in the company 's <unk> unit of $ N million",
 "the company said it expects to sell its <unk> operations in the first half of N million shares outstanding by the company 's <unk> unit of $ N million of N N to $ N million from $ N million in the third quarter compared with a N N stake in the company 's <unk> unit of $ N million",
 "the company said it expects to sell its <unk> operations in the first half of N million shares outstanding by the company 's <unk> unit of $ N 

MMD training:

In [0]:
vae = TextVAE(rnn_type='LSTM', loss_type='mmd')
optimizer = torch.optim.Adam(vae.parameters(), lr=lr)

In [20]:
from copy import deepcopy

nb_batches = len(data_train_tokens)//batch_size
pad_tok = word2token['<pad>']
start_tok = word2token['<start>']
end_tok = word2token['<end>']

for epoch in range(epochs):
    for id_batch in range(nb_batches):
        data_batch_as_list = data_train_tokens[(batch_size * id_batch):
                                        ((id_batch + 1) * batch_size)]
        batch_lengths = [len(seq[:seq.index(pad_tok)]) if pad_tok in seq 
                              else len(seq) for seq in data_batch_as_list]
        batch_lengths = torch.IntTensor(batch_lengths)
        sorted_lengths, sorted_idx = torch.sort(batch_lengths, descending=True)
        
        input_batch = deepcopy(data_batch_as_list)
        input_batch = [[start_tok] + seq[:-1] for seq in input_batch]
        input_batch = torch.LongTensor(input_batch)
        input_batch = input_batch.view(batch_size, max_sequence_length)
        input_batch = input_batch[sorted_idx]
        
        target_batch = deepcopy(data_batch_as_list)
        target_batch = [seq[:-1] + [end_tok] if seq[-1] != pad_tok 
                          else seq[:seq.index(pad_tok)] + [end_tok] + 
                            seq[seq.index(pad_tok) + 1:] for seq in target_batch]
        target_batch = torch.LongTensor(target_batch)
        target_batch = target_batch.view(batch_size, max_sequence_length)
        target_batch = target_batch[sorted_idx]        
        
        
        latent_var, data_batch_recon = vae(input_batch, sorted_lengths)
        
        
        mmd_loss = vae.loss_mmd(latent_var)
        nll_loss = vae.loss_nll(target_batch, data_batch_recon, batch_lengths)
        mmd_weight = vae.annealing(id_batch + nb_batches * epoch)
        
        loss = (nll_loss + mmd_weight * mmd_loss) / batch_size
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not id_batch % 10:
            print(f'Epoch {epoch + 1}/{epochs}, Batch {id_batch}/{nb_batches}, \
             Loss MMD: {mmd_loss}; Loss NLL: {nll_loss}, Anneal: {mmd_weight}')
            
    # save model at each epoch:
    torch.save(vae, 'textVAE_elbo_after_epoch_{}_incorp_reparam.pt'.format(epoch))
    # https://stackoverflow.com/questions/52277083/pytorch-saving-model-userwarning-couldnt-retrieve-source-code-for-container-of

KeyboardInterrupt: ignored

In [27]:
print('Model trained with MMD objective stopped at the beginning of the third epoch:')

pad_tok = word2token['<pad>']
start_tok = word2token['<start>']
end_tok = word2token['<end>']

from copy import deepcopy
data_batch_as_list = data_train_tokens[:batch_size]
batch_lengths = [len(seq[:seq.index(pad_tok)]) if pad_tok in seq 
                      else seq for seq in data_batch_as_list]
batch_lengths = torch.IntTensor(batch_lengths)
sorted_lengths, sorted_idx = torch.sort(batch_lengths, descending=True)

example_batch = deepcopy(data_batch_as_list)
example_batch = [[start_tok] + seq[:-1] for seq in example_batch]
example_batch = torch.LongTensor(example_batch)
example_batch = example_batch.view(batch_size, max_sequence_length)
example_batch = example_batch[sorted_idx]
vae.train()

print('Sentences generated by a forward pass of a batch of data through the VAE:')

for i in range(example_batch.size(0)):
  print(i, convert_sequence_of_tokens_to_sentence((torch.max((vae(example_batch, sorted_lengths)[-1]), dim=-1)[1][i]).tolist()),
          convert_sequence_of_tokens_to_sentence(example_batch[i].tolist()))
  
print('\n \n Sentences greedily generated from the latent code generated by the same batch of data:')
list(map(convert_sequence_of_tokens_to_sentence, vae.sample_sentence(z=vae(example_batch, sorted_lengths)[0].squeeze())[1].int()))

Model trained with MMD objective stopped at the beginning of the third epoch:
Sentences generated by a forward pass of a batch of data through the VAE:
0 for the the the <unk> that the have the the a as <unk> <unk> to year been been <unk> <unk> of the the <unk> of the and <unk> to the the <unk> <unk> <unk> in and the 's the and the <unk> <unk> in the <unk> to <start> for journalists however who write what they <unk> view as history 's first draft this has also been a week to give a lot of space and time to ron and nancy 's sales appearance in japan on behalf of a communications giant and its controversial founder
1 one <unk> <unk> <unk> president of chief of the <unk> & the boston corp. the 's to the the of the of company company reserve were to be to be N N from to N N in the N N N N N in the N of <start> robert white a vice president and manager of corporate trade at first interstate of california agreed with that view and predicted the u.s. federal funds rate will drop to between N 

['for the past few years ago the <unk> of the <unk> of the <unk> <unk> <unk> <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk>',
 "one of the company 's <unk> <unk> <unk> <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> & co. 's <unk> <unk> unit of $ N million or N cents a share from $ N million or N cents a share from $ N million or N cents a share from $ N million or N cents a",
 'separately <unk> corp. said it expects to $ N million or N cents a share from $ N million or N cents a share from $ N million or N cents a share from $ N million or N cents a share from $ N million or N cents a share from $ N million or N cents a share',
 "it 's a <unk> <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <un

In [0]:
# READ: https://arxiv.org/abs/1402.0030  https://www.reddit.com/r/MachineLearning/comments/46xjtw/how_to_understand_the_kl_divergence_term_in/ 
# annealing function where it comes from?
#[i for i in list(torch.max((model(input_batch, sorted_lengths)[-1]), dim=-1)[1])]
for i in range(input_batch.size(0)):
  print(i, convert_sequence_of_tokens_to_sentence((torch.max((model(input_batch, sorted_lengths)), dim=-1)[1][i]).tolist()),
          convert_sequence_of_tokens_to_sentence(input_batch[i].tolist()))

0 the the interview <unk> the <unk> 's been in a <unk> <unk> in <unk> <unk> of <unk> <unk> president of a. <unk> of a to <unk> <unk> <unk> <unk> of the 's in and N of the a <unk> $ of the $ <unk> of the <unk> market the and <start> in an <unk> indictment the government has charged gaf a wayne n.j. specialty chemical maker and its vice chairman james t. sherwin with attempting to manipulate the common stock of union carbide corp. in advance of gaf 's planned sale of a large block of the stock in november N
1 the <unk> of <unk> <unk> <unk> been to for the <unk> <unk> in the <unk> <unk> <unk> <unk> <unk> week the years ago the and <unk> <unk> the the N old the and and the <unk> years and and the <unk> 's <end> <end> <end> <end> <end> <end> <end> <end> <start> dallas district judge jack <unk> had sparked calls for a judicial inquiry with his remarks to the press last december two weeks after sentencing an <unk> defendant to N years in state prison for killing two homosexual men in a city p

In [41]:
convert_sequence_of_tokens_to_sentence(vae.sample_sentence()[1].squeeze().int().tolist())

'this year is a <unk> of the <unk> of the <unk> of the <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk>'

In [44]:
ls_gen_sentences_tokens = vae.sample_sentence(z=torch.randn(4, 2* z_dim))[1].squeeze().int().tolist()

for ls_tokens in ls_gen_sentences_tokens:
  print(convert_sequence_of_tokens_to_sentence(ls_tokens))

it is n't <unk> by the <unk> of the <unk> of the <unk> of the <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and
people who have been <unk> by the <unk> of the <unk> of the <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> <unk> and <unk> <unk> and <unk> <unk> <unk> and <unk> <unk> <unk>
william <unk> <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk>
that 's <unk> <unk> <unk> and <unk> <unk> and <unk> <unk> an

In [53]:
ls_gen_sentences_tokens = vae.sample_sentence(nb_examples=7)[1].squeeze().int().tolist()

for ls_tokens in ls_gen_sentences_tokens:
  print(convert_sequence_of_tokens_to_sentence(ls_tokens))

it is n't likely to be able to be <unk> by the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> of the <unk> <unk> and <unk> <unk> and
and the <unk> of the <unk> of the <unk> <unk> <unk> <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk>
<unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and <unk> <unk> and
the company said it will be a <unk> of the <unk> of the <unk> of the company 's <unk> <unk> and <unk> <un