<a href="https://colab.research.google.com/github/shainedl/Papers-Colab/blob/master/Posterior_Collapse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn, optim
from google.colab import files
from collections import defaultdict
from itertools import count, chain
import nltk
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')
import numpy as np
import pdb
import re

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [3]:
uploaded_training = files.upload()

Saving sample_train.txt to sample_train.txt


In [5]:
uploaded_val = files.upload()

Saving sample_dev.txt to sample_dev.txt


In [7]:
uploaded_test = files.upload()

Saving sample_test.txt to sample_test.txt


In [0]:
def load_data(file):
  """
  Load training data and output vocabulary dictionaries
  """
  w2i = defaultdict(lambda x=count(0): next(x))
  w2i["<s>"] 
  w2i["</s>"] 
  w2i["<unk>"] 
  w2i['<unk>twoDigitNum']
  w2i['<unk>fourDigitNum']
  w2i['<unk>containsDigitAndAlpha']
  w2i['<unk>containsDigitAndDash']
  w2i['<unk>containsDigitAndSlash']
  w2i['<unk>containsDigitAndComma']
  w2i['<unk>containsDigitAndPeriod']
  w2i['<unk>othernum']
  w2i['<unk>allCaps']
  w2i['<unk>initCap']
  w2i['<unk>lowercase']
  w2i['<unk>time']
  data = []
  file = file.decode('utf-8')
  sentences = file.splitlines()

  for sentence in sentences:
    tokens = word_tokenize(sentence)
    for token in tokens:
      w2i[token]
    data.append(tokens)

  freq_dist = nltk.FreqDist([item for sublist in data for item in sublist])
  freq1 = set(list(freq_dist.keys())[-4000:])

  w2i = dict(w2i)
  for key in list(w2i.keys()):
    if key in freq1:
      w2i.pop(key)
  i2w = {i:w for w,i in w2i.items()}

  return data, w2i, i2w, freq_dist

In [0]:
training_data, w2i, i2w, freq_dist = load_data(uploaded_training['sample_train.txt'])

In [0]:
def load_data_test(file):
  """
  Load test and validation data
  """
  data = []
  file = file.decode('utf-8')
  sentences = file.splitlines()

  for sentence in sentences:
    tokens = word_tokenize(sentence)
    data.append(tokens)
    
  return data

In [0]:
val_data = load_data_test(uploaded_val['sample_dev.txt'])

In [0]:
test_data = load_data_test(uploaded_test['sample_test.txt'])

In [0]:
def prepare_sequence(seq, to_ix):
  idxs = [to_ix["<s>"]]
  for w in seq:
    if w in to_ix:
      idxs.append(to_ix[w])
    else:
      unk = get_word_class(w)
      idxs.append(to_ix[unk])
  idxs.append(to_ix["</s>"])
  return torch.tensor(idxs, dtype=torch.long)

In [0]:
def get_word_class(x):
    """
    Get fword class for a given word.
    Parameters
    ----------
    x : str
        word to be replaced
    Returns
    -------
    str
        corresponding word class
    """
    if re.fullmatch(r'[0-9]{2}', x):
        return '<unk>twoDigitNum'
    elif re.fullmatch(r'[0-9]{4}', x):
        return '<unk>fourDigitNum'
    elif re.fullmatch(r'A[0-9\-]+', x):
        return '<unk>containsDigitAndAlpha'
    elif re.fullmatch(r'[0-9]+\-[0-9]+', x):
        return '<unk>containsDigitAndDash'
    elif re.fullmatch(r'[0-9]+/[0-9]+/[0-9]+', x):
        return '<unk>containsDigitAndSlash'
    elif re.fullmatch(r'([0-9]+,[0-9]+)+\.[0-9]+', x):
        return '<unk>containsDigitAndComma'
    elif re.fullmatch(r'[0-9]+\.[0-9]+', x):
        return '<unk>containsDigitAndPeriod'
    elif re.fullmatch(r'[0-9]+', x):
        return '<unk>othernum'
    elif re.fullmatch(r'[A-Z]+', x):
        return '<unk>allCaps'
    elif re.fullmatch(r'[A-Z][a-z]+', x):
        return '<unk>initCap'
    elif re.fullmatch(r'[a-z]+', x):
        return '<unk>lowercase'
    elif re.match(r'[0-9]+:[0-9]+', x):
        return '<unk>time'

    return '<unk>'

