In [1]:
import os
import pdb
from IPython.core.debugger import set_trace as bp
import argparse
import pickle as pkl

from collections import defaultdict

import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

from easydict import EasyDict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Local imports
import utils
import data_handler
from data_handler import *
from torch.distributions import Categorical

In [2]:
# Options/ Hyperparameters required to train and test the model
opts = EasyDict()

opts.n_epochs = 50
opts.batch_size = 16
opts.learning_rate = 0.005
opts.lr_decay = 0.99
opts.hidden_layer_size = 100
opts.model_name = "simple_rnn"
opts.checkpoints_dir = "./checkpoints/"+opts.model_name 
opts.temp = 0.4
TEST_SENTENCE = 'i love deep learning'

In [3]:
utils.create_dir_if_not_exists(opts.checkpoints_dir)

In [4]:
line_pairs, vocab_size, idx_dict = load_data()

In [5]:
for key in (idx_dict.keys()):
    print(key)
    print(idx_dict[key])

char_to_index
{'-': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, 'SOS': 27, 'EOS': 28}
index_to_char
{0: '-', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 27: 'SOS', 28: 'EOS'}
start_token
27
end_token
28


In [6]:
# dividing the line pairs into 8:2, train and val split
num_lines = len(line_pairs)
num_train = int(0.8 * num_lines)
train_pairs, val_pairs = line_pairs[:num_train], line_pairs[num_train:]

In [7]:
train_dict = create_dict(train_pairs)
val_dict = create_dict(val_pairs)

# Study the structure of the created train_dict and val_dict variables

In [8]:
for key in train_dict.keys():
    print("Pairs of length {}: {}".format(key,len(train_dict[key])))

Pairs of length (13, 17): 16
Pairs of length (9, 12): 218
Pairs of length (11, 14): 107
Pairs of length (9, 11): 406
Pairs of length (7, 10): 159
Pairs of length (11, 13): 153
Pairs of length (5, 7): 463
Pairs of length (8, 10): 552
Pairs of length (6, 8): 574
Pairs of length (2, 5): 16
Pairs of length (7, 9): 648
Pairs of length (12, 15): 70
Pairs of length (4, 6): 368
Pairs of length (6, 11): 1
Pairs of length (8, 11): 157
Pairs of length (13, 16): 33
Pairs of length (12, 14): 88
Pairs of length (5, 8): 80
Pairs of length (1, 4): 2
Pairs of length (10, 12): 308
Pairs of length (14, 16): 16
Pairs of length (3, 5): 109
Pairs of length (10, 13): 156
Pairs of length (4, 7): 33
Pairs of length (14, 21): 1
Pairs of length (6, 9): 142
Pairs of length (1, 3): 4
Pairs of length (15, 18): 4
Pairs of length (2, 2): 1
Pairs of length (13, 18): 5
Pairs of length (15, 19): 3
Pairs of length (13, 15): 35
Pairs of length (3, 6): 27
Pairs of length (12, 17): 6
Pairs of length (16, 22): 1
Pairs of len

In [9]:
# Implement your own LSTM cell. A sample class definition is given to you.

class MyLSTMCell():
    def __init__(self, input_size, hidden_size):
        super(MyLSTMCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)
                        
    def forward(self, x, h_prev, c_prev):
        """Forward pass of the LSTM computation for one time step.

        Arguments
            x: batch_size x input_size
            h_prev: batch_size x hidden_size
            c_prev: batch_size x hidden_size
        Returns:
            h_new: batch_size x hidden_size
            c_new: batch_size x hidden_size
        """
        x_combined = torch.cat((x, h_prev), 1)

        f = torch.sigmoid(self.W_f(x_combined))
        i = torch.sigmoid(self.W_i(x_combined))
        o = torch.sigmoid(self.W_o(x_combined))
        c_dash = torch.tanh(self.W_c(x_combined))
        
        c_new = f*c_prev + i*c_dash
        h_new = o*torch.tanh(c_new)
        
        return h_new, c_new
    

