# Colab FAQ

For some basic overview and features offered in Colab notebooks, check out: [Overview of Colaboratory Features](https://colab.research.google.com/notebooks/basic_features_overview.ipynb)

You need to use the colab GPU for this assignmentby selecting:

> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**

## Setup PyTorch
All files are stored at /content/NLP/HW4/ folder


In [None]:
######################################################################
# Setup python environment and change the current working directory
######################################################################
!pip install torch torchvision
!pip install Pillow==4.0.0
%mkdir -p /content/NLP/HW4/
%cd /content/NLP/HW4

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting Pillow==4.0.0
  Downloading Pillow-4.0.0.tar.gz (11.1 MB)
[K     |████████████████████████████████| 11.1 MB 8.1 MB/s 
[?25hCollecting olefile
  Downloading olefile-0.46.zip (112 kB)
[K     |████████████████████████████████| 112 kB 55.3 MB/s 
[?25hBuilding wheels for collected packages: Pillow, olefile
  Building wheel for Pillow (setup.py) ... [?25l[?25hdone
  Created wheel for Pillow: filename=Pillow-4.0.0-cp37-cp37m-linux_x86_64.whl size=1007276 sha256=d5964ce85f5d0ab26942d6a23d868183022db0721e97d0cd80f0ac71d4e39e75
  Stored in directory: /root/.cache/pip/wheels/ec/1e/24/dbc5e4964ea99cad93230a9013d934fb5adc322c3102f69e45
  Building wheel for olefile (setup.py) ... [?25l[?25hdone
  Created wheel for olefile: filename=olefile-0.46-py2.py3-none-any.whl size=35432 sha256=9b8

/content/NLP/HW4


# Helper code

## Utility functions

In [None]:
import os
import pdb
import argparse
import pickle as pkl

from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

from six.moves.urllib.request import urlretrieve
import tarfile
import pickle
import sys


def get_file(fname,
             origin,
             untar=False,
             extract=False,
             archive_format='auto',
             cache_dir='data'):
    datadir = os.path.join(cache_dir)
    if not os.path.exists(datadir):
        os.makedirs(datadir)

    if untar:
        untar_fpath = os.path.join(datadir, fname)
        fpath = untar_fpath + '.tar.gz'
    else:
        fpath = os.path.join(datadir, fname)
    
    print(fpath)
    if not os.path.exists(fpath):
        print('Downloading data from', origin)

        error_msg = 'URL fetch failure on {}: {} -- {}'
        try:
            try:
                urlretrieve(origin, fpath)
            except URLError as e:
                raise Exception(error_msg.format(origin, e.errno, e.reason))
            except HTTPError as e:
                raise Exception(error_msg.format(origin, e.code, e.msg))
        except (Exception, KeyboardInterrupt) as e:
            if os.path.exists(fpath):
                os.remove(fpath)
            raise

    if untar:
        if not os.path.exists(untar_fpath):
            print('Extracting file.')
            with tarfile.open(fpath) as archive:
                archive.extractall(datadir)
        return untar_fpath

    if extract:
        _extract_archive(fpath, datadir, archive_format)

    return fpath

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
def to_var(tensor, cuda):
    """Wraps a Tensor in a Variable, optionally placing it on the GPU.

        Arguments:
            tensor: A Tensor object.
            cuda: A boolean flag indicating whether to use the GPU.

        Returns:
            A Variable object, on the GPU if cuda==True.
    """
    if cuda:
        return Variable(tensor.cuda())
    else:
        return Variable(tensor)


def create_dir_if_not_exists(directory):
    """Creates a directory if it doesn't already exist.
    """
    if not os.path.exists(directory):
        os.makedirs(directory)


def save_loss_plot(train_losses, val_losses, opts):
    """Saves a plot of the training and validation loss curves.
    """
    plt.figure()
    plt.plot(range(len(train_losses)), train_losses)
    plt.plot(range(len(val_losses)), val_losses)
    plt.title('BS={}, nhid={}'.format(opts.batch_size, opts.hidden_size), fontsize=20)
    plt.xlabel('Epochs', fontsize=16)
    plt.ylabel('Loss', fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(opts.checkpoint_path, 'loss_plot.pdf'))
    plt.close()


def checkpoint(encoder, decoder, idx_dict, opts):
    """Saves the current encoder and decoder models, along with idx_dict, which
    contains the char_to_index and index_to_char mappings, and the start_token
    and end_token values.
    """
    with open(os.path.join(opts.checkpoint_path, 'encoder.pt'), 'wb') as f:
        torch.save(encoder, f)

    with open(os.path.join(opts.checkpoint_path, 'decoder.pt'), 'wb') as f:
        torch.save(decoder, f)

    with open(os.path.join(opts.checkpoint_path, 'idx_dict.pkl'), 'wb') as f:
        pkl.dump(idx_dict, f)


## Data loader

In [None]:
def read_lines(filename):
    """Read a file and split it into lines.
    """
    lines = open(filename).read().strip().lower().split('\n')
    return lines


def read_pairs(filename):
    """Reads lines that consist of two words, separated by a space.

    Returns:
        source_words: A list of the first word in each line of the file.
        target_words: A list of the second word in each line of the file.
    """
    lines = read_lines(filename)
    source_words, target_words = [], []
    for line in lines:
        line = line.strip()
        if line:
            source, target = line.split()
            source_words.append(source)
            target_words.append(target)
    return source_words, target_words


def all_alpha_or_dash(s):
    """Helper function to check whether a string is alphabetic, allowing dashes '-'.
    """
    return all(c.isalpha() or c == '-' for c in s)


def filter_lines(lines):
    """Filters lines to consist of only alphabetic characters or dashes "-".
    """
    return [line for line in lines if all_alpha_or_dash(line)]


def load_data():
    """Loads (English, Pig-Latin) word pairs, and creates mappings from characters to indexes.
    """

    source_lines, target_lines = read_pairs('data/pig_latin_data.txt')

    # Filter lines
    source_lines = filter_lines(source_lines)
    target_lines = filter_lines(target_lines)

    all_characters = set(''.join(source_lines)) | set(''.join(target_lines))

    # Create a dictionary mapping each character to a unique index
    char_to_index = { char: index for (index, char) in enumerate(sorted(list(all_characters))) }

    # Add start and end tokens to the dictionary
    start_token = len(char_to_index)
    end_token = len(char_to_index) + 1
    char_to_index['SOS'] = start_token
    char_to_index['EOS'] = end_token

    # Create the inverse mapping, from indexes to characters (used to decode the model's predictions)
    index_to_char = { index: char for (char, index) in char_to_index.items() }

    # Store the final size of the vocabulary
    vocab_size = len(char_to_index)

    line_pairs = list(set(zip(source_lines, target_lines)))  # Python 3

    idx_dict = { 'char_to_index': char_to_index,
                 'index_to_char': index_to_char,
                 'start_token': start_token,
                 'end_token': end_token }

    return line_pairs, vocab_size, idx_dict


def create_dict(pairs):
    """Creates a mapping { (source_length, target_length): [list of (source, target) pairs]
    This is used to make batches: each batch consists of two parallel tensors, one containing
    all source indexes and the other containing all corresponding target indexes.
    Within a batch, all the source words are the same length, and all the target words are
    the same length.
    """
    unique_pairs = list(set(pairs))  # Find all unique (source, target) pairs

    d = defaultdict(list)
    for (s,t) in unique_pairs:
        d[(len(s), len(t))].append((s,t))

    return d


## Training and evaluation code

In [None]:
def string_to_index_list(s, char_to_index, end_token):
    """Converts a sentence into a list of indexes (for each character).
    """
    return [char_to_index[char] for char in s] + [end_token]  # Adds the end token to each index list


def translate_sentence(sentence, encoder, decoder, idx_dict, opts):
    """Translates a sentence from English to Pig-Latin, by splitting the sentence into
    words (whitespace-separated), running the encoder-decoder model to translate each
    word independently, and then stitching the words back together with spaces between them.
    """
    if idx_dict is None:
      line_pairs, vocab_size, idx_dict = load_data()
    return ' '.join([translate(word, encoder, decoder, idx_dict, opts) for word in sentence.split()])


def translate(input_string, encoder, decoder, idx_dict, opts):
    """Translates a given string from English to Pig-Latin.
    """

    char_to_index = idx_dict['char_to_index']
    index_to_char = idx_dict['index_to_char']
    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']

    max_generated_chars = 20
    gen_string = ''

    indexes = string_to_index_list(input_string, char_to_index, end_token)
    indexes = to_var(torch.LongTensor(indexes).unsqueeze(0), opts.cuda)  # Unsqueeze to make it like BS = 1

    encoder_annotations, encoder_last_hidden = encoder(indexes)

    decoder_hidden = encoder_last_hidden
    decoder_input = to_var(torch.LongTensor([[start_token]]), opts.cuda)  # For BS = 1
    decoder_inputs = decoder_input

    for i in range(max_generated_chars):
      ## slow decoding, recompute everything at each time
      decoder_outputs, attention_weights = decoder(decoder_inputs, encoder_annotations, decoder_hidden)
      generated_words = F.softmax(decoder_outputs, dim=2).max(2)[1]
      ni = generated_words.cpu().numpy().reshape(-1)  # LongTensor of size 1
      ni = ni[-1] #latest output token

      decoder_inputs = torch.cat([decoder_input, generated_words], dim=1)
      
      if ni == end_token:
          break
      else:
          gen_string = "".join(
              [index_to_char[int(item)] 
               for item in generated_words.cpu().numpy().reshape(-1)])

    return gen_string


def visualize_attention(input_string, encoder, decoder, idx_dict, opts):
    """Generates a heatmap to show where attention is focused in each decoder step.
    """
    if idx_dict is None:
      line_pairs, vocab_size, idx_dict = load_data()
    char_to_index = idx_dict['char_to_index']
    index_to_char = idx_dict['index_to_char']
    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']

    max_generated_chars = 20
    gen_string = ''

    indexes = string_to_index_list(input_string, char_to_index, end_token)
    indexes = to_var(torch.LongTensor(indexes).unsqueeze(0), opts.cuda)  # Unsqueeze to make it like BS = 1

    encoder_annotations, encoder_hidden = encoder(indexes)

    decoder_hidden = encoder_hidden
    decoder_input = to_var(torch.LongTensor([[start_token]]), opts.cuda)  # For BS = 1
    decoder_inputs = decoder_input

    produced_end_token = False

    for i in range(max_generated_chars):
      ## slow decoding, recompute everything at each time
      decoder_outputs, attention_weights = decoder(decoder_inputs, encoder_annotations, decoder_hidden)
      generated_words = F.softmax(decoder_outputs, dim=2).max(2)[1]
      ni = generated_words.cpu().numpy().reshape(-1)  # LongTensor of size 1
      ni = ni[-1] #latest output token
      
      decoder_inputs = torch.cat([decoder_input, generated_words], dim=1)
      
      if ni == end_token:
          break
      else:
          gen_string = "".join(
              [index_to_char[int(item)] 
               for item in generated_words.cpu().numpy().reshape(-1)])
    
    if isinstance(attention_weights, tuple):
      ## transformer's attention mweights
      attention_weights, self_attention_weights = attention_weights
    
    all_attention_weights = attention_weights.data.cpu().numpy()
    
    for i in range(len(all_attention_weights)):
      attention_weights_matrix = all_attention_weights[i].squeeze()
      fig = plt.figure()
      ax = fig.add_subplot(111)
      cax = ax.matshow(attention_weights_matrix, cmap='bone')
      fig.colorbar(cax)

      # Set up axes
      ax.set_yticklabels([''] + list(input_string) + ['EOS'], rotation=90)
      ax.set_xticklabels([''] + list(gen_string) + (['EOS'] if produced_end_token else []))

      # Show label at every tick
      ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
      ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
      # Add title
      plt.xlabel('Attention weights to the source sentence in layer {}'.format(i+1))
      plt.tight_layout()
      plt.grid('off')
      plt.show()
      #plt.savefig(save)

      #plt.close(fig)

    return gen_string


def compute_loss(data_dict, encoder, decoder, idx_dict, criterion, optimizer, opts):
    """Train/Evaluate the model on a dataset.

    Arguments:
        data_dict: The validation/test word pairs, organized by source and target lengths.
        encoder: An encoder model to produce annotations for each step of the input sequence.
        decoder: A decoder model (with or without attention) to generate output tokens.
        idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes.
        criterion: Used to compute the CrossEntropyLoss for each decoder output.
        optimizer: Train the weights if an optimizer is given. None if only evaluate the model. 
        opts: The command-line arguments.

    Returns:
        mean_loss: The average loss over all batches from data_dict.
    """
    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']
    char_to_index = idx_dict['char_to_index']

    losses = []
    for key in data_dict:
        input_strings, target_strings = zip(*data_dict[key])
        input_tensors = [torch.LongTensor(string_to_index_list(s, char_to_index, end_token)) for s in input_strings]
        target_tensors = [torch.LongTensor(string_to_index_list(s, char_to_index, end_token)) for s in target_strings]

        num_tensors = len(input_tensors)
        num_batches = int(np.ceil(num_tensors / float(opts.batch_size)))

        for i in range(num_batches):

            start = i * opts.batch_size
            end = start + opts.batch_size

            inputs = to_var(torch.stack(input_tensors[start:end]), opts.cuda)
            targets = to_var(torch.stack(target_tensors[start:end]), opts.cuda)

            # The batch size may be different in each epoch
            BS = inputs.size(0)

            encoder_annotations, encoder_hidden = encoder(inputs)

            # The last hidden state of the encoder becomes the first hidden state of the decoder
            decoder_hidden = encoder_hidden

            start_vector = torch.ones(BS).long().unsqueeze(1) * start_token  # BS x 1 --> 16x1  CHECKED
            decoder_input = to_var(start_vector, opts.cuda)  # BS x 1 --> 16x1  CHECKED

            loss = 0.0

            seq_len = targets.size(1)  # Gets seq_len from BS x seq_len

            decoder_inputs = torch.cat([decoder_input, targets[:, 0:-1]], dim=1)  # Gets decoder inputs by shifting the targets to the right 
            
            decoder_outputs, attention_weights = decoder(decoder_inputs, encoder_annotations, encoder_hidden)
            decoder_outputs_flatten = decoder_outputs.view(-1, decoder_outputs.size(2))
            targets_flatten = targets.view(-1)
            loss = criterion(decoder_outputs_flatten, targets_flatten)

            losses.append(loss.item())

            ## training if an optimizer is provided
            if optimizer:
              # Zero gradients
              optimizer.zero_grad()
              # Compute gradients
              loss.backward()
              # Update the parameters of the encoder and decoder
              optimizer.step()
              
    mean_loss = np.mean(losses)
    return mean_loss

  

def training_loop(train_dict, val_dict, idx_dict, encoder, decoder, criterion, optimizer, opts):
    """Runs the main training loop; evaluates the model on the val set every epoch.
        * Prints training and val loss each epoch.
        * Prints qualitative translation results each epoch using TEST_SENTENCE
        * Saves an attention map for TEST_WORD_ATTN each epoch

    Arguments:
        train_dict: The training word pairs, organized by source and target lengths.
        val_dict: The validation word pairs, organized by source and target lengths.
        idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes.
        encoder: An encoder model to produce annotations for each step of the input sequence.
        decoder: A decoder model (with or without attention) to generate output tokens.
        criterion: Used to compute the CrossEntropyLoss for each decoder output.
        optimizer: Implements a step rule to update the parameters of the encoder and decoder.
        opts: The command-line arguments.
    """

    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']
    char_to_index = idx_dict['char_to_index']

    loss_log = open(os.path.join(opts.checkpoint_path, 'loss_log.txt'), 'w')

    best_val_loss = 1e6
    train_losses = []
    val_losses = []

    for epoch in range(opts.nepochs):

        optimizer.param_groups[0]['lr'] *= opts.lr_decay
        
        train_loss = compute_loss(train_dict, encoder, decoder, idx_dict, criterion, optimizer, opts)
        val_loss = compute_loss(val_dict, encoder, decoder, idx_dict, criterion, None, opts)

        if val_loss < best_val_loss:
            checkpoint(encoder, decoder, idx_dict, opts)

        gen_string = translate_sentence(TEST_SENTENCE, encoder, decoder, idx_dict, opts)
        print("Epoch: {:3d} | Train loss: {:.3f} | Val loss: {:.3f} | Gen: {:20s}".format(epoch, train_loss, val_loss, gen_string))

        loss_log.write('{} {} {}\n'.format(epoch, train_loss, val_loss))
        loss_log.flush()

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        save_loss_plot(train_losses, val_losses, opts)


def print_data_stats(line_pairs, vocab_size, idx_dict):
    """Prints example word pairs, the number of data points, and the vocabulary.
    """
    print('=' * 80)
    print('Data Stats'.center(80))
    print('-' * 80)
    for pair in line_pairs[:5]:
        print(pair)
    print('Num unique word pairs: {}'.format(len(line_pairs)))
    print('Vocabulary: {}'.format(idx_dict['char_to_index'].keys()))
    print('Vocab size: {}'.format(vocab_size))
    print('=' * 80)


def train(opts):
    line_pairs, vocab_size, idx_dict = load_data()
    print_data_stats(line_pairs, vocab_size, idx_dict)

    # Split the line pairs into an 80% train and 20% val split
    num_lines = len(line_pairs)
    num_train = int(0.8 * num_lines)
    train_pairs, val_pairs = line_pairs[:num_train], line_pairs[num_train:]

    # Group the data by the lengths of the source and target words, to form batches
    train_dict = create_dict(train_pairs)
    val_dict = create_dict(val_pairs)

    ##########################################################################
    ### Setup: Create Encoder, Decoder, Learning Criterion, and Optimizers ###
    ##########################################################################
    encoder = GRUEncoder(vocab_size=vocab_size, 
                         hidden_size=opts.hidden_size, 
                         opts=opts)

    if opts.decoder_type == 'rnn':
        decoder = RNNDecoder(vocab_size=vocab_size, 
                             hidden_size=opts.hidden_size)
    elif opts.decoder_type == 'rnn_attention':
        decoder = RNNAttentionDecoder(vocab_size=vocab_size, 
                                      hidden_size=opts.hidden_size, 
                                      attention_type=opts.attention_type)
    else:
        raise NotImplementedError
        
    #### setup checkpoint path
    model_name = 'h{}-bs{}-{}'.format(opts.hidden_size, 
                                      opts.batch_size, 
                                      opts.decoder_type)
    opts.checkpoint_path = model_name
    create_dir_if_not_exists(opts.checkpoint_path)
    ####

    if opts.cuda:
        encoder.cuda()
        decoder.cuda()
        print("Moved models to GPU!")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=opts.learning_rate)

    try:
        training_loop(train_dict, val_dict, idx_dict, encoder, decoder, criterion, optimizer, opts)
    except KeyboardInterrupt:
        print('Exiting early from training.')
        return encoder, decoder
      
    return encoder, decoder


def print_opts(opts):
    """Prints the values of all command-line arguments.
    """
    print('=' * 80)
    print('Opts'.center(80))
    print('-' * 80)
    for key in opts.__dict__:
        print('{:>30}: {:<30}'.format(key, opts.__dict__[key]).center(80))
    print('=' * 80)


# Your code for NMT models

## GRU cell

In [None]:
class MyGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(MyGRUCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        # ------------
        # FILL THIS IN
        # ------------
        ## Input linear layers
        self.Wiz = nn.Linear(input_size, hidden_size, bias = False)
        self.Wir = nn.Linear(input_size, hidden_size, bias = False)
        self.Win = nn.Linear(input_size, hidden_size, bias = False)
        ## Hidden linear layers
        self.Whz = nn.Linear(hidden_size, hidden_size)
        self.Whr = nn.Linear(hidden_size, hidden_size)
        self.Whn = nn.Linear(hidden_size, hidden_size)
        


    def forward(self, x, h_prev):
        """Forward pass of the GRU computation for one time step.

        Arguments
            x: batch_size x input_size
            h_prev: batch_size x hidden_size

        Returns:
            h_new: batch_size x hidden_size
        """

        # ------------
        # FILL THIS IN
        # ------------
        z = F.sigmoid(self.Wiz(x) + self.Whz(h_prev))
        r = F.sigmoid(self.Wir(x) + self.Whr(h_prev))
        g = F.tanh(self.Win(x) + r * self.Whn(h_prev))
        h_new = (1-z)*g + z*h_prev
        return h_new

        

### GRU encoder / decoder

In [None]:
class GRUEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, opts):
        super(GRUEncoder, self).__init__()

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.opts = opts

        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.gru = nn.GRUCell(hidden_size, hidden_size)

    def forward(self, inputs):
        """Forward pass of the encoder RNN.

        Arguments:
            inputs: Input token indexes across a batch for all time steps in the sequence. (batch_size x seq_len)

        Returns:
            annotations: The hidden states computed at each step of the input sequence. (batch_size x seq_len x hidden_size)
            hidden: The final hidden state of the encoder, for each sequence in a batch. (batch_size x hidden_size)
        """

        batch_size, seq_len = inputs.size()
        hidden = self.init_hidden(batch_size)

        encoded = self.embedding(inputs)  # batch_size x seq_len x hidden_size
        annotations = []

        for i in range(seq_len):
            x = encoded[:,i,:]  # Get the current time step, across the whole batch
            hidden = self.gru(x, hidden)
            annotations.append(hidden)

        annotations = torch.stack(annotations, dim=1)
        return annotations, hidden

    def init_hidden(self, bs):
        """Creates a tensor of zeros to represent the initial hidden states
        of a batch of sequences.

        Arguments:
            bs: The batch size for the initial hidden state.

        Returns:
            hidden: An initial hidden state of all zeros. (batch_size x hidden_size)
        """
        return to_var(torch.zeros(bs, self.hidden_size), self.opts.cuda)


class RNNDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(RNNDecoder, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRUCell(input_size=hidden_size, hidden_size=hidden_size)
        self.out = nn.Linear(hidden_size, vocab_size)

    def forward(self, inputs, annotations, hidden_init):
        """Forward pass of the non-attentional decoder RNN.

        Arguments:
            inputs: Input token indexes across a batch. (batch_size x seq_len)
            annotations: This is not used here. It just maintains consistency with the
                    interface used by the AttentionDecoder class.
            hidden_init: The hidden states from the last step of encoder, across a batch. (batch_size x hidden_size)

        Returns:
            output: Un-normalized scores for each token in the vocabulary, across a batch for all the decoding time steps. (batch_size x decoder_seq_len x vocab_size)
            None
        """        
        batch_size, seq_len = inputs.size()
        embed = self.embedding(inputs)  # batch_size x seq_len x hidden_size        

        hiddens = []
        h_prev = hidden_init
        for i in range(seq_len):
            x = embed[:,i,:]  # Get the current time step input tokens, across the whole batch
            h_prev = self.rnn(x, h_prev)  # batch_size x hidden_size
            hiddens.append(h_prev)

        hiddens = torch.stack(hiddens, dim=1) # batch_size x seq_len x hidden_size
        
        output = self.out(hiddens)  # batch_size x seq_len x vocab_size
        return output, None      


## Attention

In [None]:
class AdditiveAttention(nn.Module):
    def __init__(self, hidden_size):
        super(AdditiveAttention, self).__init__()

        self.hidden_size = hidden_size

        # A two layer fully-connected network
        # hidden_size*2 --> hidden_size, ReLU, hidden_size --> 1
        self.attention_network = nn.Sequential(
                                    nn.Linear(hidden_size*2, hidden_size),
                                    nn.ReLU(),
                                    nn.Linear(hidden_size, 1)
                                 )

        self.softmax = nn.Softmax(dim=1)

    def forward(self, queries, keys, values):
        """The forward pass of the additive attention mechanism.

        Arguments:
            queries: The current decoder hidden state. (batch_size x hidden_size)
            keys: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)
            values: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)

        Returns:
            context: weighted average of the values (batch_size x 1 x hidden_size)
            attention_weights: Normalized attention weights for each encoder hidden state. (batch_size x seq_len x 1)

            The attention_weights must be a softmax weighting over the seq_len annotations.
        """

        # ------------
        # FILL THIS IN
        # ------------
        batch_size, seq_len, hidden_size = keys.size()
        expanded_queries = queries.unsqueeze(1).expand_as(keys) #size is batch_size x seq_len x hidden_size
        concat_inputs = torch.cat((expanded_queries, keys),2)
        unnormalized_attention = self.attention_network(concat_inputs.view(-1, hidden_size * 2)).view(batch_size, seq_len, 1)
        attention_weights = self.softmax(unnormalized_attention) #size is batch_size x seq_len x 1
        context = torch.bmm(attention_weights.transpose(1,2), values)
        return context, attention_weights
         

### Attention decoder

In [None]:
class RNNAttentionDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, attention_type='additive'):
        super(RNNAttentionDecoder, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(vocab_size, hidden_size)

        self.rnn = MyGRUCell(input_size=hidden_size*2, hidden_size=hidden_size)
        self.attention = AdditiveAttention(hidden_size=hidden_size)
        
        self.out = nn.Linear(hidden_size, vocab_size)

        
    def forward(self, inputs, annotations, hidden_init):
        """Forward pass of the attention-based decoder RNN.

        Arguments:
            inputs: Input token indexes across a batch for all the time step. (batch_size x decoder_seq_len)
            annotations: The encoder hidden states for each step of the input.
                         sequence. (batch_size x seq_len x hidden_size)
            hidden_init: The final hidden states from the encoder, across a batch. (batch_size x hidden_size)

        Returns:
            output: Un-normalized scores for each token in the vocabulary, across a batch for all the decoding time steps. (batch_size x decoder_seq_len x vocab_size)
            attentions: The stacked attention weights applied to the encoder annotations (batch_size x encoder_seq_len x decoder_seq_len)
        """
        
        batch_size, seq_len = inputs.size()
        embed = self.embedding(inputs)  # batch_size x seq_len x hidden_size        

        hiddens = []
        attentions = []
        h_prev = hidden_init
        for i in range(seq_len):
            # ------------
            # FILL THIS IN
            # ------------
            embed_current = embed[:,i,:]
            context, attention_weights = self.attention(embed_current, annotations, annotations)
            embed_and_context = torch.cat((embed_current, context.squeeze(1)), 1)
            h_prev = self.rnn(embed_and_context,h_prev)

            
            hiddens.append(h_prev)
            attentions.append(attention_weights)

        hiddens = torch.stack(hiddens, dim=1) # batch_size x seq_len x hidden_size
        attentions = torch.cat(attentions, dim=2) # batch_size x seq_len x seq_len
        
        output = self.out(hiddens)  # batch_size x seq_len x vocab_size
        return output, attentions
        

# Training


## Download dataset

In [None]:
######################################################################
# Download Translation datasets
######################################################################
data_fpath = get_file(fname='pig_latin_data.txt', 
                         origin='http://www.cs.toronto.edu/~jba/pig_latin_data.txt', 
                         untar=False)

data/pig_latin_data.txt
Downloading data from http://www.cs.toronto.edu/~jba/pig_latin_data.txt


## RNN decoder

In [None]:
TEST_SENTENCE = 'the air conditioning is working'

args = AttrDict()
args_dict = {
              'cuda':True, 
              'nepochs':100, 
              'checkpoint_dir':"checkpoints", 
              'learning_rate':0.005, 
              'lr_decay':0.99,
              'batch_size':64, 
              'hidden_size':20, 
              'decoder_type': 'rnn', # options: rnn / rnn_attention / transformer
              'attention_type': '',  # options: additive / scaled_dot
}
args.update(args_dict)

print_opts(args)
rnn_encoder, rnn_decoder = train(args)

translated = translate_sentence(TEST_SENTENCE, rnn_encoder, rnn_decoder, None, args)
print("source:\t\t{} \ntranslated:\t{}".format(TEST_SENTENCE, translated))

                                      Opts                                      
--------------------------------------------------------------------------------
                                   cuda: 1                                      
                                nepochs: 100                                    
                         checkpoint_dir: checkpoints                            
                          learning_rate: 0.005                                  
                               lr_decay: 0.99                                   
                             batch_size: 64                                     
                            hidden_size: 20                                     
                           decoder_type: rnn                                    
                         attention_type:                                        
                                   Data Stats                                   
----------------------------

In [None]:
TEST_SENTENCE = 'the air conditioning is working'
translated = translate_sentence(TEST_SENTENCE, rnn_encoder, rnn_decoder, None, args)
print("source:\t\t{} \ntranslated:\t{}".format(TEST_SENTENCE, translated))

source:		the air conditioning is working 
translated:	ethay ainway onditionationway isway-ybay oullway


## RNN attention decoder

In [None]:
TEST_SENTENCE = 'the air conditioning is working'

args = AttrDict()
args_dict = {
              'cuda':True, 
              'nepochs':100, 
              'checkpoint_dir':"checkpoints", 
              'learning_rate':0.005, 
              'lr_decay':0.99,
              'batch_size':64, 
              'hidden_size':20, 
              'decoder_type': 'rnn_attention', # options: rnn / rnn_attention / transformer
              'attention_type': 'additive',  # options: additive / scaled_dot
}
args.update(args_dict)

print_opts(args)
rnn_attn_encoder, rnn_attn_decoder = train(args)

translated = translate_sentence(TEST_SENTENCE, rnn_attn_encoder, rnn_attn_decoder, None, args)
print("source:\t\t{} \ntranslated:\t{}".format(TEST_SENTENCE, translated))

                                      Opts                                      
--------------------------------------------------------------------------------
                                   cuda: 1                                      
                                nepochs: 100                                    
                         checkpoint_dir: checkpoints                            
                          learning_rate: 0.005                                  
                               lr_decay: 0.99                                   
                             batch_size: 64                                     
                            hidden_size: 20                                     
                           decoder_type: rnn_attention                          
                         attention_type: additive                               
                                   Data Stats                                   
----------------------------



Epoch:   0 | Train loss: 2.237 | Val loss: 2.080 | Gen: ay-ay-ay atay-ay ontay-intintintintin ontay-intintintintay onssay-ontay-ontay-a




Epoch:   1 | Train loss: 1.775 | Val loss: 1.912 | Gen: away-away-away-away- away ontingingtay-ontay-o ingway-ingway ontingway-ontay-onta




Epoch:   2 | Train loss: 1.554 | Val loss: 1.753 | Gen: eatehay-iay-ay-ay-ay ayway ontingtingay inway oulingway




Epoch:   3 | Train loss: 1.373 | Val loss: 1.654 | Gen: atieway-iay-iay-iay- ilay otingdingday isisay oninghay-inghway




Epoch:   4 | Train loss: 1.195 | Val loss: 1.526 | Gen: eatheway-iay away ontingongingday issway oninghway




Epoch:   5 | Train loss: 1.031 | Val loss: 1.453 | Gen: eathiay ayirway ontingingingdingding isisway onghinghinghway




Epoch:   6 | Train loss: 0.898 | Val loss: 1.287 | Gen: eatheway airway ontingdingdingdingdw isay ongingway




Epoch:   7 | Train loss: 0.776 | Val loss: 1.150 | Gen: eatheway airway ongingtiongtingtingt isisway oruingway




Epoch:   8 | Train loss: 0.689 | Val loss: 1.214 | Gen: eway-iesthay ayiray ongingingday-ingday- isway orkingway




Epoch:   9 | Train loss: 0.648 | Val loss: 1.165 | Gen: eway-awhay airway ondistiondiondway-is isisway orkingingway




Epoch:  10 | Train loss: 0.575 | Val loss: 1.037 | Gen: eway-eway airway ongay-ingtingway isisway orkingway




Epoch:  11 | Train loss: 0.518 | Val loss: 0.965 | Gen: eway-iahway airway oningtingdingdinway isway orkingway




Epoch:  12 | Train loss: 0.465 | Val loss: 0.833 | Gen: eway-eway-eway airway ongtingsingtay isway orkingway




Epoch:  13 | Train loss: 0.428 | Val loss: 0.930 | Gen: eway aiway ongingningngditingwa isway orkingway




Epoch:  14 | Train loss: 0.418 | Val loss: 0.871 | Gen: eway-eway-iway-eway- airway ongtingtingway-ingta isway orkingway




Epoch:  15 | Train loss: 0.381 | Val loss: 0.801 | Gen: ehay-ehay airway ongtingingtingtingwa isway orkingway




Epoch:  16 | Train loss: 0.353 | Val loss: 0.836 | Gen: esthay airway ongtingtingcay isway orkingway




Epoch:  17 | Train loss: 0.357 | Val loss: 0.978 | Gen: eathay airway ongingniongnitiongni isway orkingway




Epoch:  18 | Train loss: 0.365 | Val loss: 0.850 | Gen: eway airway ongcingcay isway orkingway




Epoch:  19 | Train loss: 0.320 | Val loss: 0.806 | Gen: ehay-eway airway ongcay-ingcay isway orkingway




Epoch:  20 | Train loss: 0.296 | Val loss: 0.846 | Gen: eqsthay airway ongcay-ongcay isway orkingway




Epoch:  21 | Train loss: 0.298 | Val loss: 0.814 | Gen: ethay airway oninininiciningcay isway orkingway




Epoch:  22 | Train loss: 0.281 | Val loss: 0.799 | Gen: eathay airway ongcingcay isway orkingway




Epoch:  23 | Train loss: 0.274 | Val loss: 0.799 | Gen: eqsthay airway onditiongcay isway orkingway




Epoch:  24 | Train loss: 0.278 | Val loss: 0.788 | Gen: eqpay airway ongdingdingcay isway orkingway




Epoch:  25 | Train loss: 0.245 | Val loss: 0.674 | Gen: eathay airway ongciningcay isway orkingway




Epoch:  26 | Train loss: 0.231 | Val loss: 0.711 | Gen: eqstay airway ongcay isway orkingway




Epoch:  27 | Train loss: 0.228 | Val loss: 0.673 | Gen: eachay airway oningcay-iningcay isway orkingway




Epoch:  28 | Train loss: 0.216 | Val loss: 0.683 | Gen: eacthay airway ongcay-ingcay isway orkingway




Epoch:  29 | Train loss: 0.205 | Val loss: 0.668 | Gen: eachay airway ondingcay isway orkingway




Epoch:  30 | Train loss: 0.199 | Val loss: 0.690 | Gen: eachay airway ongcay-ingcay isway orkingway




Epoch:  31 | Train loss: 0.201 | Val loss: 0.701 | Gen: eacthay airway ongcondingcay isway orkingway




Epoch:  32 | Train loss: 0.205 | Val loss: 0.695 | Gen: eacthay airway ongcay-ingcay isway orkingway




Epoch:  33 | Train loss: 0.202 | Val loss: 0.706 | Gen: eachay airway ongciingcay isway orkingway




Epoch:  34 | Train loss: 0.185 | Val loss: 0.653 | Gen: eachay airway ongciingcay isway orkingway




Epoch:  35 | Train loss: 0.173 | Val loss: 0.658 | Gen: eacthay airway ongcongcay isway orkingway




Epoch:  36 | Train loss: 0.164 | Val loss: 0.647 | Gen: ehay airway ongciingcay isway orkingway




Epoch:  37 | Train loss: 0.156 | Val loss: 0.669 | Gen: eayc airway ongcay-ingcay isway orkingway




Epoch:  38 | Train loss: 0.165 | Val loss: 0.692 | Gen: eacthay airway ondingcay-ingcay isway orkingway




Epoch:  39 | Train loss: 0.188 | Val loss: 0.930 | Gen: eachay airway oningdiningcay isisway orkingway




Epoch:  40 | Train loss: 0.217 | Val loss: 0.825 | Gen: eaypay airway ongcay isway orkingway




Epoch:  41 | Train loss: 0.208 | Val loss: 0.831 | Gen: eachay airway ongcay isway orkingway




Epoch:  42 | Train loss: 0.178 | Val loss: 0.704 | Gen: eactlay airway ondindingcay isway orkingway




Epoch:  43 | Train loss: 0.164 | Val loss: 0.657 | Gen: eactlay airway ongciongcay isway orkingway




Epoch:  44 | Train loss: 0.155 | Val loss: 0.662 | Gen: eacthay airway oningcay-iningcay isway orkingway




Epoch:  45 | Train loss: 0.151 | Val loss: 0.643 | Gen: eactlay airway oningciongcay isway orkingway




Epoch:  46 | Train loss: 0.137 | Val loss: 0.676 | Gen: eactlay airway ongciningcay isway orkingway




Epoch:  47 | Train loss: 0.132 | Val loss: 0.636 | Gen: eactlay airway ongciongcay isway orkingway




Epoch:  48 | Train loss: 0.123 | Val loss: 0.651 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  49 | Train loss: 0.125 | Val loss: 0.644 | Gen: eactlay airway ongciningcay isway orkingway




Epoch:  50 | Train loss: 0.149 | Val loss: 0.766 | Gen: eactlay airway ongcitingcay isway orkingway




Epoch:  51 | Train loss: 0.149 | Val loss: 0.694 | Gen: ecay airway ongcay isway orkingway




Epoch:  52 | Train loss: 0.147 | Val loss: 0.670 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  53 | Train loss: 0.126 | Val loss: 0.611 | Gen: eactlay airway ongciongciongcay isway orkingway




Epoch:  54 | Train loss: 0.118 | Val loss: 0.620 | Gen: eactlay airway ongcongcay isway orkingway




Epoch:  55 | Train loss: 0.114 | Val loss: 0.634 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  56 | Train loss: 0.112 | Val loss: 0.614 | Gen: ecay airway ondingcay isway orkingway




Epoch:  57 | Train loss: 0.106 | Val loss: 0.628 | Gen: eactlay airway ondingcay isway orkingway




Epoch:  58 | Train loss: 0.104 | Val loss: 0.622 | Gen: eactlay airway ondingciningcay isway orkingway




Epoch:  59 | Train loss: 0.100 | Val loss: 0.640 | Gen: eactlay airway ondiningcay-inwendwa isway orkingway




Epoch:  60 | Train loss: 0.100 | Val loss: 0.655 | Gen: eactlay airway ondiningcay isway orkingway




Epoch:  61 | Train loss: 0.100 | Val loss: 0.640 | Gen: egtay airway oningcay-iningcay isway orkingway




Epoch:  62 | Train loss: 0.099 | Val loss: 0.652 | Gen: eactlay airway ondinicay-iningcay isway orkingway




Epoch:  63 | Train loss: 0.107 | Val loss: 0.757 | Gen: eayc airway oningciongcay isway orkingway




Epoch:  64 | Train loss: 0.157 | Val loss: 0.963 | Gen: eway-awttgay airway ondcicay isway orkingway




Epoch:  65 | Train loss: 0.217 | Val loss: 1.041 | Gen: eactlay airway ondingcitingcay isway orkingway




Epoch:  66 | Train loss: 0.203 | Val loss: 0.697 | Gen: eactlay airway ongcitingcay isway orkingway




Epoch:  67 | Train loss: 0.129 | Val loss: 0.575 | Gen: eactlay airway ongconingcay isway orkingway




Epoch:  68 | Train loss: 0.100 | Val loss: 0.592 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  69 | Train loss: 0.093 | Val loss: 0.592 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  70 | Train loss: 0.088 | Val loss: 0.596 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  71 | Train loss: 0.086 | Val loss: 0.596 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  72 | Train loss: 0.084 | Val loss: 0.608 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  73 | Train loss: 0.083 | Val loss: 0.604 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  74 | Train loss: 0.081 | Val loss: 0.613 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  75 | Train loss: 0.080 | Val loss: 0.621 | Gen: eactlay airway onindicay-iningcay isway orkingway




Epoch:  76 | Train loss: 0.079 | Val loss: 0.607 | Gen: eactlay airway oningcay-ininway isway orkingway




Epoch:  77 | Train loss: 0.079 | Val loss: 0.636 | Gen: eactlay airway oningciongcay isway orkingway




Epoch:  78 | Train loss: 0.079 | Val loss: 0.620 | Gen: eactlay airway oningcay-ininway isway orkingway




Epoch:  79 | Train loss: 0.080 | Val loss: 0.616 | Gen: eactlay airway ondingcondicway isway orkingway




Epoch:  80 | Train loss: 0.086 | Val loss: 0.768 | Gen: eactlay airway onindininicay-ininwa isway orkingway




Epoch:  81 | Train loss: 0.102 | Val loss: 0.684 | Gen: eawtlay airway onindingcay isway orkingway




Epoch:  82 | Train loss: 0.084 | Val loss: 0.635 | Gen: egtay airway ondiningcay isway orkingway




Epoch:  83 | Train loss: 0.078 | Val loss: 0.608 | Gen: egtay airway ondinicay-inindway isway orkingway




Epoch:  84 | Train loss: 0.073 | Val loss: 0.613 | Gen: egtay airway onindicay-iningcay isway orkingway




Epoch:  85 | Train loss: 0.073 | Val loss: 0.622 | Gen: egtay airway onindicay-iningcay isway orkingway




Epoch:  86 | Train loss: 0.070 | Val loss: 0.627 | Gen: egthay airway ondiningcay isway orkingway




Epoch:  87 | Train loss: 0.070 | Val loss: 0.628 | Gen: egtay airway ondiningcay isway orkingway




Epoch:  88 | Train loss: 0.070 | Val loss: 0.625 | Gen: eactlay airway ondiningcay isway orkingway




Epoch:  89 | Train loss: 0.077 | Val loss: 0.767 | Gen: egthay airway ondiningcay isway orkingway




Epoch:  90 | Train loss: 0.112 | Val loss: 0.751 | Gen: ethay airway ongconingcay isway orkingweyway




Epoch:  91 | Train loss: 0.100 | Val loss: 0.762 | Gen: eactlay airway ondingcay-ininway isway orkingway




Epoch:  92 | Train loss: 0.091 | Val loss: 0.733 | Gen: eactlay airway oningcay-iningcay isway orkingway




Epoch:  93 | Train loss: 0.078 | Val loss: 0.658 | Gen: eactlay airway ondiningcay isway orkingway




Epoch:  94 | Train loss: 0.074 | Val loss: 0.637 | Gen: egthay airway ondinicay-inniscay isway orkingway




Epoch:  95 | Train loss: 0.067 | Val loss: 0.647 | Gen: eactlay airway onindicay-ininway isway orkingway




Epoch:  96 | Train loss: 0.063 | Val loss: 0.630 | Gen: eactlay airway onindicay-inindgcay isway orkingway




Epoch:  97 | Train loss: 0.060 | Val loss: 0.641 | Gen: eactlay airway onindicay-iningcay isway orkingway




Epoch:  98 | Train loss: 0.059 | Val loss: 0.643 | Gen: eactlay airway onindicay-inindgcay isway orkingway




Epoch:  99 | Train loss: 0.058 | Val loss: 0.638 | Gen: ethay airway onindicay-inindgcay isway orkingway




source:		the air conditioning is working 
translated:	ethay airway onindicay-inindgcay isway orkingway


In [None]:
TEST_SENTENCE = 'the air conditioning is working'
translated = translate_sentence(TEST_SENTENCE, rnn_attn_encoder, rnn_attn_decoder, None, args)
print("source:\t\t{} \ntranslated:\t{}".format(TEST_SENTENCE, translated))



source:		the air conditioning is working 
translated:	ethay airway onindicay-inindgcay isway orkingway
