In [15]:
import os
import pdb
import sys
# For interactive debugging
#from IPython.core.debugger import set_trace as bp

import argparse
import pickle as pkl

from collections import defaultdict

import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

from easydict import EasyDict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Categorical

# Local imports
import utils
import data_handler
from data_handler import *

import attn_vis

In [16]:
# Hyperparameters required to train and test the model
opts = EasyDict()

opts.n_epochs = 100
opts.batch_size = 16
opts.optimizer = "Adam" #"SGD"
opts.learning_rate = 0.002
opts.lr_decay = 0.99
opts.hidden_layer_size = 10
opts.model_name = "attention_rnn"
opts.checkpoints_dir = "./checkpoints/{}_{}_lr_{}_epochs_{}_num_hidden_{}_batchsize_{}"\
    .format(opts.model_name,opts.optimizer,opts.learning_rate,opts.n_epochs,opts.hidden_layer_size,opts.batch_size)


opts.device = 'cuda' if torch.cuda.is_available() else 'cpu'

TEST_SENTENCE = 'i love deep learning'
TEST_WORD_ATTENTION = "attention"

In [17]:
# Checkpoint Directory
print(opts.checkpoints_dir)

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

./checkpoints/attention_rnn_Adam_lr_0.002_epochs_100_num_hidden_10_batchsize_16


In [18]:
utils.create_dir_if_not_exists(opts.checkpoints_dir)

In [19]:
line_pairs, vocab_size, idx_dict = load_data()

In [20]:
for key in (idx_dict.keys()):
    print(key) 
    print(idx_dict[key])

char_to_index
{'-': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, 'SOS': 27, 'EOS': 28}
index_to_char
{0: '-', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 27: 'SOS', 28: 'EOS'}
start_token
27
end_token
28


In [21]:
# dividing the line pairs into 8:2, train and val split
num_lines = len(line_pairs)
num_train = int(0.8 * num_lines)
train_pairs, val_pairs = line_pairs[:num_train], line_pairs[num_train:]

In [22]:
train_dict = create_dict(train_pairs)
val_dict = create_dict(val_pairs)

In [23]:
for key in train_dict.keys():
#    print("Pairs of length {}: {}".format(key,len(train_dict[key])))
    pass
#train_dict[(3, 5)]

In [30]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, opts):
        super(Encoder, self).__init__()

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.opts = opts

        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lstm_cell = nn.LSTMCell(hidden_size, hidden_size)

    def forward(self, inputs):
        """Forward pass of the encoder RNN.

        Arguments:
            inputs: Input token indexes across a batch for all time steps in the sequence. (batch_size x seq_len)

        Returns:
            annotations: The hidden states computed at each step of the input sequence. (batch_size x seq_len x hidden_size)
            hidden: The final hidden state of the encoder, for each sequence in a batch. (batch_size x hidden_size)
        """

        batch_size, seq_len = inputs.size()
        hidden = self.init_hidden(batch_size)
        cell = self.init_hidden(batch_size)
        encoded = self.embedding(inputs)  # batch_size x seq_len x hidden_size
        annotations = []

        for i in range(seq_len):
            x = encoded[:,i,:]  # Get the current time step, across the whole batch
            hidden, cell = self.lstm_cell(x, (hidden, cell))
            annotations.append(hidden)

        annotations = torch.stack(annotations, dim=1)
        return annotations, hidden

    def init_hidden(self, bs):
        """Creates a tensor of zeros to represent the initial hidden states
        of a batch of sequences.

        Arguments:
            bs: The batch size for the initial hidden state.

        Returns:
            hidden: An initial hidden state of all zeros. (batch_size x hidden_size)
        """
        return Variable(torch.zeros(bs, self.hidden_size))


