<a href="https://colab.research.google.com/github/shainedl/Papers-Colab/blob/master/Surprisingly_Effective_Fix_Latent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on *A Surprisingly Effective Fix for Deep Latent Variable Modeling of Text* (Li et al, Carnegie Mellon University)

In [19]:
import torch
from torch import nn
from google.colab import files
from collections import defaultdict
from itertools import count, chain
import nltk
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')
import numpy as np
from torch.autograd import Variable

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [6]:
uploaded = files.upload()

Saving sample_train.txt to sample_train.txt


In [0]:
def load_data(file):
  w2i = defaultdict(lambda x=count(0): next(x))
  w2i["<s>"] 
  w2i["</s>"] 
  w2i["<unk>"] 
  data = []
  file = file.decode('utf-8')
  sentences = file.splitlines()

  for sentence in sentences:
    tokens = word_tokenize(sentence)
    for token in tokens:
      w2i[token]
    data.append(tokens)


  w2i = dict(w2i)
  i2w = {i:w for w,i in w2i.items()}

  return data, w2i, i2w, padding_idx

In [0]:
training_data, w2i, i2w, padding_idx = load_data(uploaded['sample_train.txt'])

In [0]:
def prepare_sequence(seq, to_ix):
  idxs = []
  idxs.append([w2i["<s>"]])
  idxs.append([to_ix[w] for w in seq])
  idxs.append([w2i["</s>"]])
  idxs = flatten_list = list(chain.from_iterable(idxs))
  return torch.tensor(idxs, dtype=torch.long).flatten()

In [0]:
def batch_data(batch_size, data, w2i):

  sentence_lengths = np.array([len(sentence) for sentence in data])
  sorted_idx = np.argsort(sentence_lengths)
  sorted_lengths = sentence_lengths[sorted_idx]

  len_increase_idx = []
  for i in range(1, len(sorted_lengths)):
    if sorted_lengths[i] > sorted_lengths[i-1]:
      len_increase_idx.append(i)
  len_increase_idx.append(len(sorted_lengths))

  batch_data = []
  curr_idx = 0
  for idx in len_increase_idx:
    while curr_idx < idx:
      batch_sentences = []
      new_idx = min(curr_idx + batch_size, idx)
      for i in range(curr_idx, new_idx):
        sent_to_vec = prepare_sequence(data[sorted_idx[i]], w2i)
        batch_sentences.append(sent_to_vec)
      curr_idx = new_idx
      batch_data.append(batch_sentences)

  return batch_data

In [0]:
batch_size = 50

In [0]:
batch_training = batch_data(batch_size, training_data, w2i)

In [44]:
len(batch_training)

2042

In [45]:
len(training_data) / batch_size

2000.0

In [0]:
class Encoder(nn.Module):

  def __init__(self, vocab_size, embedding_size, hidden_size, latent_size):
    super(Encoder, self).__init__()

    self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                   embedding_dim=embedding_size,
                                   padding_idx=-1)
    
    self.rnn = nn.LSTM(input_size=embedding_size,
                        hidden_size=hidden_size,
                        batch_first=True)
    
    self.fc_mu = nn.Linear(hidden_size, latent_size)
    self.fc_var = nn.Linear(hidden_size, latent_size)

    self._initialize_parameters([-0.01, 0.01], [-0.1, 0.1]) 

  def forward(self, x):
    x = self.embeddings(x)
    outputs, (hidden, cell) = self.rnn(x)
    mu = self.fc_mu(hidden)
    logvar = self.fc_var(hidden)
    
    return mu, logvar 

  def _initialize_parameters(self, lstm_init, embed_init):
    for param in self.parameters():
      nn.init.uniform_(param, a=lstm_init[0], b=lstm_init[1])
    nn.init.uniform_(self.embeddings.weight, a=embed_init[0], b=embed_init[1])

In [0]:
class Decoder(nn.Module):

  def __init__(self, vocab_size, embedding_size, padding_idx, hidden_size, latent_size):
    super(Decoder, self).__init__()

    self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                   embedding_dim=embedding_size,
                                   padding_idx=-1)
    
    self.rnn = nn.LSTM(input_size=embedding_size + latent_size,
                        hidden_size=hidden_size,
                        batch_first=True)    

    self.fc_hid = nn.Linear(latent_size, hidden_size)
    self.fc_voc = nn.Linear(hidden_size, vocab_size)

    self._initialize_parameters([-0.01, 0.01], [-0.1, 0.1])

  def decode(self, z, input):
    cell = self.fc_hid(z)
    hidden = torch.tanh(cell)

    embed = self.embeddings(input)
    embed_lat = torch.cat(embed, z)

    outputs, (hidden, cell) = self.rnn_d(embed_lat, (hidden, cell))
    output_logits = self.fc_voc(outputs)

    return output_logits

  def _initialize_parameters(self, lstm_init, embed_init):
    for param in self.parameters():
      nn.init.uniform_(param, a=lstm_init[0], b=lstm_init[1])
    nn.init.uniform_(self.embeddings.weight, a=embed_init[0], b=embed_init[1])

In [0]:
class VAE(nn.Module):

  def __init__(self, vocab_size, embedding_size, padding_idx, hidden_size, latent_size):
    super(VAE, self).__init__():

  def forward(self, x):
    mu, logvar = self.encode(x)
    z = self._reparameterize(mu, logvar)

    return self.decode(z, x), mu, logvar

  def _reparameterize(self, mu, logvar):
    std = torch.exp(logvar / 2)
    eps = torch.randn_like(std)
    
    return mu + std * eps  


In [0]:
vocab_size = len(w2i)
embedding_size = 128
hidden_size = 512 
latent_size = 32

In [0]:
def train(epoch):
  
  model.train()

  y, mu, logvar = model(batch_input)

In [0]:
model = VAE(vocab_size, embedding_size, padding_idx, hidden_size, latent_size)