class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(Encoder, self).__init__()

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
        self.embeddings = nn.Embedding(vocab_size, hidden_size)
        self.LSTMCell = MyLSTMCell(hidden_size, hidden_size)

    def forward(self, inputs):
        """Forward pass of the encoder RNN.

        Arguments:
            inputs: Input token indexes across a batch for all time steps in the sequence. (batch_size x seq_len)

        Returns:
            annotations: The hidden states computed at each step of the input sequence. (batch_size x seq_len x hidden_size)
            hidden: The final hidden state of the encoder, for each sequence in a batch. (batch_size x hidden_size)
        """

        batch_size, seq_len = inputs.size()
        hidden = self.init_hidden(batch_size)
        cell = self.init_hidden(batch_size)
        annotations = torch.empty((batch_size, seq_len, self.hidden_size), dtype=hidden.dtype, device=hidden.device)
        
        # The encoded embeddings should be of size batch_size x seq_len x hidden_size        
        encoded_embeddings = self.embeddings(inputs)
        
        # Loop over the time-steps do forward pass through LSTM cell at each ti
        for i in range(seq_len):
            hidden,cell = self.LSTMCell.forward(encoded_embeddings[:,i,:], hidden, cell)
            annotations[:,i,:] = hidden
            
        return annotations, hidden

    def init_hidden(self, bs):
        """Creates a tensor of zeros to represent the initial hidden states
        of a batch of sequences.

        Arguments:
            bs: The batch size for the initial hidden state.

        Returns:
            hidden: An initial hidden state of all zeros. (batch_size x hidden_size)
        """
        return torch.zeros(bs, self.hidden_size)


# Implement your Decoder RNN using instances of LSTM Cell you just created.
# You would need a character embedding layer for this. 
# In addition you would also require an activation function applied to the output of the LSTM Cell

class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size

        self.embeddings = nn.Embedding(vocab_size, hidden_size)
        self.LSTMCell = MyLSTMCell(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h_prev, c_prev):
        """Forward pass of the decoder RNN.

        Arguments:
            x: Input token indexes across a batch for a single time step. (batch_size x 1)
            h_prev: The hidden states from the previous step, across a batch. (batch_size x hidden_size)
            c_prev: The cell states from the previous step, across a batch. (batch_size x hidden_size)

        Returns:
            output: Un-normalized scores for each token in the vocabulary, across a batch. (batch_size x vocab_size)
            h_new: The new hidden states, across a batch. (batch_size x hidden_size)
            c_new: The new cell states, across a batch. (batch_size x hidden_size)
        """
    
        encoded_embeddings = self.embeddings(x).squeeze(1)
        h_new, c_new = self.LSTMCell.forward(encoded_embeddings, h_prev, c_prev)
        output = self.output(h_new)
        
        return output, h_new, c_new

In [10]:
##########################################################################
### Setup: Create Encoder, Decoder Objects ###
##########################################################################
encoder = Encoder(vocab_size=vocab_size, hidden_size=opts.hidden_layer_size)
decoder = Decoder(vocab_size=vocab_size, hidden_size=opts.hidden_layer_size)