In [0]:
def batch_data(batch_size, data, w2i):
  """
  Batches data with sequences of the same length
  """
  sentence_lengths = np.array([len(sentence) for sentence in data])
  sorted_idx = np.argsort(sentence_lengths)
  sorted_lengths = sentence_lengths[sorted_idx]

  len_increase_idx = []
  for i in range(1, len(sorted_lengths)):
    if sorted_lengths[i] > sorted_lengths[i-1]:
      len_increase_idx.append(i)
  len_increase_idx.append(len(sorted_lengths))

  batch_data = []
  curr_idx = 0
  for i, idx in enumerate(len_increase_idx):
    while curr_idx < idx:
      batch_sentences = []
      new_idx = min(curr_idx + batch_size, idx)
      for i in range(curr_idx, new_idx):
        sent_to_vec = prepare_sequence(data[sorted_idx[i]], w2i)
        batch_sentences.append(sent_to_vec)
      curr_idx = new_idx
      batch_sentences = torch.stack(batch_sentences).to(device=cuda)
      batch_data.append(batch_sentences)

  i = 0
  j = len(batch_data)
  while i < j:
    if i != 0 and len(batch_data[i]) <= 2 and len(batch_data[i][0]) == len(batch_data[i-1][0]):
      batch_data.append(torch.cat((batch_data[i], batch_data[i-1])))
      batch_data.pop(i)
      batch_data.pop(i-1)
      i -= 1
      j = len(batch_data)
    elif len(batch_data[i]) == 1:
      batch_data.pop(i)
      j = len(batch_data)
    else:
      i += 1

  return batch_data

In [0]:
batch_size = 4

In [0]:
use_cuda = True
cuda = torch.device('cuda')

In [0]:
batch_training = batch_data(batch_size, training_data, w2i)

In [0]:
batch_val = batch_data(batch_size, val_data, w2i)

In [0]:
batch_test = batch_data(batch_size, test_data, w2i)

In [0]:
class Encoder(nn.Module):

  def __init__(self, vocab_size, embedding_size, hidden_size, latent_size):
    super(Encoder, self).__init__()

    self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                   embedding_dim=embedding_size)
    
    self.rnn = nn.LSTM(input_size=embedding_size,
                        hidden_size=hidden_size,
                        batch_first=True)
    
    self.fc_mu = nn.Linear(hidden_size, latent_size)
    self.fc_var = nn.Linear(hidden_size, latent_size)

    self._initialize_parameters([-0.01, 0.01], [-0.1, 0.1]) 

  def encode(self, x):
    """
    Produces a Gaussian distribution over the possible values of the code z 
    from which x could have been generated
        
    Parameters
    ----------
      x: batch size x sequence length Tensor
        observed data

    Returns
    -------
      mu: batch size x latent size Tensor
        mean of Gaussian distribution
        
      logvar: batch size x latent size Tensor
        log of variance of Gaussian distribution     
    """
    x = self.embeddings(x)
    outputs, (hidden, cell) = self.rnn(x)
    mu = self.fc_mu(hidden)
    logvar = self.fc_var(hidden)
    mu = mu.squeeze()
    logvar = logvar.squeeze()
    return mu, logvar 

  def _initialize_parameters(self, lstm_init, embed_init):
    """
    Initializes the LSTM parameters and embeddings with uniform distributions
    """
    for param in self.parameters():
      nn.init.uniform_(param, a=lstm_init[0], b=lstm_init[1])
    nn.init.uniform_(self.embeddings.weight, a=embed_init[0], b=embed_init[1])

