<a href="https://colab.research.google.com/github/shainedl/Papers-Colab/blob/master/Surprisingly_Effective_Fix_Latent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on *A Surprisingly Effective Fix for Deep Latent Variable Modeling of Text* (Li et al, Carnegie Mellon University)

In [98]:
import torch
from torch import nn, optim
from google.colab import files
from collections import defaultdict
from itertools import count, chain
import nltk
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')
import numpy as np
from torch.autograd import Variable
import pdb

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [100]:
uploaded_training = files.upload()

Saving sample_train.txt to sample_train (1).txt


In [102]:
uploaded_val = files.upload()

Saving sample_dev.txt to sample_dev (1).txt


In [104]:
uploaded_test = files.upload()

Saving sample_test.txt to sample_test (1).txt


In [0]:
def load_data(file):
  w2i = defaultdict(lambda x=count(0): next(x))
  w2i["<s>"] 
  w2i["</s>"] 
  w2i["<unk>"] 
  data = []
  file = file.decode('utf-8')
  sentences = file.splitlines()

  for sentence in sentences:
    tokens = word_tokenize(sentence)
    for token in tokens:
      w2i[token]
    data.append(tokens)

  w2i = dict(w2i)
  i2w = {i:w for w,i in w2i.items()}

  return data, w2i, i2w

In [0]:
training_data, w2i, i2w = load_data(uploaded_training['sample_train.txt'])

In [0]:
def load_data_test(file):
  data = []
  file = file.decode('utf-8')
  sentences = file.splitlines()

  for sentence in sentences:
    tokens = word_tokenize(sentence)
    data.append(tokens)
    
  return data

In [0]:
val_data = load_data_test(uploaded_val['sample_dev.txt'])

In [0]:
test_data = load_data_test(uploaded_test['sample_test.txt'])

In [0]:
def prepare_sequence(seq, to_ix):
  idxs = [w2i["<s>"]]
  for w in seq:
    if w in to_ix:
      idxs.append(to_ix[w])
    else:
      idxs.append(w2i["<unk>"])
  idxs.append(w2i["</s>"])
  return torch.tensor(idxs, dtype=torch.long)

In [0]:
def batch_data(batch_size, data, w2i):

  sentence_lengths = np.array([len(sentence) for sentence in data])
  sorted_idx = np.argsort(sentence_lengths)
  sorted_lengths = sentence_lengths[sorted_idx]

  len_increase_idx = []
  for i in range(1, len(sorted_lengths)):
    if sorted_lengths[i] > sorted_lengths[i-1]:
      len_increase_idx.append(i)
  len_increase_idx.append(len(sorted_lengths))

  batch_data = []
  curr_idx = 0
  for i, idx in enumerate(len_increase_idx):
    while curr_idx < idx:
      batch_sentences = []
      new_idx = min(curr_idx + batch_size, idx)
      for i in range(curr_idx, new_idx):
        sent_to_vec = prepare_sequence(data[sorted_idx[i]], w2i)
        batch_sentences.append(sent_to_vec)
      curr_idx = new_idx
      batch_sentences = torch.stack(batch_sentences)
      batch_data.append(batch_sentences)

  i = 0
  j = len(batch_data)
  while i < j:
    if i != 0 and len(batch_data[i]) <= 10 and len(batch_data[i][0]) == len(batch_data[i-1][0]):
      batch_data.append(torch.cat((batch_data[i], batch_data[i-1])))
      batch_data.pop(i)
      batch_data.pop(i-1)
      i -= 1
      j = len(batch_data)
    elif len(batch_data[i]) == 1:
      batch_data.pop(i)
      j = len(batch_data)
    else:
      i += 1

  return batch_data

In [0]:
batch_size = 50

In [0]:
batch_training = batch_data(batch_size, training_data, w2i)

In [0]:
batch_val = batch_data(batch_size, val_data, w2i)

In [0]:
batch_test = batch_data(batch_size, test_data, w2i)

