<a href="https://colab.research.google.com/github/shainedl/Papers-Colab/blob/master/Surprisingly_Effective_Fix_Latent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on *A Surprisingly Effective Fix for Deep Latent Variable Modeling of Text* (Li et al, Carnegie Mellon University)

In [1]:
import torch
from torch import nn, optim
from google.colab import files
from collections import defaultdict
from itertools import count, chain
import nltk
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')
import numpy as np
from torch.autograd import Variable

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [3]:
uploaded = files.upload()

Saving sample_train.txt to sample_train.txt


In [0]:
def load_data(file):
  w2i = defaultdict(lambda x=count(0): next(x))
  w2i["<s>"] 
  w2i["</s>"] 
  w2i["<unk>"] 
  data = []
  file = file.decode('utf-8')
  sentences = file.splitlines()

  for sentence in sentences:
    tokens = word_tokenize(sentence)
    for token in tokens:
      w2i[token]
    data.append(tokens)


  w2i = dict(w2i)
  i2w = {i:w for w,i in w2i.items()}

  return data, w2i, i2w

In [0]:
training_data, w2i, i2w = load_data(uploaded['sample_train.txt'])

In [0]:
def prepare_sequence(seq, to_ix):
  idxs = []
  idxs.append([w2i["<s>"]])
  idxs.append([to_ix[w] for w in seq])
  idxs.append([w2i["</s>"]])
  idxs = flatten_list = list(chain.from_iterable(idxs))
  return torch.tensor(idxs, dtype=torch.long).flatten()

In [0]:
def batch_data(batch_size, data, w2i):

  sentence_lengths = np.array([len(sentence) for sentence in data])
  sorted_idx = np.argsort(sentence_lengths)
  sorted_lengths = sentence_lengths[sorted_idx]

  len_increase_idx = []
  for i in range(1, len(sorted_lengths)):
    if sorted_lengths[i] > sorted_lengths[i-1]:
      len_increase_idx.append(i)
  len_increase_idx.append(len(sorted_lengths))

  batch_data = []
  curr_idx = 0
  for idx in len_increase_idx:
    while curr_idx < idx:
      batch_sentences = []
      new_idx = min(curr_idx + batch_size, idx)
      for i in range(curr_idx, new_idx):
        sent_to_vec = prepare_sequence(data[sorted_idx[i]], w2i)
        batch_sentences.append(sent_to_vec)
      curr_idx = new_idx
      batch_sentences = torch.stack(batch_sentences)
      batch_data.append(batch_sentences)

  return batch_data

In [0]:
batch_size = 50

In [0]:
batch_training = batch_data(batch_size, training_data, w2i)

In [10]:
len(batch_training)

2042

In [11]:
len(training_data) / batch_size

2000.0

In [0]:
class Encoder(nn.Module):

  def __init__(self, vocab_size, embedding_size, hidden_size, latent_size):
    super(Encoder, self).__init__()

    self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                   embedding_dim=embedding_size,
                                   padding_idx=-1)
    
    self.rnn = nn.LSTM(input_size=embedding_size,
                        hidden_size=hidden_size,
                        batch_first=True)
    
    self.fc_mu = nn.Linear(hidden_size, latent_size)
    self.fc_var = nn.Linear(hidden_size, latent_size)

    self._initialize_parameters([-0.01, 0.01], [-0.1, 0.1]) 

  def encode(self, x):
    x = self.embeddings(x)
    outputs, (hidden, cell) = self.rnn(x)
    mu = self.fc_mu(hidden)
    logvar = self.fc_var(hidden)
    
    return mu, logvar 

  def _initialize_parameters(self, lstm_init, embed_init):
    for param in self.parameters():
      nn.init.uniform_(param, a=lstm_init[0], b=lstm_init[1])
    nn.init.uniform_(self.embeddings.weight, a=embed_init[0], b=embed_init[1])