In [0]:
class Decoder(nn.Module):

  def __init__(self, vocab_size, embedding_size, hidden_size, latent_size):
    super(Decoder, self).__init__()

    self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                   embedding_dim=embedding_size)
    
    self.rnn = nn.LSTM(input_size=embedding_size + latent_size,
                        hidden_size=hidden_size,
                        batch_first=True)    

    self.fc_hid = nn.Linear(latent_size, hidden_size, bias=False)
    self.fc_voc = nn.Linear(hidden_size, vocab_size, bias=False)

    self.dropout = nn.Dropout()
    self._initialize_parameters([-0.01, 0.01], [-0.1, 0.1])

  def decode(self, z, inputs):
    """
    Given a code z it produces unscaled output corresponding to the vocabulary

    Parameters
    ----------
      z: batch size x latent size Tensor
        latent variables
      
      inputs: batch size x sequence length Tensor
        source sequence

    Returns
    -------
      output_logits: batch size x sequence length x vocab size Tensor
        unscaled output
    """
    cell = self.fc_hid(z)
    cell = cell.unsqueeze(0)
    hidden = torch.tanh(cell)

    embed = self.embeddings(inputs)
    embed = self.dropout(embed)
    z = z.expand(embed.size(1), z.size(0), z.size(1))
    z = z.transpose(1,0)
    embed_lat = torch.cat((embed, z), 2)

    outputs, (hidden, cell) = self.rnn(embed_lat, (hidden, cell))
    outputs = self.dropout(outputs)
    output_logits = self.fc_voc(outputs)
    return output_logits

  def decode_greedy(self, z, inputs, interpolation=False):
    """
    Given a code z it produces unscaled output corresponding to the vocabulary

    Parameters
    ----------
      z: batch size x latent size Tensor
        latent variables
      
      inputs: batch size x sequence length Tensor
        source sequence

    Returns
    -------
      output_logits: batch size x sequence length x vocab size Tensor
        unscaled output
      
      batch_decoded: batch size x output sequence length list
        decoded output sequence
    """
    cell = self.fc_hid(z)
    cell = cell.unsqueeze(0)
    hidden = torch.tanh(cell)

    batch_size = inputs.size(0)
    input_d = inputs[:,0]
    output_logit_prev = None
    seq_len = inputs.size(1)
    batch_decoded = [[] for j in range(batch_size)]

    end_mask = torch.ones(batch_size)
    counter = 0
    while end_mask.sum() != 0 and counter < seq_len:
      embed = self.embeddings(input_d)
      embed_lat = torch.cat((embed, z), 1)
      embed_lat = embed_lat.unsqueeze(1)
      outputs, (hidden, cell) = self.rnn(embed_lat, (hidden, cell))  
      output_logit = self.fc_voc(outputs)
      if output_logit_prev is not None:
        output_logits = torch.cat((output_logit_prev, output_logit), dim=1)
        output_logit_prev = output_logits
      else:
        output_logit_prev = output_logit
      input_d = torch.argmax(output_logit, dim=2).flatten()

      for k in range(batch_size):
        if end_mask[k] != 0:
          if interpolation and input_d[k].item() == w2i["</s>"] :
            end_mask[k] = 0
          else:
            token = i2w[input_d[k].item()]
            batch_decoded[k].append(token)
      counter += 1
    
    return output_logits, batch_decoded

  def _initialize_parameters(self, lstm_init, embed_init):
    """
    Initializes the LSTM parameters and embeddings with uniform distributions
    """
    for param in self.parameters():
      nn.init.uniform_(param, a=lstm_init[0], b=lstm_init[1])
    nn.init.uniform_(self.embeddings.weight, a=embed_init[0], b=embed_init[1])

In [0]:
class VAE(nn.Module):

  def __init__(self, encoder, decoder):
    super(VAE, self).__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.re_loss = nn.CrossEntropyLoss(reduction='none')

  def forward(self, x, greedy=False):
    """
    Forward pass of the model 
    """
    mu, logvar = self.encoder.encode(x)
    kl = self.get_kl(mu, logvar)
    z = self._reparameterize(mu, logvar)

    source = x[:,:-1]
    target = x[:, 1:]
    if greedy:
      output_logits, batch_decoded = self.decoder.decode_greedy(z, source)
      re = self.get_reconstruction_error(output_logits, target)
      return kl, re, batch_decoded
    else: 
      output_logits = self.decoder.decode(z, source)
      re = self.get_reconstruction_error(output_logits, target)
      return kl, re

  def _reparameterize(self, mu, logvar):
    """
    Reparameterize the random variable z to express as a deterministic variable
    
    Parameters
    ----------
      mu: batch size x latent size Tensor
        mean of Gaussian distribution
        
      logvar: batch size x latent size Tensor
        log of variance of Gaussian distribution     
    
    Returns
    -------
      z: batch size x latent size Tensor
        reparameterization of latent variables
    """
    std = torch.exp(logvar / 2)
    eps = torch.randn_like(std)
    return mu + std * eps  

  def get_kl(self, mu, logvar):
    """
    Returns the KLD between posterior and prior

    Parameters
    ----------
      mu: batch size x latent size Tensor
        mean of Gaussian distribution
        
      logvar: batch size x latent size Tensor
        log of variance of Gaussian distribution
    
    Returns
    -------
      kl: batch size x latent size Tensor
        kl divergence
    """
    return (mu**2 + logvar.exp() - 1 - logvar) / 2

  def get_reconstruction_error(self, output_logits, target):
    """
    Returns the reconstruction error

    Parameters
    ----------
      output_logits: batch size x sequence length x vocab size Tensor
        unscaled output
      
      target: batch size x sequence length Tensor
        target sequence
    """
    target = target.contiguous().view(-1)
    output_logits = output_logits.view(-1, output_logits.size(2))
    return self.re_loss(output_logits, target)