In [0]:
class Encoder(nn.Module):

  def __init__(self, vocab_size, embedding_size, hidden_size, latent_size):
    super(Encoder, self).__init__()

    self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                   embedding_dim=embedding_size,
                                   padding_idx=-1)
    
    self.rnn = nn.LSTM(input_size=embedding_size,
                        hidden_size=hidden_size,
                        batch_first=True)
    
    self.fc_mu = nn.Linear(hidden_size, latent_size)
    self.fc_var = nn.Linear(hidden_size, latent_size)

    self._initialize_parameters([-0.01, 0.01], [-0.1, 0.1]) 

  def encode(self, x):
    x = self.embeddings(x)
    outputs, (hidden, cell) = self.rnn(x)
    mu = self.fc_mu(hidden)
    logvar = self.fc_var(hidden)
    
    return mu, logvar 

  def _initialize_parameters(self, lstm_init, embed_init):
    for param in self.parameters():
      nn.init.uniform_(param, a=lstm_init[0], b=lstm_init[1])
    nn.init.uniform_(self.embeddings.weight, a=embed_init[0], b=embed_init[1])

In [0]:
class Decoder(nn.Module):

  def __init__(self, vocab_size, embedding_size, hidden_size, latent_size):
    super(Decoder, self).__init__()

    self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                   embedding_dim=embedding_size,
                                   padding_idx=-1)
    
    self.rnn = nn.LSTM(input_size=embedding_size + latent_size,
                        hidden_size=hidden_size,
                        batch_first=True)    

    self.fc_hid = nn.Linear(latent_size, hidden_size)
    self.fc_voc = nn.Linear(hidden_size, vocab_size)

    self.dropout = nn.Dropout()
    self._initialize_parameters([-0.01, 0.01], [-0.1, 0.1])

  def decode(self, z, inputs, teacher_forcing):
    cell = self.fc_hid(z)
    cell = cell.unsqueeze(0)
    hidden = torch.tanh(cell)

    if teacher_forcing:
      embed = self.embeddings(inputs)
      embed = self.dropout(embed)
      z = z.expand(embed.size(1), z.size(0), z.size(1))
      z = z.transpose(1,0)
      embed_lat = torch.cat((embed, z), 2)

      outputs, (hidden, cell) = self.rnn(embed_lat, (hidden, cell))
      outputs = self.dropout(outputs)
      output_logits = self.fc_voc(outputs)
      return output_logits, _
    else:
      batch_size = inputs.size(0)
      input_d = inputs[:,1]
      output_logit_prev = None
      sent_len = inputs.size(1)
      z = z.unsqueeze(1)
      for i in range(sent_len):
        embed = self.embeddings(input_d)
        embed = embed.unsqueeze(1)
        embed_lat = torch.cat((embed, z), 2)
        outputs, (hidden, cell) = self.rnn(embed_lat, (hidden, cell))  
        output_logit = self.fc_voc(outputs)
        if output_logit_prev is not None:
          output_logits = torch.cat((output_logit_prev, output_logit), dim=1)
          output_logit_prev = output_logits
        else:
          output_logit_prev = output_logit
        input_d = torch.argmax(output_logit, dim=2).flatten()

        batch_decoded = [[] for j in range(batch_size)]
        for k in range(batch_size):
          token = i2w[input_d[k].item()]
          batch_decoded[k].append(token)
      #pdb.set_trace() 
      return output_logits, batch_decoded

  def _initialize_parameters(self, lstm_init, embed_init):
    for param in self.parameters():
      nn.init.uniform_(param, a=lstm_init[0], b=lstm_init[1])
    nn.init.uniform_(self.embeddings.weight, a=embed_init[0], b=embed_init[1])

In [0]:
class VAE(nn.Module):

  def __init__(self, encoder, decoder):
    super(VAE, self).__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.re_loss = nn.CrossEntropyLoss()

  def forward(self, x, teacher_forcing):
    mu, logvar = self.encoder.encode(x)
    kl = self.get_kl(mu, logvar)
    z = self._reparameterize(mu, logvar)
    z = z.squeeze()

    source = x[:,:-1]
    target = x[:, 1:]
    teacher_forcing = self.training
    output_logits, batch_decoded = self.decoder.decode(z, source, teacher_forcing)
    re = self.get_reconstruction_error(output_logits, target)

    return kl, re, batch_decoded

  def _reparameterize(self, mu, logvar):
    std = torch.exp(logvar / 2)
    eps = torch.randn_like(std)
    return mu + std * eps  

  def get_kl(self, mu, logvar):
    return (mu**2 + logvar.exp() - 1 - logvar) / 2

  def get_reconstruction_error(self, output_logits, target):
    target = target.contiguous().view(-1)
    output_logits = output_logits.view(-1, output_logits.size(2))
    #pdb.set_trace() 
    return self.re_loss(output_logits, target)


