# Chatbot

### Load necessary modules

In [None]:
from os import path, getcwd
import pickle
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.autograd import Variable
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
from tqdm import tqdm_notebook
from ipywidgets import widgets  # ipywidgets should be version 7 or higher
from IPython.display import display
from sympy import Symbol
from concurrent.futures import ProcessPoolExecutor
from torch.nn.utils import weight_norm
from dataset import *
from data_utils import *
from model_utils import *
from rnn import RNN, DecoderRNN

%matplotlib inline
ROOT_DIR = os.getcwd()
DATA_DIR = path.join(ROOT_DIR, 'data')
MODEL_DIR = path.join(ROOT_DIR, 'model')
EMBED_DIR = path.join(MODEL_DIR, '.vector_cache')  # pre-trained embeddings
USE_CUDA = torch.cuda.is_available()
if USE_CUDA:
    print("Using GPU")
else:
    print("Not using GPU")

### Download and Unpack Necessary Files, if needed
The data we use are twitter chatlogs from https://github.com/Marsan-Ma/chat_corpus

In [None]:
download('https://github.com/Marsan-Ma/chat_corpus/raw/master/twitter_en.txt.gz',
         ['twitter_en.txt.gz', 'twitter_en.txt'], DATA_DIR)
download('https://github.com/Marsan-Ma/chat_corpus/raw/master/twitter_en_big.txt.gz.partaa',
         ['twitter_en_big.txt.gz.partaa', 'twitter_en_big.txt.gz', 'twitter_en_big.txt'],
         DATA_DIR)
download('https://github.com/Marsan-Ma/chat_corpus/raw/master/twitter_en_big.txt.gz.partab',
         ['twitter_en_big.txt.gz.partab', 'twitter_en_big.txt.gz', 'twitter_en_big.txt'],
         DATA_DIR)

# concatenate twitter_en_big.txt.gz.partaa and .partab if needed
concatenate_two_gz(path.join(DATA_DIR, 'twitter_en_big.txt.gz'), '.partaa', '.partab')
!chmod +w data  # make sure we have write permission in data directory
# unzip gz files, as needed
unzip_gz('twitter_en.txt.gz', DATA_DIR)
unzip_gz('twitter_en_big.txt.gz', DATA_DIR)
# create a short sample.txt file with only a few lines
create_sample('twitter_en.txt', 'sample.txt', DATA_DIR, 20000)
!mkdir model  # create directory for saving models

## Load and build dataset
Create Dataset and Split to Train, Validation, and Test

In [None]:
# Uncomment one of the following three lines to select data file
FILE_NAME = 'sample.txt'  # short text file for dev
#FILE_NAME = 'twitter_en.txt'  # medium length text file (754530 lines)
#FILE_NAME = 'twitter_en_big.txt'  # full text file (5202488 lines)

load_data = False  # choose whether to load data
save_data = False  # choose whether to save data

FILE_PATH = path.join(DATA_DIR, FILE_NAME)
CHAT_DATA_NAME = 'chat_data_' + FILE_NAME[:-4] + '.p'
CHAT_DATA_PATH = path.join(DATA_DIR, CHAT_DATA_NAME)
EMBED_DIM = 200  # dimension of embedding vectors

if load_data:
    with open(CHAT_DATA_PATH, 'rb') as f:
        dataset = pickle.load(f)
else:
    #glove = vocab.GloVe('twitter.27B', dim=EMBED_DIM, cache=EMBED_DIR)
    dataset = ChatDataset(data_path = FILE_PATH,  # path to data tile
                          max_length = 12,  # maximum length of sentence
                          max_vocab_size = 8000,  # maximum size of vocabulary
                          min_freq = 6,  # minimum frequency to add word to vocabulary
                          eos_token = '<eos>',  # end of sentence token
                          pad_token = '<pad>',  # padding to keep sentence lengths equal
                          unk_token = '<unk>',  # unknown word (word not in vocabulary)
                          special_tokens = [],  # any other tokens to add to vocabulary
                          embed_dim = EMBED_DIM,  # dimension of embedding vectors
                          threshold = 3)  # count of unk required to remove sentence
                          #pre_trained = glove)  # pre_trained word embeddings

    if save_data:
        with open(CHAT_DATA_PATH, 'wb') as f:
            pickle.dump(dataset, f)