In [0]:
vocab_size = len(w2i)
embedding_size = 128
hidden_size = 512 
latent_size = 32

In [0]:
print_every = round(len(batch_training) / batch_size)

In [0]:
class DecayLearning:
  """
  Class updated from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
  """
  def __init__(self, patience=2):
    self.patience = patience
    self.counter = 0
    self.best_score = None
    self.update_lr = False

  def __call__(self, val_loss):
    score = -val_loss

    if self.best_score is None:
      self.best_score = score
    elif score < self.best_score:
      self.counter += 1
      if self.counter >= self.patience:
        self.update_lr = True
    else:
      self.best_score = score
      self.counter = 0

In [0]:
def interpolate(latent_size, sample1, sample2):
  
  for w in range(11):
    weight = w * 0.1
    sample = weight * sample2 + (1-weight) * sample1
    _, batch_decoded = decoder.decode_greedy(sample, torch.zeros(13, dtype=torch.long, device=cuda).unsqueeze(0), True)
    print(*batch_decoded[0])

Pre-training on Autoencoder objective

Pre-training ran for 3 hours

In [0]:
def train(epoch):
  """
  Trains the model
  """
  model.train()
  running_loss = 0.0
  for batch_idx, data in enumerate(batch_training):
    optimizer_e.zero_grad()
    optimizer_d.zero_grad()
    re = model(data)
    loss = re
    loss.backward()
    optimizer_e.step()
    optimizer_d.step()

    running_loss += loss
  
    if (epoch == 0 or epoch % 10 == 9) and batch_idx % print_every == print_every-1:    
      print('[%d, %5d] Train loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / print_every))
      running_loss = 0.0 

In [0]:
def test(epoch, validation=False):
  """
  Run the model on validation or test dataset
  """
  model.eval()
  anneal = 0.1 * epoch if epoch < 10 else 1.0

  data = batch_val if validation else batch_test
  running_loss = 0.0
  for batch_idx, data in enumerate(data):
    if validation:  
      re = model(data)
    else:
      re, batch_decoded = model(data, True)
    loss = re

    running_loss += loss

  avg_loss = running_loss / (batch_idx + 1)
  if validation and (epoch == 0 or epoch % 10 == 9):
    print('[%d, %5d] Validation loss: %.3f' % (epoch + 1, batch_idx + 1, avg_loss))
  else:
    print('[%d, %5d] Test loss: %.3f' % (epoch + 1, batch_idx + 1, avg_loss))
    print(*batch_decoded[0])

  return avg_loss

In [0]:
encoder = Encoder(vocab_size, embedding_size, hidden_size, latent_size)
decoder = Decoder(vocab_size, embedding_size, hidden_size, latent_size)
model = VAE(encoder, decoder)
if use_cuda and torch.cuda.is_available():
  model.cuda()
  
lr = 0.5
optimizer_e = optim.SGD(encoder.parameters(), lr=lr)
optimizer_d = optim.SGD(decoder.parameters(), lr=lr)

decay_learning = DecayLearning()
num_decays = 0
early_stop = 5

for epoch in range(10):
  train(epoch)
  val_loss = test(epoch, True)
  decay_learning(val_loss)
  if decay_learning.update_lr:
    lr *= 0.5
    num_decays += 1
    if num_decays == early_stop + 1:
      print("Stopping early at epoch", epoch+1)
      break
    print("Learning rate has been decayed to", lr, "at epoch", epoch+1)
    optimizer_e = optim.SGD(encoder.parameters(), lr=lr)
    optimizer_d = optim.SGD(decoder.parameters(), lr=lr)
    decay_learning = DecayLearning()

[1,  3125] Train loss: 3.947
[1,  6250] Train loss: 3.595
[1,  9375] Train loss: 3.789
[1, 12500] Train loss: 3.809
[1,  1248] Validation loss: 4.795
Learning rate has been decayed to 0.25 at epoch 4
[10,  3125] Train loss: 1.061
[10,  6250] Train loss: 1.135
[10,  9375] Train loss: 1.453
[10, 12500] Train loss: 1.745
[10,  1248] Validation loss: 2.527


In [0]:
# epoch in range(10,20)

Learning rate has been decayed to 0.125 at epoch 17
[20,  3125] Train loss: 0.276
[20,  6250] Train loss: 0.382
[20,  9375] Train loss: 0.614
[20, 12500] Train loss: 0.917
[20,  1248] Validation loss: 2.240


In [0]:
# epoch in range(20,30)

Learning rate has been decayed to 0.0625 at epoch 23
[30,  3125] Train loss: 0.136
[30,  6250] Train loss: 0.214
[30,  9375] Train loss: 0.391
[30, 12500] Train loss: 0.666
[30,  1248] Validation loss: 1.204


In [0]:
# epoch in range(30,40)

Learning rate has been decayed to 0.03125 at epoch 32
[40,  3125] Train loss: 0.096
[40,  6250] Train loss: 0.158
[40,  9375] Train loss: 0.311
[40, 12500] Train loss: 0.554
[40,  1248] Validation loss: 0.786


In [0]:
# epoch in range(40,50)

Learning rate has been decayed to 0.015625 at epoch 43
[50,  3125] Train loss: 0.081
[50,  6250] Train loss: 0.137
[50,  9375] Train loss: 0.273
[50, 12500] Train loss: 0.504
[50,  1248] Validation loss: 0.678


In [0]:
# epoch in range(50,60):

Stopping early at epoch 58


In [0]:
test(58, False)

[59,  1249] Test loss: 1.142
A caucasian wearing a green jacket and a hat and yellow shirts . </s>


tensor(1.1421, device='cuda:0', grad_fn=<DivBackward0>)

In [0]:
interpolate(latent_size)

friends are celebrating a white and one on another shirt
the children were a yellow dress and riding an object
the children were wearing yellow and shoes with a helmet .
the children have a long , green on and no .
the children chews and riding and have on another , outside
An animal was doing karate and wearing no and sandals outside .
An artist was balancing and boots at a small dog vendor .
An artist was balancing and vegetables at another boy and shovel .
A mother was selling vegetables and on crafts and vegetables together .
A mother does not seen and at something , cooking flowers .
A blond man shoots wine while wearing no hat and clothes .


In [0]:
interpolate(latent_size)

<unk>initCap , people fight
<unk>initCap , people in dogs are running
<unk>initCap , people in snow are looking
Three firemen , a car are running across the car .
Three firemen , a car are running across the car .
Three firemen are a hot dog looking along the ocean .
Three firemen are holding two dogs along along <unk>lowercase .
Two firemen are holding a bus back along the sand .
Two sheep are having two fishing walking across the .
Two sheep are having two fishing going across leaves .
Two sheep are having two fishing to the left .


In [0]:
interpolate(latent_size)

some tourists run
some actors has work
some actors has work while look
some spectators on opposing people are tired .
two guys of flowers are posing while both
two guys in helmets are posing , waits outside and .
two guys in helmets are posing , waits outside and .
Three guys in uniform are posing while both outside .
Three guys in pink clothes are posing while holding bags .
Two guys in pink uniforms are running , outside .
Two guys in pink uniforms are running outside and outside .


In [0]:
interpolate(latent_size)

This water rides through train tracks , his dog 's wings .
a bucket from <unk>lowercase of pigeons , while his feet is life
a bucket from <unk>lowercase of people because she was all in her
a light boat of some passengers while wearing shorts 's mouth .
a lone boat full of several people plays an white cone .
a lone lot of some passengers while a brown car sleeps .
a lone bicyclist riding a bucking dog and an obstacle stand .
a bull , near a grocery store while an angel .
a bull , riding a bucking dog outside of flowers .
a small soldier riding a unicycle while wearing ear gear .
a black cat carrying a flag while wearing ear gear .


Free Bits Training

Free Bits ran for 1 and a half hours

In [28]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
"""
Helpful Link: https://pytorch.org/tutorials/beginner/saving_loading_models.html
"""
model_save_name = 'pretraining.pt'
path = F"/content/gdrive/My Drive/{model_save_name}" 
checkpoint = torch.load(path)

In [0]:
def train(epoch, target_rate=(4.0 / float(latent_size))):
  """
  Trains the model
  Helpful link for free bits: https://stats.stackexchange.com/questions/267924/explanation-of-the-free-bits-technique-for-variational-autoencoders
  """
  model.train()
  anneal = 0.1 * epoch if epoch < 10 else 1.0

  running_loss = running_kl = running_re = 0.0
  num_seq = 0
  for batch_idx, data in enumerate(batch_training):
    batch_size, seq_len = data.size()
    num_seq += batch_size
    optimizer_e.zero_grad()
    optimizer_d.zero_grad()
    kl, re = model(data)
    re_mean = re.view(batch_size, -1).mean(dim=0)
    re_loss = re_mean.sum()
    kl_mean = kl.mean(dim=0)
    kl_mask = (kl_mean > target_rate).float()
    fb_mask = (kl_mean <= target_rate).float()
    free_b = kl_mask + target_rate
    kl_fb = (kl_mean * kl_mask + free_b * fb_mask).sum()
    loss = kl_fb * anneal + re_loss
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
    optimizer_e.step()
    optimizer_d.step()

    running_kl += kl.sum()
    running_re += re.sum()
    running_loss += kl.sum() + re.sum()
  
    if (epoch == 0 or epoch % 10 == 9) and  batch_idx % print_every == print_every-1:    
      print('[%d, %5d] Train loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / num_seq))
      print('[%d, %5d] Train KL: %.3f' % (epoch + 1, batch_idx + 1, running_kl / num_seq))
      print('[%d, %5d] Train RE: %.3f' % (epoch + 1, batch_idx + 1, running_re / num_seq))
      running_loss = running_kl = running_re = 0.0

In [0]:
def test(epoch, validation=False, target_rate = (4.0 / float(latent_size))):
  """
  Run the model on validation or test dataset
  """
  model.eval()
  anneal = 0.1 * epoch if epoch < 10 else 1.0

  data = batch_val if validation else batch_test
  running_loss = running_kl = running_re = 0.0
  num_words = num_seq = 0
  for batch_idx, data in enumerate(data):
    if validation:  
      kl, re = model(data)
    else:
      kl, re, batch_decoded = model(data, True)
    batch_size, seq_len = data.size()
    num_words += batch_size * (seq_len - 1)
    num_seq += batch_size
    running_kl += kl.sum()
    running_re += re.sum()

  avg_loss = (running_kl + running_re) / num_seq
  avg_kl = running_kl / num_seq
  avg_re = running_re / num_seq
  if validation and (epoch == 0 or epoch % 10 == 9):
    print('[%d, %5d] Validation loss: %.3f' % (epoch + 1, batch_idx + 1, avg_loss))
    print('[%d, %5d] Validation KL: %.3f' % (epoch + 1, batch_idx + 1, avg_kl))
    print('[%d, %5d] Validation RE: %.3f' % (epoch + 1, batch_idx + 1, avg_re))
  elif not validation:
    print('[%d, %5d] Test loss: %.3f' % (epoch + 1, batch_idx + 1, avg_loss))
    print('[%d, %5d] Test KL: %.3f' % (epoch + 1, batch_idx + 1, avg_kl))
    print('[%d, %5d] Test RE: %.3f' % (epoch + 1, batch_idx + 1, avg_re))
    nll = (running_kl + running_re) / num_seq
    pdb.set_trace()
    ppl = torch.exp((running_kl + running_re) / num_words)
    print('[%d, %5d] Test NLL: %.3f' % (epoch + 1, batch_idx + 1, nll))
    print('[%d, %5d] Test PPL: %.3f' % (epoch + 1, batch_idx + 1, ppl))
    print(*batch_decoded[0])

  return avg_loss

In [0]:
encoder = Encoder(vocab_size, embedding_size, hidden_size, latent_size)
encoder.load_state_dict(checkpoint['encoder'])
decoder = Decoder(vocab_size, embedding_size, hidden_size, latent_size)
model = VAE(encoder, decoder)

if use_cuda and torch.cuda.is_available():
  model.cuda()

In [38]:
lr = 0.5
optimizer_e = optim.SGD(encoder.parameters(), lr=lr)
optimizer_d = optim.SGD(decoder.parameters(), lr=lr)
decay_learning = DecayLearning()
num_decays = 0
early_stop = 5
clip_grad = 5.0

for epoch in range(10):
  train(epoch)
  val_loss = test(epoch, True)
  decay_learning(val_loss)
  if decay_learning.update_lr:
    lr *= 0.5
    num_decays += 1
    if num_decays == early_stop + 1:
      print("Stopping early at epoch", epoch+1)
      break
    print("Learning rate has been decayed to", lr, "at epoch", epoch+1)
    optimizer_e = optim.SGD(encoder.parameters(), lr=lr)
    optimizer_d = optim.SGD(decoder.parameters(), lr=lr)
    decay_learning = DecayLearning()

[1,  3125] Train loss: 127.056
[1,  3125] Train KL: 119.090
[1,  3125] Train RE: 7.966
[1,  6250] Train loss: 63.078
[1,  6250] Train KL: 56.953
[1,  6250] Train RE: 6.125
[1,  9375] Train loss: 43.746
[1,  9375] Train KL: 36.464
[1,  9375] Train RE: 7.282
[1, 12500] Train loss: 35.303
[1, 12500] Train KL: 27.303
[1, 12500] Train RE: 8.000
[1,  1248] Validation loss: 146.207
[1,  1248] Validation KL: 120.718
[1,  1248] Validation RE: 25.489
[10,  3125] Train loss: 24.723
[10,  3125] Train KL: 5.652
[10,  3125] Train RE: 19.071
[10,  6250] Train loss: 15.919
[10,  6250] Train KL: 3.192
[10,  6250] Train RE: 12.727
[10,  9375] Train loss: 13.754
[10,  9375] Train KL: 2.340
[10,  9375] Train RE: 11.414
[10, 12500] Train loss: 12.884
[10, 12500] Train KL: 1.865
[10, 12500] Train RE: 11.019
[10,  1248] Validation loss: 45.523
[10,  1248] Validation KL: 8.090
[10,  1248] Validation RE: 37.434


In [39]:
# epoch in range(10,20)

Learning rate has been decayed to 0.25 at epoch 13
[20,  3125] Train loss: 21.244
[20,  3125] Train KL: 4.252
[20,  3125] Train RE: 16.992
[20,  6250] Train loss: 13.835
[20,  6250] Train KL: 2.380
[20,  6250] Train RE: 11.454
[20,  9375] Train loss: 12.016
[20,  9375] Train KL: 1.683
[20,  9375] Train RE: 10.333
[20, 12500] Train loss: 11.349
[20, 12500] Train KL: 1.302
[20, 12500] Train RE: 10.046
[20,  1248] Validation loss: 41.028
[20,  1248] Validation KL: 5.381
[20,  1248] Validation RE: 35.647


In [40]:
# epoch in range(20,30):

Learning rate has been decayed to 0.125 at epoch 22
Learning rate has been decayed to 0.0625 at epoch 25
[30,  3125] Train loss: 19.931
[30,  3125] Train KL: 4.241
[30,  3125] Train RE: 15.689
[30,  6250] Train loss: 12.676
[30,  6250] Train KL: 2.328
[30,  6250] Train RE: 10.348
[30,  9375] Train loss: 11.080
[30,  9375] Train KL: 1.650
[30,  9375] Train RE: 9.430
[30, 12500] Train loss: 10.416
[30, 12500] Train KL: 1.295
[30, 12500] Train RE: 9.120
[30,  1248] Validation loss: 41.414
[30,  1248] Validation KL: 5.076
[30,  1248] Validation RE: 36.338


In [41]:
# epoch in range(30,40):

Learning rate has been decayed to 0.03125 at epoch 31
Learning rate has been decayed to 0.015625 at epoch 37
[40,  3125] Train loss: 20.112
[40,  3125] Train KL: 4.295
[40,  3125] Train RE: 15.816
[40,  6250] Train loss: 12.550
[40,  6250] Train KL: 2.330
[40,  6250] Train RE: 10.221
[40,  9375] Train loss: 10.884
[40,  9375] Train KL: 1.651
[40,  9375] Train RE: 9.233
[40, 12500] Train loss: 10.272
[40, 12500] Train KL: 1.297
[40, 12500] Train RE: 8.975
[40,  1248] Validation loss: 37.370
[40,  1248] Validation KL: 4.883
[40,  1248] Validation RE: 32.487
Stopping early at epoch 40


In [42]:
test(39, False)

[40,  1249] Test loss: 76.540
[40,  1249] Test KL: 4.834
[40,  1249] Test RE: 71.705
[40,  1249] Test NLL: 76.540
[40,  1249] Test PPL: 3068.368
The two men are wearing a red shirt , and white pants , smiling


tensor(76.5395, device='cuda:0', grad_fn=<DivBackward0>)

In [0]:
sample1 = torch.randn(1, latent_size, device=cuda)

In [0]:
sample2 = torch.randn(1, latent_size, device=cuda)

In [62]:
interpolate(latent_size, sample1, sample2)

An old man stands in front of a crowd of people .
An old man stands in front of a crowd of people .
An old man stands in front of a crowd of people .
An old man stands in front of a crowd of people .
The man is standing in front of a crowd of people .
The man is standing in front of a crowd of people .
The man is standing in front of a building , wearing a suit
The man is not wearing a red shirt , and black pants ,
The man is wearing a red shirt , and a black hat smiles
The man is in the kitchen , wearing a hat , and jeans
The man is in the kitchen , while others look on the sidewalk


In [0]:
sample1 = torch.randn(1, latent_size, device=cuda)

In [0]:
sample2 = torch.randn(1, latent_size, device=cuda)

In [69]:
interpolate(latent_size, sample1, sample2)

A man in a blue shirt is in front of a building .
A man in a blue shirt is in front of a building .
A man in a blue shirt is walking down the street .
A man in a blue shirt is walking down a street corner .
A man in a blue shirt is walking down a street corner .
A man in a black shirt is walking down a street corner .
A man in a black shirt is walking down a street corner .
A man with a beard is walking down a street , holding a
A man wearing a black shirt , standing in front of a building
A man is playing guitar , and singing into a microphone .
A man is playing guitar , and singing into a microphone .


In [0]:
sample1 = torch.randn(1, latent_size, device=cuda)

In [0]:
sample2 = torch.randn(1, latent_size, device=cuda)

In [73]:
interpolate(latent_size, sample1, sample2)

A man is about to get ready to go to work .
A man is about to get ready to go to work .
A man is having a conversation , while sitting on a bench .
A man is having a conversation on a sunny day 's shoulder .
A man is having fun in a classroom , is running on the
A man is sitting on a bench , while others watch him .
A man is sitting on a bench , while others watch him .
A man is sitting on the couch , while others watch him .
A man is sitting on the couch , while others watch him .
A man is sitting on the couch , while others watch him .
A man sits on the couch , while another man watches him .


In [0]:
import random

# Training
f = open('./snli_1.0/snli_1.0_train.txt')
contents = f.read()
file_as_list = contents.splitlines()
sentences = []
for example in file_as_list[1:]:
    sentence1 = example.split('\t')[5]
    sentence2 = example.split('\t')[6]
    sent1_len = len(sentence1.split())
    sent2_len = len(sentence2.split())
    if sent1_len <= 12: sentences.append(sentence1)
    if sent2_len <= 12: sentences.append(sentence2)

sample = random.sample(sentences, 50000)

f = open('./snli_1.0/sample_train.txt', "w")
for sentence in sample:
    f.write(sentence)
    f.write("\n")
f.close()

# Development
f = open('./snli_1.0/snli_1.0_dev.txt')
contents = f.read()
file_as_list = contents.splitlines()
sentences = []
for example in file_as_list[1:]:
    sentence1 = example.split('\t')[5]
    sentence2 = example.split('\t')[6]
    sent1_len = len(sentence1.split())
    sent2_len = len(sentence2.split())
    if sent1_len <= 12: sentences.append(sentence1)
    if sent2_len <= 12: sentences.append(sentence2)

sample = random.sample(sentences, 5000)

f = open('./snli_1.0/sample_dev.txt', "w")
for sentence in sample:
    f.write(sentence)
    f.write("\n")
f.close()

# Test
f = open('./snli_1.0/snli_1.0_test.txt')
contents = f.read()
file_as_list = contents.splitlines()
sentences = []
for example in file_as_list[1:]:
    sentence1 = example.split('\t')[5]
    sentence2 = example.split('\t')[6]
    sent1_len = len(sentence1.split())
    sent2_len = len(sentence2.split())
    if sent1_len <= 12: sentences.append(sentence1)
    if sent2_len <= 12: sentences.append(sentence2)

sample = random.sample(sentences, 5000)

f = open('./snli_1.0/sample_test.txt', "w")
for sentence in sample:
    f.write(sentence)
    f.write("\n")
f.close()

In [0]:
model_save_name = 'pretraining.pt'
path = F"/content/gdrive/My Drive/{model_save_name}" 
torch.save({
    'model':model.state_dict(), 
    'encoder':encoder.state_dict(), 
    'decoder':decoder.state_dict(), 
    'optimizer_e':optimizer_e.state_dict(), 
    'optimizer_d':optimizer_d.state_dict()
    }, 
    path)

In [0]:
model_save_name = 'fb.pt'
path = F"/content/gdrive/My Drive/{model_save_name}" 
torch.save({
    'model':model.state_dict(), 
    'encoder':encoder.state_dict(), 
    'decoder':decoder.state_dict(), 
    'optimizer_e':optimizer_e.state_dict(), 
    'optimizer_d':optimizer_d.state_dict()
    }, 
    path)