class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()

        self.hidden_size = hidden_size
        
        self.attention_network = nn.Linear(2*hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size,1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, hidden, annotations):
        """The forward pass of the attention mechanism.

        Arguments:
            hidden: The current decoder hidden state. (batch_size x hidden_size)
            annotations: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)

        Returns:
            output: Normalized attention weights for each encoder hidden state. (batch_size x seq_len x 1)

            The output must be a softmax weighting over the seq_len annotations.
        """

        batch_size, seq_len, hid_size = annotations.size()
        expanded_hidden = hidden.unsqueeze(1).expand_as(annotations)

        concat = torch.cat((expanded_hidden, annotations), dim = 2)
        reshaped_for_attention_net = concat.reshape(-1, 2*hid_size)
        attention_net_output = F.relu(self.attention_network(reshaped_for_attention_net))
        unnormalized_attention = self.output_layer(attention_net_output).reshape(batch_size, seq_len, 1)  

        return self.softmax(unnormalized_attention)


class AttentionDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(AttentionDecoder, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(vocab_size, hidden_size)

        self.lstm_cell = nn.LSTMCell(input_size=hidden_size*2, hidden_size=hidden_size)
        self.attention = Attention(hidden_size=hidden_size)
        self.out = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h_prev, c_prev, annotations):
        """Forward pass of the attention-based decoder RNN.

        Arguments:
            x: Input token indexes across a batch for a single time step. (batch_size x 1)
            h_prev: The hidden states from the previous step, across a batch. (batch_size x hidden_size)
            annotations: The encoder hidden states for each step of the input.
                         sequence. (batch_size x seq_len x hidden_size)

        Returns:
            output: Un-normalized scores for each token in the vocabulary, across a batch. (batch_size x vocab_size)
            h_new: The new hidden states, across a batch. (batch_size x hidden_size)
            attention_weights: The weights applied to the encoder annotations, across a batch. (batch_size x encoder_seq_len x 1)
        """
        embed = self.embedding(x)    # batch_size x 1 x hidden_size
        embed = embed.squeeze(1)     # batch_size x hidden_size

        attention_weights = self.attention.forward(h_prev, annotations).permute(0,2,1)
        context = torch.matmul(attention_weights, annotations).squeeze(1)
        embed_and_context = torch.cat((context, embed), axis=1)
        h_new, c_new = self.lstm_cell(embed_and_context)
        output = self.out(h_new)
        return output, h_new, c_new, attention_weights

In [31]:
encoder = Encoder(vocab_size=vocab_size, hidden_size=opts.hidden_layer_size, opts = opts)
decoder = AttentionDecoder(vocab_size=vocab_size, hidden_size=opts.hidden_layer_size)