print("Number of words in vocabulary: %d" % dataset.nwords)
print("Number of sentences in data: %d" % len(dataset))
print("Number of unknown words in data: %d" % dataset.unk_count)
print("Total number of words in data: %d\n" % dataset.total_tokens)

# Split to training, validation, and test set
train_sampler, valid_sampler, test_sampler = split_data(dataset, 0.6, 0.2, 0.2)

print("Training set size: %d" % len(train_sampler))
print("Validation set size: %d" % len(valid_sampler))
print("Test set size: %d" % len(test_sampler))

## Sequence to Sequence Model

In [None]:
class EncoderRNN(RNN):

    def __init__(self, input_size, hidden_size, nlayers, embed_dim,
                 rnn_type, pad_idx, use_cuda, dropout, bidirect=True):
        super().__init__(input_size, hidden_size, nlayers, embed_dim,
                         rnn_type, pad_idx, use_cuda, dropout, True)  # bidrectional
        self.init_weights()  # initialize weights when initializing rnn

    def forward(self, input, hidden, lengths, max_len):
        batch_size = input.size()[0]
        embedded = self.embedding(input)
        output = pack_padded_sequence(embedded, lengths, batch_first=True)
        output, hidden = self.rnn(output, hidden)
        # unpack packed sequence for use in decoder with attention
        output, out_lens = pad_packed_sequence(output, batch_first=True)
        # refill padding
        n_pad = max_len - max(out_lens)
        if n_pad > 0:
            padding = Variable(torch.zeros(batch_size, n_pad, self.hidden_size))
            padding = padding.cuda() if self.is_cuda() else padding
            output = torch.cat((output, padding), 1)
        return output, hidden

We implement a Decoder model with Attention, where we wish to learn how much to peek into the encoder's topmost hidden state to gain some additional information. In order to learn where to look and how much to look, we the scoring function $\mathbf{h}_t^T \mathbf{W} \mathbf{\bar{h}}_s$ from Luong et al, Effective Approaches to Attention-based Neural Machine Translation, 2015. The first $\mathbf{h}$ is the top hidden layer from the previous time step and the second $\mathbf{\bar{h}}$ is a top hidden layer from the encoder.

With the attention mechanism in place, whenever the decoder outputs the unknown word token, we can replace the output with the input word given at the time index with the highest attention score, following Luong et al, Addressing the Rare Word Problem in Neural Machine Translation, 2015.