In [11]:
def train_model(train_dict, val_dict, idx_dict, encoder, decoder, opts):
    """Runs the main training loop; evaluates the model on the val set every epoch.
        * Prints training and val loss each epoch.
        * Prints qualitative translation results each epoch using TEST_SENTENCE

    Arguments:
        train_dict: The training word pairs, organized by source and target lengths.
        val_dict: The validation word pairs, organized by source and target lengths.
        idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes.
        encoder: An encoder model to produce annotations for each step of the input sequence.
        decoder: A decoder model to generate output tokens.
        opts: The input arguments for hyper-parameters and others.
    """
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(list(encoder.parameters())+list(decoder.parameters()) ,lr=opts.learning_rate)
    
    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']
    char_to_index = idx_dict['char_to_index']

    loss_log = open(os.path.join(opts.checkpoints_dir, 'loss_log.txt'), 'w')

    best_val_loss = 1e6
    train_losses = []
    val_losses = []

    for epoch in range(opts.n_epochs):
        
        # decay the learning rate of the optimizer
        optimizer.param_groups[0]['lr'] *= opts.lr_decay

        epoch_losses = []

        for key in train_dict:
            input_strings, target_strings = zip(*train_dict[key])
            
            input_tensors, output_tensors = [],[]
            for s in input_strings:
                input_tensors.append(torch.LongTensor(string_to_index_list(s, char_to_index, end_token)))
            
            for s in target_strings:
                output_tensors.append(torch.LongTensor(string_to_index_list(s, char_to_index, end_token)))
                
            num_tensors = len(input_tensors)
            num_batches = int(np.ceil(num_tensors / float(opts.batch_size)))
            
            for i in range(num_batches):
                start = i * opts.batch_size
                end = start + opts.batch_size

                # Define inputs and targets for THIS batch, beginning at index 'start' to 'end'
                inputs = torch.stack(input_tensors[start:end])
                # outputs = ....
                targets = torch.stack(output_tensors[start:end])
            
                
                #print ("targets", targets.size(), "\n" ,targets)
                # The batch size may be different in each epoch
                BS = inputs.size(0)

                encoder_annotations, encoder_hidden = encoder.forward(inputs)

                # The last hidden state of the encoder becomes the first hidden state of the decoder
                decoder_hidden = encoder_hidden
                decoder_cell = torch.zeros(BS, encoder.hidden_size)

                # Define the first decoder input. This would essentially be the start_token
                decoder_input = start_token * torch.ones(BS,1).long()

                loss = 0.0

                seq_len = targets.size(1)  # Gets seq_len from BS x seq_len
                
                for si in range(seq_len):
                    decoder_output, decoder_hidden, decoder_cell = decoder.forward(decoder_input, decoder_hidden, decoder_cell)

                    current_target = targets[:,si]
                    # Calculate the cross entropy between the decoder distribution and Ground truth (current_target)
                    loss += criterion(decoder_output, current_target)
                    
                    decoder_input = targets[:,si].unsqueeze(1)

                loss /= float(seq_len)
                epoch_losses.append(loss.item())
                
                # Compute gradients
                loss.backward()

                # Update the parameters of the encoder and decoder
                optimizer.step()

        train_loss = np.mean(epoch_losses)
        val_loss = evaluate(val_dict, encoder, decoder, idx_dict, criterion, opts)

        if val_loss < best_val_loss:
            utils.store_checkpoints(encoder, decoder, idx_dict, opts)

        gen_string = find_pig_latin(TEST_SENTENCE, encoder, decoder, idx_dict, opts)
        print(gen_string)
        print("Epoch: {:3d} | Train loss: {:.3f} | Val loss: {:.3f} | Gen: {:20s}".format(epoch, train_loss, val_loss, gen_string))

        loss_log.write('{} {} {}\n'.format(epoch, train_loss, val_loss))
        loss_log.flush()

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        utils.store_loss_plots(train_losses, val_losses, opts)


In [12]:
def evaluate(data_dict, encoder, decoder, idx_dict, criterion, opts):
    """Evaluates the model on a held-out validation or test set. 
    Arguments:
        data_dict: The validation/test word pairs, organized by source and target lengths.
        encoder: An encoder model to produce annotations for each step of the input sequence.
        decoder: A decoder model to generate output tokens.
        idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes.
        criterion: Used to compute the CrossEntropyLoss for each decoder output.
        opts: The command-line arguments.

    Returns:
        mean_loss: The average loss over all batches from data_dict.
    """

    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']
    char_to_index = idx_dict['char_to_index']

    losses = []

    for key in data_dict:

        input_strings, target_strings = zip(*data_dict[key])
        
        input_tensors, output_tensors = [],[]
        for s in input_strings:
            input_tensors.append(torch.LongTensor(string_to_index_list(s, char_to_index, end_token)))
            
        for s in target_strings:
            output_tensors.append(torch.LongTensor(string_to_index_list(s, char_to_index, end_token)))
            
        num_tensors = len(input_tensors)
        num_batches = int(np.ceil(num_tensors / float(opts.batch_size)))

        for i in range(num_batches):
            start = i * opts.batch_size
            end = start + opts.batch_size

            # Define inputs and targets for THIS batch, beginning at index 'start' to 'end'
            inputs = torch.stack(input_tensors[start:end])
            # outputs = ....
            targets = torch.stack(output_tensors[start:end])
            
            # The batch size may be different in each epoch
            BS = inputs.size(0)

            encoder_annotations, encoder_hidden = encoder.forward(inputs)
            
            # The last hidden state of the encoder becomes the first hidden state of the decoder
            decoder_hidden = encoder_hidden
            decoder_cell = torch.zeros(BS, encoder.hidden_size)

            # Define the first decoder input. This would essentially be the start_token
            decoder_input = start_token * torch.ones(BS,1).long()

            loss = 0.0

            seq_len = targets.size(1)  # Gets seq_len from BS x seq_len
            #print("seq len", seq_len)
            
            for i in range(seq_len):
                decoder_output, decoder_hidden, decoder_cell = decoder.forward(decoder_input, decoder_hidden, decoder_cell)

                current_target = targets[:,i]

                # Calculate the cross entropy between the decoder distribution and Ground truth (current_target)
                loss += criterion(decoder_output, current_target)
                #bp()
                # Find out the most probable character (ni) from the softmax distribution produced
                ni = F.softmax(decoder_output, dim=1).max(1)
                
                # Update decoder_input at the next time step to be this time-step's target 
                decoder_input = targets[:,i].unsqueeze(1)

            loss /= float(seq_len)
            losses.append(loss.item())

    mean_loss = np.mean(losses)

    return mean_loss