In [36]:
def train_model(train_dict, val_dict, idx_dict, encoder, decoder, opts):
    """Runs the main training loop; evaluates the model on the val set every epoch.
        * Prints training and val loss each epoch.
        * Prints qualitative translation results each epoch using TEST_SENTENCE

    Arguments:
        train_dict: The training word pairs, organized by source and target lengths.
        val_dict: The validation word pairs, organized by source and target lengths.
        idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes.
        encoder: An encoder model to produce annotations for each step of the input sequence.
        decoder: A decoder model to generate output tokens.
        opts: The input arguments for hyper-parameters and others.
    """
    
    criterion = nn.CrossEntropyLoss()
    
    parameters = list(encoder.parameters())+list(decoder.parameters())
    
    if opts.optimizer == "RMSprop":
        optimizer = optim.RMSprop(parameters ,lr=opts.learning_rate)
    elif opts.optimizer == "Adam":
        optimizer = optim.Adam(parameters ,lr=opts.learning_rate)
    else:
        optimizer = optim.SGD(parameters, lr=opts.learning_rate, momentum=0.9)

    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']
    char_to_index = idx_dict['char_to_index']

    loss_log = open(os.path.join(opts.checkpoints_dir, 'loss_log.txt'), 'w')

    best_val_loss = 1e6
    train_losses = []
    val_losses = []

    for epoch in range(opts.n_epochs):
        
        # decay the learning rate of the optimizer
        optimizer.param_groups[0]['lr'] *= opts.lr_decay

        epoch_losses = []

        for key in train_dict:

            input_strings, target_strings = zip(*train_dict[key])
            
            input_tensors, output_tensors = [],[]
            for s in input_strings:
                input_tensors.append(torch.LongTensor(string_to_index_list(s, char_to_index, end_token)))

            for s in target_strings:
                output_tensors.append(torch.LongTensor(string_to_index_list(s, char_to_index, end_token)))
            
            num_tensors = len(input_tensors)
            num_batches = int(np.ceil(num_tensors / float(opts.batch_size)))

            for i in range(num_batches):

                start = i * opts.batch_size
                end = start + opts.batch_size

                # Inputs and targets for THIS batch, beginning at index 'start' to 'end'
                inputs = torch.stack(input_tensors[start:end]).to(opts.device)
                targets = torch.stack(output_tensors[start:end]).to(opts.device)
               
                # The batch size may be different in each epoch
                BS = inputs.size(0)

                encoder_annotations, encoder_hidden = encoder.forward(inputs)

                # The last hidden state of the encoder becomes the first hidden state of the decoder
                decoder_hidden = encoder_hidden.to(opts.device)
                decoder_cell = encoder_hidden.to(opts.device)
                # First decoder input. This is the essentially be the start_token
                decoder_input = start_token * torch.ones(BS,1).long().to(opts.device)

                loss = 0.0

                seq_len = targets.size(1)  # Gets seq_len from BS x seq_len

                for i in range(seq_len):
                    decoder_output, decoder_hidden, decoder_cell, attention_weights = decoder.forward(decoder_input, \
                                                                                        decoder_hidden, \
                                                                                        decoder_cell, \
                                                                                        encoder_annotations)

                    current_target = targets[:,i]
                    
                    # Calculate the cross entropy between the decoder distribution and Ground truth (current_target)
                    loss += criterion(decoder_output, current_target)

                    decoder_input = targets[:,i].unsqueeze(1)

                loss /= float(seq_len)
                epoch_losses.append(loss.item())

                # Compute gradients
                loss.backward()

                # Update the parameters of the encoder and decoder
                optimizer.step()
                optimizer.zero_grad()

        train_loss = np.mean(epoch_losses)
        val_loss = evaluate(val_dict, encoder, decoder, idx_dict, criterion, opts)

        if val_loss < best_val_loss:
            utils.store_checkpoints(encoder, decoder, idx_dict, opts)
            
        attention_dir = os.path.join(opts.checkpoints_dir,'train_attns')
        utils.create_dir_if_not_exists(attention_dir)
        
        attn_vis.visualize_attention(TEST_WORD_ATTENTION,encoder,decoder,idx_dict,opts,
                                      save=attention_dir+"/attn-epoch-{}.png".format(epoch))

        gen_string = find_pig_latin(TEST_SENTENCE, encoder, decoder, idx_dict, opts)
        print("Epoch: {:3d} | Train loss: {:.3f} | Val loss: {:.3f} | Gen: {:20s}".format(epoch, train_loss, val_loss, gen_string))

        loss_log.write('{} {} {}\n'.format(epoch, train_loss, val_loss))
        loss_log.flush()

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        utils.store_loss_plots(train_losses, val_losses, opts)