In [None]:
class AttDecoder(RNN):
    """
    Decoder with attention.
    """

    def __init__(self, input_size, hidden_size, nlayers, embed_dim,
                 rnn_type, pad_idx, max_len, use_cuda, dropout, bidirect=False):
        super().__init__(input_size, hidden_size, nlayers, embed_dim,
                         rnn_type, pad_idx, use_cuda, dropout, False)
        self.seq_len = max_len
        # function to calculate scores of which time step to attend
        self.score_fn = nn.Bilinear(hidden_size, hidden_size, 1, False)
        # calculates weights of how much each time step to attend
        self.score_softmax = nn.Softmax()

        # Intermediate linear mapping between output embedding and rnn
        # in order to tie input and output embedding
        self.linear = nn.Linear(hidden_size, embed_dim, bias=False)

        self.output_embed = nn.Linear(embed_dim, input_size, bias=False)
        self.softmax = nn.LogSoftmax()
        self.init_weights()

    def init_weights(self):
        super().init_weights()
        init_range = 0.05
        self.score_fn.weight.data.uniform_(-init_range, init_range)
        self.linear.weight.data.uniform_(-init_range, init_range)
        self.output_embed.weight.data.uniform_(-init_range, init_range)
        weight_norm(self.score_fn)

    def forward(self, input, prev_hidden, encoder_hidden):
        """
        Args:
            input (Variable): input words to go into rnn
            prev_hidden (Variable): hidden layer from previous time step
            encoder_hidden (Variable) : top level hidden layers from encoder

        Dimensions:
            input: batch size * 1 (one word at a time)
            prev_hidden: number of layers * batch size * hidden size
            encoder_hidden: batch size * max sentence length * hidden size
        """
        batch_size = input.size()[0]
        # input is one word, so unsqueeze second dimension (sequence length)
        embedded = self.embedding(input).unsqueeze(1)
        # get highest previous hidden layer
        # and get lower hidden layer, which will be directly fed this time step
        if isinstance(prev_hidden, tuple):  # encoder was LSTM
            top_prev_hidden = prev_hidden[0][-1]
            if self.nlayers > 1:
                lower_prev_hidden = prev_hidden[0][1:]
        else:  # encoder was GRU
            top_prev_hidden = prev_hidden[-1]
            if self.nlayers > 1:
                lower_prev_hidden = prev_hidden[1:]

        # container for scores
        scores = Variable(torch.zeros(batch_size, self.seq_len), requires_grad=False)
        scores = scores.cuda() if self.is_cuda() else scores
        # calculate pairwise scores of encoder hidden at time i and top_prev_hidden
        for i in range(self.seq_len):
            scores[:, i] = self.score_fn(top_prev_hidden, encoder_hidden[:, i])
        # normalize scores to get probabilites, which will be used to weigh
        # relevence of encoder hidden layer at time i
        # Dimension is: batch_size * 1 * seq_len
        normalized_scores = self.score_softmax(scores).unsqueeze(1)
        # weighted sum of scores and encoder output over sentence indexes
        # Dimension is: batch_size * 1 * hidden_size
        context_vector = torch.bmm(normalized_scores, encoder_hidden)
        # Rearrange such that number of layers (1) comes first
        context_vector = context_vector.permute(1, 0, 2)
        # Finally, we set current highest hidden state to the context vector
        if self.nlayers > 1:
            curr_hidden = torch.cat((context_vector, lower_prev_hidden), 0)
        else:
            curr_hidden = context_vector

        if isinstance(prev_hidden, tuple):
            prev_hidden[0].data = curr_hidden.data
        else:
            prev_hidden.data = curr_hidden.data

        output, hidden = self.rnn(embedded, prev_hidden)
        output = self.softmax(self.output_embed(self.linear(output[:, 0, :])))

        # return index with highest attention score
        _, max_atten_idx = torch.max(scores[0], 0)
        return output, hidden, max_atten_idx[0]

## Train Sequence to Sequence Model
First, we initialize the necessary components of the model, including the encoder, decoder, optimizer and learning rate scheduler of optimizer.

We employ three methods to regularize out model.
We have dropout on connections between each layer at any one time step are drawn independently at each time step just like in feed forward nets, as in Melis et al. On the State of the Art of Evaluation in Neural Language Models, 2017.
We also tie the weights of the encoder embedding matrix and the decoder linear output mapping matrix, following Press and World, Using the Output Embedding to Improve Language Models, 2016.
In addition, we employ weight decay to prevent weights from becoming to big.

Since we use a bidirectional encoder, we do not use variational dropout from Gal and Ghahramani 2016 (keeping the same dropout mask with each time step between hidden states), as implementing variational dropout for a bidirectional encoder will require major workarounds.

In [None]:
NHIDDEN = 514  # hidden layer size
NLAYERS = 2  # number of recurrent layers
EMBED_DIM = dataset.embed_dim  # dimension of embedding vectors
LEARNING_RATE = 0.001  # initial learning rate
DROP_L = 0.2  # probability of dropping connections between layers

encoder = EncoderRNN(dataset.nwords, NHIDDEN, NLAYERS, EMBED_DIM, 'GRU',
                     dataset.pad_idx, USE_CUDA, DROP_L)
decoder = AttDecoder(dataset.nwords, NHIDDEN, NLAYERS, EMBED_DIM, 'GRU',
                     dataset.pad_idx, dataset.max_len, USE_CUDA, DROP_L)