In [0]:
class Decoder(nn.Module):

  def __init__(self, vocab_size, embedding_size, hidden_size, latent_size):
    super(Decoder, self).__init__()

    self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                   embedding_dim=embedding_size,
                                   padding_idx=-1)
    
    self.rnn = nn.LSTM(input_size=embedding_size + latent_size,
                        hidden_size=hidden_size,
                        batch_first=True)    

    self.fc_hid = nn.Linear(latent_size, hidden_size)
    self.fc_voc = nn.Linear(hidden_size, vocab_size)

    self.dropout = nn.Dropout()

    self._initialize_parameters([-0.01, 0.01], [-0.1, 0.1])

  def decode(self, z, input):
    cell = self.fc_hid(z)
    cell = cell.unsqueeze(0)
    hidden = torch.tanh(cell)

    embed = self.embeddings(input)
    embed = self.dropout(embed)
    z = z.expand(embed.size(1), z.size(0), z.size(1))
    z = z.transpose(1,0)
    embed_lat = torch.cat((embed, z), 2)

    outputs, (hidden, cell) = self.rnn(embed_lat, (hidden, cell))
    outputs = self.dropout(outputs)
    output_logits = self.fc_voc(outputs)

    return output_logits

  def _initialize_parameters(self, lstm_init, embed_init):
    for param in self.parameters():
      nn.init.uniform_(param, a=lstm_init[0], b=lstm_init[1])
    nn.init.uniform_(self.embeddings.weight, a=embed_init[0], b=embed_init[1])

In [0]:
class VAE(nn.Module):

  def __init__(self, encoder, decoder):
    super(VAE, self).__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.re_loss = nn.CrossEntropyLoss()

  def forward(self, x):
    mu, logvar = self.encoder.encode(x)
    kl = self.get_kl(mu, logvar)
    z = self._reparameterize(mu, logvar)
    z = z.squeeze()

    source = x[:,:-1]
    target = x[:, 1:]
    output_logits = self.decoder.decode(z, source)
    re = self.get_reconstruction_error(output_logits, target)

    return kl, re

  def _reparameterize(self, mu, logvar):
    std = torch.exp(logvar / 2)
    eps = torch.randn_like(std)
    return mu + std * eps  

  def get_kl(self, mu, logvar):
    return (1 + logvar - mu**2 - torch.exp(logvar) / 2)

  def get_reconstruction_error(self, output_logits, target):
    target = target.contiguous().view(-1)
    output_logits = output_logits.view(-1, output_logits.size(2))
    return self.re_loss(output_logits, target)


In [0]:
vocab_size = len(w2i)
embedding_size = 128
hidden_size = 512 
latent_size = 32

In [0]:
def train(epoch, pretraining=False, fb=False, target_rate=2.0, anneal=0.0):
  
  model.train()

  running_loss = 0.0
  for batch_idx, data in enumerate(batch_training):
    optimizer_e.zero_grad()
    optimizer_d.zero_grad()
    kl, re = model(batch_training[batch_idx])
    if pretraining:
      loss = re
    if fb:
      kl_mask = (kl > target_rate).float()
      fb_mask = (kl <= target_rate).float()
      free_b = kl_mask + target_rate
      kl = (kl * kl_mask + free_b * fb_mask).sum()
      if epoch < 10:
        loss = -kl * anneal + re
        anneal += 0.1
      else:
        loss = -kl + re
    loss.backward()
    optimizer_e.step()
    optimizer_d.step()

    running_loss += loss
    if (epoch == 0 or epoch % 10 == 9) and batch_idx % 100 == 99:    # print every 100 mini-batches
      print('[%d, %5d] Train loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 100))
      running_loss = 0.0 

In [98]:
encoder = Encoder(vocab_size, embedding_size, hidden_size, latent_size)
decoder = Decoder(vocab_size, embedding_size, hidden_size, latent_size)
model = VAE(encoder, decoder)
optimizer_e = optim.SGD(encoder.parameters(), lr=0.5)
optimizer_d = optim.SGD(decoder.parameters(), lr=0.5)
train(0, True, False)

[1,   100] Train loss: 7.499
[1,   200] Train loss: 5.271
[1,   300] Train loss: 4.959
[1,   400] Train loss: 4.652
[1,   500] Train loss: 4.405
[1,   600] Train loss: 4.437


KeyboardInterrupt: ignored