In [0]:
vocab_size = len(w2i)
embedding_size = 128
hidden_size = 512 
latent_size = 32
use_cuda = True

In [0]:
def train(epoch, pretraining=False, fb=False, target_rate=2.0, anneal=0.0):
  
  model.train()

  running_loss = 0.0
  for batch_idx, data in enumerate(batch_training):
    if use_cuda and torch.cuda.is_available():
      data = data.cuda()
      
    optimizer_e.zero_grad()
    optimizer_d.zero_grad()
    kl, re, _ = model(data, True)
    if pretraining:
      loss = re
    if fb:
      kl_mask = (kl > target_rate).float()
      fb_mask = (kl <= target_rate).float()
      free_b = kl_mask + target_rate
      kl = (kl * kl_mask + free_b * fb_mask).sum()
      if epoch < 10:
        loss = kl * anneal + re
        anneal += 0.1
      else:
        loss = kl + re
    loss.backward()
    optimizer_e.step()
    optimizer_d.step()

    running_loss += loss
    #if (epoch == 0 or epoch % 10 == 9) and batch_idx % 100 == 99:    # print every 100 mini-batches
    if batch_idx % 100 == 99:
      print('[%d, %5d] Train loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 100))
      running_loss = 0.0 

In [0]:
def test(epoch, validation=False, pretraining=False, fb=False, target_rate=2.0, anneal=0.0):
  
  model.eval()

  running_loss = 0.0
  data = batch_val if validation else batch_test
  for batch_idx, data in enumerate(data):
    if use_cuda and torch.cuda.is_available():
      data = data.cuda()
      
    kl, re, batch_decoded = model(data, False)
    if pretraining:
      loss = re
    if fb:
      kl_mask = (kl > target_rate).float()
      fb_mask = (kl <= target_rate).float()
      free_b = kl_mask + target_rate
      kl = (kl * kl_mask + free_b * fb_mask).sum()
      if epoch < 10:
        loss = kl * anneal + re
        anneal += 0.1
      else:
        loss = kl + re

  #if epoch == 0 or epoch % 10 == 9:
  if validation:
    print('[%d, %5d] Validation loss: %.3f' % (epoch + 1, batch_idx + 1, loss))
  else:
    print('[%d, %5d] Test loss: %.3f' % (epoch + 1, batch_idx + 1, loss))

  return loss

In [160]:
encoder = Encoder(vocab_size, embedding_size, hidden_size, latent_size)
decoder = Decoder(vocab_size, embedding_size, hidden_size, latent_size)
model = VAE(encoder, decoder)
if use_cuda and torch.cuda.is_available():
  model.cuda()
optimizer_e = optim.SGD(encoder.parameters(), lr=0.5)
optimizer_d = optim.SGD(decoder.parameters(), lr=0.5)

max_loss = -1e10
unimproved_count = 0
patience = 2

for i in range(3):
  train(i, True, False)
  val_loss = test(i, True, True, False)
  if max_loss < val_loss:
    max_loss = val_loss
  elif abs(max_loss - val_loss) < 0.001:
    unimproved_count += 1
    if unimproved_count == patience:
      break
"""
decoder = Decoder(vocab_size, embedding_size, hidden_size, latent_size)
model = VAE(encoder, decoder)
if use_cuda and torch.cuda.is_available():
  model.cuda()
optimizer_d = optim.SGD(decoder.parameters(), lr=0.5)
for i in range(100):
  train(i, False, True)
  test(i)
"""

[1,   100] Train loss: 7.549
[1,   200] Train loss: 5.308
[1,   300] Train loss: 5.011
[1,   400] Train loss: 4.674
[1,   500] Train loss: 4.436
[1,   600] Train loss: 4.438
[1,   700] Train loss: 4.393
[1,   800] Train loss: 4.492
[1,   900] Train loss: 4.680
[1,  1000] Train loss: 4.704
[1,  1100] Train loss: 4.642
[1,  1200] Train loss: 4.799
[1,  1300] Train loss: 4.837
[1,  1400] Train loss: 4.822
[1,  1500] Train loss: 4.875
[1,  1600] Train loss: 4.935
[1,  1700] Train loss: 4.951
[1,  1800] Train loss: 5.027
[1,  1900] Train loss: 5.064
[1,  2000] Train loss: 5.143
[1,   220] Validation loss: 6.160
[2,   100] Train loss: 4.434
[2,   200] Train loss: 3.733
[2,   300] Train loss: 3.967
[2,   400] Train loss: 3.956
[2,   500] Train loss: 3.856
[2,   600] Train loss: 3.991
[2,   700] Train loss: 3.966
[2,   800] Train loss: 4.048
[2,   900] Train loss: 4.238
[2,  1000] Train loss: 4.241
[2,  1100] Train loss: 4.220
[2,  1200] Train loss: 4.307
[2,  1300] Train loss: 4.328
[2,  1400

'\ndecoder = Decoder(vocab_size, embedding_size, hidden_size, latent_size)\nmodel = VAE(encoder, decoder)\nif use_cuda and torch.cuda.is_available():\n  model.cuda()\noptimizer_d = optim.SGD(decoder.parameters(), lr=0.5)\nfor i in range(100):\n  train(i, False, True)\n  test(i)\n'

In [145]:
%debug

> [0;32m<ipython-input-130-0283a3440ee5>[0m(57)[0;36mdecode[0;34m()[0m
[0;32m     55 [0;31m        [0mbatch_decoded[0m [0;34m=[0m [0;34m[[0m[0;34m[[0m[0;34m][0m [0;32mfor[0m [0mj[0m [0;32min[0m [0mrange[0m[0;34m([0m[0mbatch_size[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m        [0;32mfor[0m [0mk[0m [0;32min[0m [0mrange[0m[0;34m([0m[0mbatch_size[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 57 [0;31m          [0mtoken[0m [0;34m=[0m [0mi2w[0m[0;34m[[0m[0minput_d[0m[0;34m[[0m[0mk[0m[0;34m][0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     58 [0;31m          [0mbatch_decoded[0m[0;34m[[0m[0mk[0m[0;34m][0m[0;34m.[0m[0mappend[0m[0;34m([0m[0mtoken[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     59 [0;31m      [0;31m#pdb.set_trace()[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> quit


In [0]:
"""
Downsampling the SNLI datasets
"""
import random

# Training
f = open('./snli_1.0/snli_1.0_train.txt')
contents = f.read()
file_as_list = contents.splitlines()
sentences = []
for example in file_as_list[1:]:
    sentences.append(example.split('\t')[5])
    sentences.append(example.split('\t')[6])

sample = random.sample(sentences, 100000)

f = open('./snli_1.0/sample_train.txt', "w")
for sentence in sample:
    f.write(sentence)
    f.write("\n")
f.close()

# Development
f = open('./snli_1.0/snli_1.0_dev.txt')
contents = f.read()
file_as_list = contents.splitlines()
sentences = []
for example in file_as_list[1:]:
    sentences.append(example.split('\t')[5])
    sentences.append(example.split('\t')[6])

sample = random.sample(sentences, 10000)

f = open('./snli_1.0/sample_dev.txt', "w")
for sentence in sample:
    f.write(sentence)
    f.write("\n")
f.close()

# Test
f = open('./snli_1.0/snli_1.0_test.txt')
contents = f.read()
file_as_list = contents.splitlines()
sentences = []
for example in file_as_list[1:]:
    sentences.append(example.split('\t')[5])
    sentences.append(example.split('\t')[6])

sample = random.sample(sentences, 10000)

f = open('./snli_1.0/sample_test.txt', "w")
for sentence in sample:
    f.write(sentence)
    f.write("\n")
f.close()