# Use pre-trained embedding weights, if needed
#encoder.embedding.weight = dataset.vocab.vectors
#decoder.embedding.weight = dataset.vocab.vectors
# Tie input embedding and output embedding
encoder.embedding.weight = decoder.output_embed.weight

# Do not calculate loss when target is padding, encouraging model to say more
loss = nn.NLLLoss(ignore_index=dataset.pad_idx, size_average=False)

if USE_CUDA:
    encoder.cuda()
    decoder.cuda()
    loss = loss.cuda()

# need to create optimizer after .cuda() call on models
e_opt = optim.SGD(encoder.parameters(), lr=LEARNING_RATE, weight_decay=0.001)
d_opt = optim.SGD(decoder.parameters(), lr=LEARNING_RATE, weight_decay=0.001)
# halve lr if validation loss is not reduced after patience*plot_every epochs
e_sched = ReduceLROnPlateau(e_opt, factor=0.5, patience=5)
d_sched = ReduceLROnPlateau(d_opt, factor=0.5, patience=5)

best_val_loss = float('inf')  # best validation loss so far
SAVE_MODEL = False  # choose whether to save best performing model on disk
ENCODER_SAVE_PATH = path.join(MODEL_DIR, 'encoder.pth')
DECODER_SAVE_PATH = path.join(MODEL_DIR, 'decoder.pth')

In order to help the model to learn, we employ teacher forcing, whereby we feed the target labels as input at each time step to the decoder. Initially, we feed the target labels frequently. However, as the model continues to train, we decrease our use of teacher forcing. The teacher forcing probabilities will be given by the inverse sigmoid function and teacher forcing decisions are made at each time step following, Bengio et al. Scheduled Sampling for Sequence Prediction with
Recurrent Neural Networks, 2015.

In [None]:
NEPOCH = 30000  # number of epochs
# k_val is the rate of convergence of teacher forcing probability to 0.
# initially, choose high k_val, but on subsequent runs of training cell, choose low k_val
k_val = 2000

# Set teacher forcing probabilities
x_list = np.linspace(1, NEPOCH + 1, num=NEPOCH, dtype=np.int)
tf_probs = k_val / (np.exp(x_list / k_val) + k_val)
draws = torch.rand(NEPOCH, dataset.max_len)
choices = torch.ByteTensor(NEPOCH, dataset.max_len)

with ProcessPoolExecutor() as executer:
    for i in range(NEPOCH):
        choices[i] = torch.le(draws[i], tf_probs[i])

plot(x_list, tf_probs, 'epoch', 'teacher forcing probability')

Run the following cell to train sequence to sequence model.
Interrupt Kernal at any time to stop training before completion.
Run cell again to resume training.