In [37]:
def evaluate(data_dict, encoder, decoder, idx_dict, criterion, opts):
    """Evaluates the model on a held-out validation or test set. 
    This should be pretty straight-forward if you have figured out how to do the training correctly.
    From then, it's just copy and paste.

    Arguments:
        data_dict: The validation/test word pairs, organized by source and target lengths.
        encoder: An encoder model to produce annotations for each step of the input sequence.
        decoder: A decoder model to generate output tokens.
        idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes.
        criterion: Used to compute the CrossEntropyLoss for each decoder output.
        opts: The command-line arguments.

    Returns:
        mean_loss: The average loss over all batches from data_dict.
    """

    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']
    char_to_index = idx_dict['char_to_index']

    losses = []

    for key in data_dict:

        input_strings, target_strings = zip(*data_dict[key])

        input_tensors, output_tensors = [],[]
        for s in input_strings:
            input_tensors.append(torch.LongTensor(string_to_index_list(s, char_to_index, end_token)))

        for s in target_strings:
            output_tensors.append(torch.LongTensor(string_to_index_list(s, char_to_index, end_token)))

        
        num_tensors = len(input_tensors)
        num_batches = int(np.ceil(num_tensors / float(opts.batch_size)))

        for i in range(num_batches):

            start = i * opts.batch_size
            end = start + opts.batch_size

            # Inputs and targets for THIS batch, beginning at index 'start' to 'end'
            inputs = torch.stack(input_tensors[start:end]).to(opts.device)
            targets = torch.stack(output_tensors[start:end]).to(opts.device)

            # The batch size may be different in each epoch
            BS = inputs.size(0)

            encoder_annotations, encoder_hidden = encoder.forward(inputs)
            
            decoder_hidden = encoder_hidden.to(opts.device)
            decoder_cell = encoder_hidden.to(opts.device)
            # First decoder input. This is the essentially be the start_token
            decoder_input = start_token * torch.ones(BS,1).long().to(opts.device)
            
            loss = 0.0

            seq_len = targets.size(1)  # Gets seq_len from BS x seq_len

            for i in range(seq_len):
                decoder_output, decoder_hidden, decoder_cell, attention_weights = decoder.forward(decoder_input,\
                                                                                    decoder_hidden,\
                                                                                    decoder_cell, \
                                                                                    encoder_annotations)

                current_target = targets[:,i]

                # Calculate the cross entropy between the decoder distribution and Ground truth (current_target)
                loss += criterion(decoder_output, current_target)

                # Find out the most probable character (ni) from the softmax distribution produced
                ni = F.softmax(decoder_output, dim=1).argmax(1)
                
                # Update decoder_input at the next time step to be this time-step's target 
                decoder_input = ni

            loss /= float(seq_len)
            losses.append(loss.item())

    mean_loss = np.mean(losses)

    return mean_loss


In [38]:
def find_pig_latin(sentence, encoder, decoder, idx_dict, opts):
    """Translates a sentence from English to Pig-Latin, by splitting the sentence into
    words (whitespace-separated), running the encoder-decoder model to translate each
    word independently, and then stitching the words back together with spaces between them.
    """
    return ' '.join([translate(word, encoder, decoder, idx_dict, opts) for word in sentence.split()])


def translate(input_string, encoder, decoder, idx_dict, opts):
    """Translates a given string from English to Pig-Latin.
    Not much to do here as well. Follows basically the same structure as that of the function evaluate.
    """

    char_to_index = idx_dict['char_to_index']
    index_to_char = idx_dict['index_to_char']
    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']

    max_generated_chars = 20
    gen_string = ''

    # convert given string to an array of indexes
    indexes = torch.LongTensor(string_to_index_list(input_string, char_to_index, end_token)).unsqueeze(0)
    indexes = indexes.to(opts.device)

    encoder_annotations, encoder_last_hidden = encoder.forward(indexes)

    # The last hidden state of the encoder becomes the first hidden state of the decoder
    decoder_hidden = encoder_last_hidden.to(opts.device)
    decoder_cell = encoder_last_hidden.to(opts.device)

    # Define the first decoder input. This would essentially be the start_token
    decoder_input = start_token * torch.ones(1,1).long().to(opts.device)

    for i in range(max_generated_chars):
        decoder_output, decoder_hidden, decoder_cell, attention_weights = decoder.forward(decoder_input,\
                                                                            decoder_hidden,\
                                                                            decoder_cell, \
                                                                            encoder_annotations)
    
        # Find out the most probable character (ni) from the softmax distribution produced
        ni = F.softmax(decoder_output, dim=1).argmax(1)
        #ni = F.softmax(decoder_output, dim=1).multinomial(num_samples = 1)
        #ni = Categorical(F.softmax(decoder_output, dim=1)).sample()

        if ni == end_token:
            break
        else:
            gen_string += index_to_char[ni.item()]
            
            # update decoder_input at the next time step to be ni 
            decoder_input = ni

    return gen_string