In [13]:
def find_pig_latin(sentence, encoder, decoder, idx_dict, opts):
    """Translates a sentence from English to Pig-Latin, by splitting the sentence into
    words (whitespace-separated), running the encoder-decoder model to translate each
    word independently, and then stitching the words back together with spaces between them.
    """
    return ' '.join([translate(word, encoder, decoder, idx_dict, opts) for word in sentence.split()])


def translate(input_string, encoder, decoder, idx_dict, opts):
    """Translates a given string from English to Pig-Latin.
    """

    char_to_index = idx_dict['char_to_index']
    index_to_char = idx_dict['index_to_char']
    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']
    
    #print("end token", end_token)
    
    max_generated_chars = 20
    gen_string = ''
    
    # convert given string to an array of indexes
    indexes = torch.LongTensor(string_to_index_list(input_string, char_to_index, end_token)).unsqueeze(0)

    encoder_annotations, encoder_last_hidden = encoder.forward(indexes)

    # The last hidden state of the encoder becomes the first hidden state of the decoder
    decoder_hidden = encoder_last_hidden
    decoder_cell = torch.zeros(1, encoder.hidden_size)

    # Define the first decoder input. This would essentially be the start_token
    decoder_input = start_token * torch.ones(1,1).long()
    
    for i in range(max_generated_chars):
        decoder_output, decoder_hidden, decoder_cell = decoder.forward(decoder_input, decoder_hidden, decoder_cell)
        
        # Calculate the cross entropy between the decoder distribution and Ground truth (current_target)
        # loss += criterion(decoder_output, current_target)
        
        # Find out the most probable character (ni) from the softmax distribution produced
        ni = F.softmax(decoder_output, dim=1).argmax(1)
        
        if int(ni) == end_token:
            break
        else:
            gen_string += index_to_char[ni.item()]
            
            # update decoder_input at the next time step to be ni 
            decoder_input = ni

    return gen_string

In [None]:
try:
    train_model(train_dict, val_dict, idx_dict, encoder, decoder, opts)
except KeyboardInterrupt:
    print('Exiting early from training.')

elway elway onday onday
Epoch:   0 | Train loss: 2.163 | Val loss: 2.304 | Gen: elway elway onday onday
arsway-ay-idday-idwa osterway-ay-idday-id arsway-ay-idday-idwa arsway-ay-idday-idwa
Epoch:   1 | Train loss: 2.120 | Val loss: 2.395 | Gen: arsway-ay-idday-idwa osterway-ay-idday-id arsway-ay-idday-idwa arsway-ay-idday-idwa
antay antay eallyway antay
Epoch:   2 | Train loss: 2.297 | Val loss: 2.498 | Gen: antay antay eallyway antay
ishay-ay-ay-oncay-ay allway-ay-ay-oncay-a elllyway-ay-ay-oncay andeway-ay-ay-oncay-
Epoch:   3 | Train loss: 2.149 | Val loss: 2.403 | Gen: ishay-ay-ay-oncay-ay allway-ay-ay-oncay-a elllyway-ay-ay-oncay andeway-ay-ay-oncay-
ellway owornsay ellway actioncay
Epoch:   4 | Train loss: 2.367 | Val loss: 2.471 | Gen: ellway owornsay ellway actioncay
ingray-oway-ay-odway ellway-ay-ayday-odwa ellway-ay-ayday-odwa ationway-ay-oway-ay-
Epoch:   5 | Train loss: 2.061 | Val loss: 2.260 | Gen: ingray-oway-ay-odway ellway-ay-ayday-odwa ellway-ay-ayday-odwa ationway-ay-o