In [None]:
TRAIN_BATCH = 32  # batch size for training
# volatile flag allows greater validation set batch size
if (len(valid_sampler) // (TRAIN_BATCH * 20)) >= 1:
    VALID_BATCH = TRAIN_BATCH * 20
else:
    VALID_BATCH = len(valid_sampler)

train_batches = DataLoader(dataset, batch_size=TRAIN_BATCH,
                           sampler=train_sampler, collate_fn=collate_fn,
                           pin_memory=USE_CUDA,
                           drop_last=True)
valid_batches = DataLoader(dataset, batch_size=VALID_BATCH,
                           sampler=valid_sampler, collate_fn=collate_fn,
                           pin_memory=USE_CUDA, drop_last=True)
ntrain_batches = len(train_batches)
nvalid_batches = len(valid_batches)

# Setting up information to print and plot
plot_every = 50  # plot one data point per plot_every training losses
print_every = 50  # frequency of printing loss information
train_loss_plot, train_loss_print = 0, 0
train_losses, valid_losses, epoch_list = [], [], []

try:
    encoder.load_state_dict(best_encoder)
    decoder.load_state_dict(best_decoder)
    print("Loaded up weights from previous training.")
except NameError:
    print("First run of this training cell.")

try:
    for epoch in tqdm_notebook(range(1, NEPOCH + 1), unit=' epochs'):
        choice = choices[epoch-1]
        if encoder.is_cuda():
            choice = choice.cuda()
        train_loss = 0
        train_loss += train(encoder, decoder, TRAIN_BATCH, train_batches,
                            e_opt, d_opt, dataset, choice, loss)
        train_loss /= ntrain_batches
        del choice
        train_loss_plot += train_loss
        train_loss_print += train_loss
        if epoch % plot_every == 0:
            # Calculate validation loss
            val_loss = evaluate(encoder, decoder, VALID_BATCH,
                                valid_batches, dataset, loss)
            val_loss /= nvalid_batches
            # update learning rate scheduler
            e_sched.step(val_loss)
            d_sched.step(val_loss)
            # Update losses to plot
            train_avg = train_loss_plot / plot_every
            train_losses.append(train_avg)
            valid_losses.append(val_loss)
            epoch_list.append(epoch)
            train_loss_plot = 0

            # Copy model if validation loss is the best so far
            if best_val_loss > val_loss and epoch > (NEPOCH // 2):
                best_val_loss = val_loss
                best_encoder = encoder.state_dict()
                best_decoder = decoder.state_dict()

        if epoch % print_every == 0:
            train_avg = train_loss_print / print_every
            valid_loss_print = valid_losses[-1] if len(valid_losses) > 0 else 0
            print('Epoch: %d  Avg Training loss: %.4f, Avg Validation loss %.4f'
                  % (epoch, train_avg, valid_loss_print))
            train_loss_print = 0

except KeyboardInterrupt:
    print("Training stopped. Run cell again to continue training.")

if len(epoch_list) > 5:
    plot(epoch_list, train_losses, 'epoch', 'training loss')
    plot(epoch_list, valid_losses, 'epoch', 'validation loss')

if SAVE_MODEL:
    torch.save(best_encoder, ENCODER_SAVE_PATH)
    torch.save(best_decoder, DECODER_SAVE_PATH)

## Sample responses from Model

In [None]:
# Load best performing model so far
try:
    encoder.load_state_dict(best_encoder)
    decoder.load_state_dict(best_decoder)
except NameError:  # training cell was not run previously
    try:
        encoder.load_state_dict(torch.load(ENCODER_SAVE_PATH))
        decoder.load_state_dict(torch.load(DECODER_SAVE_PATH))
    except FileNotFoundError:
        print("No model found. Recommend running the training segment.")

In [None]:
# Get random subset of test data
test_batches = DataLoader(dataset, batch_size=5, sampler=test_sampler,
                          pin_memory=True, drop_last=True)
test_batch = next(iter(test_batches))
test_input, _, input_lens = test_batch

for i in range(test_input.size()[0]):
    response = respond(encoder, decoder, test_input[i], dataset, [input_lens[i]])
    # get human readable words from input embedding indexes
    input_line = get_input(test_input[i], dataset)
    print("Input: %s" % input_line)
    print("Response: %s\n" % response)

# Chat with Model

In [None]:
text = widgets.Text()
display(text)


def chat(sender):
    try:
        output = respond(encoder, decoder, text.value, dataset)
        if len(output) == 0:
            output = "Sorry, I didn't get you. Let's about something else."
    except UserInputTooLongError:
        output = "Calm down, please talk a bit slower."
    print("You: %s" % text.value)
    print("Bot: %s" % output)

text.on_submit(chat)

# Unsupervised Sentence Generation with a VAE
So far, we have worked with a supervised generative model, where we have provided target labels to train the encoder-decoder model. We can also consider an unsupervised encoder-decoder model. With an autoencoder, we replace the target labels with the inputs and train the model so that the model produces good reconstruction of the input.

We will implement a Variational Autoencoder (VAE), where we have the encoder learn latent representations of sentences, which can be used to generate sentences with different latent attributes such as writing style or topic. This is in contrast to a regular autoencoder, which does dimensionality reduction (assuming model capcity is limited).

In [None]:
import pyro
from pyro.infer import SVI
from pyro.optim import ClippedAdam
import pyro.distributions as dist
from pyro.util import ng_zeros, ng_ones
import pdb


class Encoder(RNN):
    def __init__(self, input_size, hidden_size, nlayers, embed_dim, rnn_type,
                 pad_idx, use_cuda, dropout, bidirect=False, z_dim=50):
        super().__init__(input_size, hidden_size, nlayers, embed_dim, rnn_type,
                         pad_idx, use_cuda, dropout, bidirect)
        # Linear map to create mean vector
        self.fc21 = nn.Linear(hidden_size, z_dim)
        # Linear map to create standard deviation vector
        self.fc22 = nn.Linear(hidden_size, z_dim)
        # Nonlinearity to ensure standard deviation is positive
        self.softplus = nn.Softplus()
        self.init_weights()

    def init_weights(self):
        super().init_weights()
        self.fc21.weight.data.uniform_(-0.05, 0.05)
        self.fc22.weight.data.uniform_(-0.05, 0.05)

    def forward(self, input):
        batch_size = input.size()[0]
        hidden = self.init_hidden(batch_size)
        embedded = self.embedding(input)
        _, hidden = self.rnn(embedded, hidden)

        hidden = self.softplus(hidden.squeeze(0))
        z_mu = self.fc21(hidden)
        z_sigma = torch.exp(self.fc22(hidden))
        return z_mu, z_sigma


class Decoder(DecoderRNN):
    def __init__(self, input_size, hidden_size, nlayers, embed_dim, rnn_type,
                 pad_idx, use_cuda, dropout, bidirect=False, dataset=None):
        super().__init__(input_size, hidden_size, nlayers, embed_dim,
                         rnn_type, pad_idx, use_cuda, dropout)
        self.softmax = nn.Softmax()  # need valid probabilities for dist.categorical
        self.dataset = dataset

    def forward(self, z):
        batch_size = z.size()[0]
        output = Variable(self.dataset.eos_tensor(batch_size, self.is_cuda()),
                          requires_grad=False)
        if self.rnn_type == 'LSTM':
            hidden = self.init_hidden(batch_size)
            hidden[0].data = z.unsqueeze(0)
        elif self.rnn_type == 'GRU':
            hidden = z.unsqueeze(0)

        mu_word_list = []

        for i in range(self.dataset.max_len):
            output = self.embedding(output).unsqueeze(1)
            output, hidden = self.rnn(output, hidden)

            output = self.linear(output[:, 0, :])
            output = self.softmax(output)
            mu_word_list.append(output)
            _, idx = torch.max(output, 1)
            output = idx

        return torch.stack(mu_word_list, 1)


class VAE(nn.Module):
    def __init__(self, input_size, embed_dim, z_dim=50, hidden_size=256,
                 use_cuda=False, dropout=0):
        super().__init__()
        self.encoder = Encoder(input_size, hidden_size, 1, embed_dim, 'GRU',
                               0, use_cuda, dropout, False, z_dim)
        self.decoder = Decoder(input_size, z_dim, 1, embed_dim, 'GRU',
                               0, use_cuda, dropout, False, dataset)
        if use_cuda:
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim

    def model(self, input, annealing_factor=1.0):
        pyro.module("decoder", self.decoder)
        z_mu = ng_zeros([input.size(0), self.z_dim], type_as=input.data)
        z_sigma = ng_ones([input.size(0), self.z_dim], type_as=input.data)

        z = pyro.sample("latent", dist.normal, z_mu.float(), z_sigma.float())
        mu_sentence = self.decoder.forward(z)
        inp = Variable(one_hot_conversion(input.data,
                                          self.decoder.dataset.nwords, self.use_cuda))
        pyro.sample("obs", dist.categorical, mu_sentence, obs=inp,
                    log_pdf_mask=annealing_factor)

    def guide(self, input, annealing_factor=1.0):
        pyro.module("encoder", self.encoder)
        z_mu, z_sigma = self.encoder.forward(input)
        pyro.sample("latent", dist.normal, z_mu, z_sigma, annealing_factor)

    def reconstruct_sentence(self, input):
        z_mu, z_sigma = self.encoder(input)

        z = dist.normal(z_mu, z_sigma)

        mu_sentence = self.decoder(z)
        return mu_sentence

    def model_sample(self, batch_size=1):
        prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
        prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
        if self.use_cuda:
            prior_mu, prior_sigma = prior_mu.cuda(), prior_sigma.cuda()
        zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma)
        mu = self.decoder.forward(zs)
        xs = pyro.sample("sample", dist.categorical, mu)
        return xs, mu


def one_hot_conversion(batch, input_size, use_cuda):
        """
        Convert tensor batch to one hot encoding tensor of length input_size.

        Dimensions:
            batch: batch size * sequence length
        """
        result = torch.ByteTensor(batch.size()[0], batch.size()[1], input_size).zero_()
        if use_cuda:
            result = result.cuda()
        result.scatter_(2, batch.unsqueeze(2), 1)
        return result

At the beginning of training, we want to ease on the regularization. So we set an annealing factor between 0 and 1. We start with 0 and gradually increase the annealing factor to 1. We use a linear function to schedule this increase.

In [None]:
TRAIN_BATCH = 2  # batch size for training
NEPOCH = 5000
EVAL_EVERY = 5
# volatile flag allows greater validation set batch size
if (len(valid_sampler) // (TRAIN_BATCH * 20)) >= 1:
    VALID_BATCH = TRAIN_BATCH * 20
else:
    VALID_BATCH = len(valid_sampler)
ANNEALING_END = 4000  # epoch when annealing defaults back to 1

train_batches = DataLoader(dataset, batch_size=TRAIN_BATCH,
                           sampler=train_sampler, collate_fn=collate_fn,
                           pin_memory=USE_CUDA,
                           drop_last=True)
valid_batches = DataLoader(dataset, batch_size=VALID_BATCH,
                           sampler=valid_sampler, collate_fn=collate_fn,
                           pin_memory=USE_CUDA, drop_last=True)
vae = VAE(dataset.nwords, dataset.embed_dim, 128, 256, USE_CUDA)
adam_args = {"lr": LEARNING_RATE, "clip_norm": 5}
optimizer = ClippedAdam(adam_args)
svi = SVI(vae.model, vae.guide, optimizer, loss='ELBO')

train_elbo = []
valid_elbo = []

try:
    for epoch in tqdm_notebook(range(1, NEPOCH + 1), unit=' epochs'):
        epoch_loss = 0.
        # Set annealing_factor
        if ANNEALING_END > 0 and epoch < ANNEALING_END:
            annealing_factor = epoch / ANNEALING_END
        else:
            annealing_factor = 1.0

        for lines, _, _ in train_batches:
            if USE_CUDA:
                lines = lines.cuda()
            lines = Variable(lines.long(), requires_grad=False)
            epoch_loss += svi.step(lines, 1.0)

        normalizer_train = len(train_sampler)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train))

        if epoch % EVAL_EVERY == 0:
            valid_loss = 0.
            for i, (lines, _, _) in enumerate(valid_batches):
                if USE_CUDA:
                    lines = lines.cuda()
                lines = Variable(lines.long(), volatile=True)
                valid_loss += svi.evaluate_loss(lines, annealing_factor)

                if i == 0:
                    # print VAE samples
                    _, sample_mu = vae.model_sample()
                    sample_mu = sample_mu.squeeze(0)
                    response = []
                    for i in range(dataset.max_len):
                        _, idx = torch.max(sample_mu, 1)
                        response_word = dataset.vocab.itos[idx.data[0]]
                        if response_word == dataset.eos_token:
                            break
                        else:
                            response.append(response_word)
                    print("%s" % ' '.join(response))

            normalizer_valid = len(valid_sampler)
            total_epoch_loss_valid = valid_loss / normalizer_valid
            valid_elbo.append(total_epoch_loss_valid)
            print("[epoch %03d] average validation loss: %.4f"
                  % (epoch, total_epoch_loss_valid))

except KeyboardInterrupt:
    print("Training stopped. Run cell again to continue training.")