In [39]:
try:
    train_model(train_dict, val_dict, idx_dict, encoder, decoder, opts)
except KeyboardInterrupt:
    print('Exiting early from training.')

Epoch:   0 | Train loss: 2.162 | Val loss: 3.562 | Gen: onay onay ay ay     
Epoch:   1 | Train loss: 2.063 | Val loss: 3.503 | Gen: ongingingingingingin ongay intentententententen ay
Epoch:   2 | Train loss: 1.978 | Val loss: 3.586 | Gen: ongay ongay eday eday
Epoch:   3 | Train loss: 1.904 | Val loss: 3.657 | Gen: ingay ongay eday eday
Epoch:   4 | Train loss: 1.835 | Val loss: 3.780 | Gen: ingay ongay eday eday
Epoch:   5 | Train loss: 1.767 | Val loss: 3.826 | Gen: ingay oongay eday eday
Epoch:   6 | Train loss: 1.710 | Val loss: 3.807 | Gen: ingay ontay eray eday
Epoch:   7 | Train loss: 1.661 | Val loss: 3.730 | Gen: ingay ontay erestererererererere ereday
Epoch:   8 | Train loss: 1.614 | Val loss: 3.739 | Gen: ingay ontay eesteray eday
Epoch:   9 | Train loss: 1.564 | Val loss: 3.675 | Gen: itay ontay eestesay eay
Epoch:  10 | Train loss: 1.492 | Val loss: 3.711 | Gen: itay oneday eeseray eay
Epoch:  11 | Train loss: 1.433 | Val loss: 3.689 | Gen: itay ouesesay eestay eay
Epoch:

Epoch:  96 | Train loss: 0.335 | Val loss: 1.902 | Gen: itay oveway eepday earningway
Epoch:  97 | Train loss: 0.334 | Val loss: 1.634 | Gen: itay ovetay eeppay earninggay
Epoch:  98 | Train loss: 0.340 | Val loss: 1.770 | Gen: imay oveway eepway earningway
Epoch:  99 | Train loss: 0.336 | Val loss: 1.832 | Gen: itay oveway eepway earninggay


In [40]:
test_sentences = ["i am happy","this is not pig latin","this model is working"]

for test in test_sentences:
    print(find_pig_latin(test, encoder, decoder, idx_dict, opts))

itay amay appy
istay isray otnay iggay atinlay
istay odelmay isray orkingway


In [41]:
def get_accuracy(data_dict, encoder, decoder, opts, idx_dict):
    word_accuracy, letter_accuracy, total_words, total_letters = 0,0,0,0
    for key in data_dict:
        for k,v in data_dict[key]:
            v_ = translate(k, encoder, decoder, idx_dict, opts)
            total_words += 1
            total_letters += len(v_)
            if v_==v:
                word_accuracy += 1
                letter_accuracy += len(v)
            else:
                if len(v_)<len(v):
                    v_ += (len(v)-len(v_)) * '0'
                elif len(v_)>len(v):
                    v += (len(v_)-len(v)) * '0'
                
                for i in range(len(v)):
                    if v[i]==v_[i]:
                        letter_accuracy += 1
        
    return word_accuracy/total_words, letter_accuracy/total_letters


get_accuracy(val_dict, encoder, decoder, opts, idx_dict)

(0.37089201877934275, 0.8539570757880617)