In [7]:
import loader
import argparse
import rnn_models
# from beam_search import *
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import torch
from torchtext import data
from collections import defaultdict
import numpy as np
import pdb
import sacrebleu
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torchtext import data
from torchtext import datasets

import io
import os
import string

In [8]:
parser = argparse.ArgumentParser(description='Testing')
parser.add_argument("--max_sentence_length", help="maximum sentence length", type=int, default=50)
parser.add_argument("--min_freq", help="filter out tokens less than min frequency", type=int, default=3)
parser.add_argument("--max_vocab_size", help="at most n tokens in vocabulary", type=int, default=100000)

_StoreAction(option_strings=['--max_vocab_size'], dest='max_vocab_size', nargs=None, const=None, default=100000, type=<class 'int'>, choices=None, help='at most n tokens in vocabulary', metavar=None)

In [27]:
class Args():
    
    #########
    # Paths #
    #########
    
    data = '/scratch/vr1059/vi-en/'
    train_prefix = 'train_500'
    val_prefix = 'train_500'
    test_prefix = 'test'
    src_ext = '.tok.vi'
    trg_ext = '.tok.en'

    max_sentence_length = 50
    min_freq = 1
    max_vocab_size = 100000
    
    ################
    # Model params #
    ################
    
    hidden_size = 500
    embedding_size = 500
    bidirectional = True
    num_encoder_layers = 2
    num_decoder_layers = 2
    attn_model = 'general'
    lr = 5e-3
    epochs = 50
    batch_size = 32
    print_every = 10
    clip = 1
    
args = Args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
train_data, val_data, test_data, src, trg = loader.load_data(args)

most common source vocabs: [(',', 472), ('.', 432), ('và', 177), ('của', 165), ('những', 143), ('là', 140), ('tôi', 139), ('một', 138), ('bạn', 111), ('"', 103)]
source vocab size: 2013
most common english vocabs: [(',', 563), ('.', 476), ('the', 368), ('and', 286), ('to', 240), ('of', 220), ('a', 206), ('you', 159), ('we', 159), ('that', 156)]
english vocab size: 1935


In [29]:
print(len(train_data))
print(len(val_data))

499
499


In [12]:
def train_batch(phase, args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func, batch, device):
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    ###########
    # Encoder #
    ###########
    
    seq_len, batch_size = batch.trg[0].shape
    hidden = encoder.random_init_hidden(device, batch_size)
    encoder_outputs, hidden = encoder(hidden, batch.src[0], batch.src[1])  
    
    ###########
    # Decoder #
    ###########
    
    # Teacher-forcing always ON
    
    # [2, 2, 2, ..., 2]. List of SOS tokens, batch-sized. 
    decoder_input = batch.trg[0][0,:] 
    eos_encountered_list = [False]*batch_size
    
    i = 0
    loss = 0
    number_of_loss_calculation = 0
    
    # decoder.hidden = encoder.hidden[:decoder.n_layers] 
    # Use last (forward) hidden state from encoder #TODO: verify
    hidden = hidden[:decoder.n_layers]
    
    while ((i+1 < seq_len) and (sum(eos_encountered_list) < batch_size)):
        
        logits, _, hidden = decoder(hidden, decoder_input, encoder_outputs)
        logits = logits.unsqueeze(0)
        class_probs = F.log_softmax(logits, dim = 2)
        decoder_input = batch.trg[0][i+1,:]
        
        # i+1 represents the current index in all sequences
        for j in range(batch_size):
            if not eos_encountered_list[j]:
                loss += loss_func(class_probs[0, j, :].view(1, -1), batch.trg[0][i+1, j].view(1))
                number_of_loss_calculation += 1
                
                if batch.trg[0][i+1, j] == EOS_IDX:
                    eos_encountered_list[j] = True
                    
        i += 1
        
     
    # calculate gradients on each parameter
    loss.backward()
    
    # clip if too large
    nn.utils.clip_grad_norm_(encoder.parameters(), args.clip)
    nn.utils.clip_grad_norm_(decoder.parameters(), args.clip)

    # take gradient step
    encoder_optimizer.step()
    decoder_optimizer.step()
        
    # report avg loss over minibatch
    return loss.item()/number_of_loss_calculation


#               #
# Loss function #
#               #

# loss += loss_func(output[0, j, :].view(1, -1), batch.trg[0][i+1, j].view(1))
                
# so the way NLLLoss is set up, the target is simply the index that you want to predict. 
# and the input can be a softmax over the entire output vocabulary space
# and nllloss calculate loss value between that index between predicted and 
# elementary vector e_target_idx (zeroes everywhere except 1 in target index position)
    

In [13]:
def train(args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func, device, epoch_idx, 
                 train_data, val_data, trg):
    
    
    # Create batches with pre-sorted, similar-length sequences in each
    train_iter = data.BucketIterator(
        dataset=train_data, 
        batch_size=args.batch_size,
        repeat=False,
        sort_key=lambda x: len(x.src),
        sort_within_batch=True,
        device=device,
        train=True
    )

    # Set training flag
    encoder.train()
    decoder.train()

    train_losses = []
    for i, batch in enumerate(iter(train_iter)):
        avg_loss = train_batch('train', args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func, batch, device)
        train_losses.append(avg_loss)
        if args.print_every and i % args.print_every == 0:
            print("train, epoch: {}, batch number: {}, batch loss: {}".format(
            epoch_idx, i, avg_loss))
            
    print("epoch: {}, average loss for epoch: {}, size of last batch {}".format(
    epoch_idx, np.mean(train_losses), batch.src[0].shape[1]))
        
    return np.mean(train_losses)

In [14]:
def calculate_bleu(predictions, labels):
    """
    Only pass a list of strings 
    """
    # n_gram = 4

    bleu = sacrebleu.raw_corpus_bleu(predictions, [labels], .01).score
    return bleu

In [15]:
calculate_bleu(['I am rich. '], ['I am rich.'])

100.00000000000004

In [16]:
def beam_search(decoder, decoder_input, encoder_outputs, hidden, max_length, k, trg):
    
    candidates = [(decoder_input, 0, hidden)]
    potential_candidates = []
    completed_translations = []

    # put a cap on sentence length
    for m in range(max_length):
        for c in candidates:
            # unpack the tuple
            c_sequence, c_score, c_hidden = c
            
            # EOS token
            if c_sequence[-1] == EOS_IDX:
                completed_translations.append((c_sequence, c_score))
                k = k - 1
            else:
                logits, _, hidden = decoder(c_hidden.contiguous()[:decoder.n_layers], c_sequence.contiguous()[-1].unsqueeze(0), encoder_outputs)
                next_word_probs = F.log_softmax(logits, dim = 1)
                # in the worst-case, one sequence will have the highest k probabilities
                # so to save computation, only grab the k highest_probability from each candidate sequence
                top_probs, top_idx = torch.topk(next_word_probs, k)
                top_probs.squeeze_()
                top_idx.squeeze_()
                top_probs = [top_probs] if len(top_probs.size()) == 0 else top_probs
                top_idx = [top_idx] if len(top_idx.size()) == 0 else top_idx
                for i in range(len(top_probs)):
                    word = top_idx[i].reshape(1, 1).to(device)
                    new_score = c_score + top_probs[i]
                    potential_candidates.append((torch.cat((c_sequence, word)).to(device), new_score, hidden))

        candidates = sorted(potential_candidates, key= lambda x: x[1], reverse=True)[0:k] 
        potential_candidates = []

    completed = completed_translations + candidates
    completed = sorted(completed, key= lambda x: x[1], reverse=True)[0] 
    return completed[0]

In [17]:
def ids_to_words(ids, trg):
    words = ""
    for x in ids:
        words += trg.vocab.itos[x.squeeze().item()] + ' '
    return words.strip()

In [18]:
# Chaitra said the best way to get hyperparameters for your model
# was to take like 100, 200, 500 examples from your dataset
# and see which hyperparameter combination overfits the best/fastest on that example. 
# And then use that as your hyperparameter combination to train the model. 

In [19]:
def val_batch(args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func, batch, trg, device):
    
    encoder.eval()
    decoder.eval()
    
    ############
    #  encode  #
    ############
    
    _, batch_size = batch.trg[0].shape
    hidden = encoder.random_init_hidden(device, batch_size)
    encoder_outputs, hidden = encoder(hidden, batch.src[0], batch.src[1])
    
    #################
    #  beam search  #
    #################
    
    max_length = 30
    k = 2 
    
    translations = []
    trg_translations = []
    for i in range(batch_size):
        decoder_input = torch.tensor([[src.vocab.stoi['SOS']]], device=device)
        decoder_hidden = hidden[:, i, :].unsqueeze(1)
        encoder_outputs_i = encoder_outputs[:, i, :].unsqueeze(1)

        seq_of_ids = beam_search(decoder, decoder_input, encoder_outputs_i, decoder_hidden, max_length, k, trg)
        translations.append(ids_to_words(seq_of_ids, trg))
        trg_translations.append(ids_to_words(batch.trg[0][:batch.trg[1][i], i], trg))
        
    return translations, trg_translations 

In [20]:
def val(args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func, device, epoch_idx, 
        val_data, trg):
    
    # Create minibatches over validation data
    val_iter = data.BucketIterator(
        dataset=val_data, 
        batch_size=args.batch_size,
        train=False,
        shuffle=False,
        # A key to use for sorting examples in order to batch together 
        # examples with similar lengths and minimize padding.
        sort=True,
        sort_key=lambda x: len(x.src),
        repeat=False,
        sort_within_batch=True,
        device=device
    )
    
    val_losses = []
    val_bleus = []
    val_references = []
    
    all_predicted = []
    all_trg = []
    for i, batch in enumerate(iter(val_iter)):
        predicted_trans, trg_trans = val_batch(args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func, batch, trg, device)
        all_predicted += predicted_trans
        all_trg += trg_trans
    print(all_predicted[:10])
    print(all_trg[:10])
    bleu = calculate_bleu(all_predicted, all_trg)
    print(bleu)
    return bleu
    
    
#                    #
# Batch & Dimensions #
#                    #
# `batch` represents a batch of examples. 
# `batch.src` consists of two tensors. 
# The first, `b.src[0]`, is the `src` examples from your batch; it's a tensor with the shape (max_seq_len, batch_size). 
# Your sequences have already been indexed and padded. 
# The second, `b.src[1]`, is the actual lengths of each sequence. It is of shape (batch_size, 1). 

# data.BucketIterator automatically batches sequences of similar lengths together. 
# it also automatically sorts in reverse order. 

# Say you have a bidirectional, 2-layer RNN encoder. A single batch has max length 19 and batch size 32. 
# The encoder_outputs will have shape: (19, 32, 512). 
# Basically, it only returns the topmost layer's hidden states at each step of the sequence. 
# And it concatenates both directional outputs (hidden states) for the topmost layer. 

In [28]:
src_padding_idx = src.vocab.stoi['<pad>']
trg_padding_idx = trg.vocab.stoi['<pad>']
EOS_IDX = trg.vocab.stoi['EOS']

encoder = rnn_models.Encoder(args, src_padding_idx, len(src.vocab)).to(device)
decoder = rnn_models.LuongAttnDecoderRNN(args, trg_padding_idx, len(trg.vocab)).to(device)

# initialize weights using gaussian with 0 mean and 0.01 std, just like the paper said
# TODO: Better initialization. Xavier?
for net in [encoder, decoder]:
    for name, param in net.named_parameters(): 
        #print(name, type(param), param)
        if 'bias' in name:
            nn.init.constant_(param, 0.0)
        elif 'weight' in name:
            nn.init.xavier_normal_(param)
            
encoder_optimizer = optim.Adam(encoder.parameters(), lr=args.lr)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=args.lr)
enc_scheduler = ReduceLROnPlateau(encoder_optimizer, min_lr=1e-10,factor = 0.5,  patience=0)
dec_scheduler = ReduceLROnPlateau(decoder_optimizer, min_lr=1e-10,factor = 0.5,  patience=0)

loss_func = nn.NLLLoss()

loss_history = []
bleu_history = []

for i in range(args.epochs):
    train_loss = train(args, encoder, decoder, encoder_optimizer, 
                                     decoder_optimizer, loss_func, device, i, 
                                    train_data, val_data, trg)
    
    loss_history.append(train_loss)
    bleu = val(args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func, device, 0, val_data, trg)
    bleu_history.append(bleu)
    

train, epoch: 0, batch number: 0, batch loss: 7.567677804049194
train, epoch: 0, batch number: 10, batch loss: 7.559014344078633
epoch: 0, average loss for epoch: 7.196130547167625, size of last batch 32
['SOS and and and nurses nurses nurses . . . . . . . . . . EOS', 'SOS and and and and . . . . . . . . . . . . . . . . . . . . . . . . . .', 'SOS and . EOS', 'SOS so i . EOS', 'SOS so i . . . EOS', 'SOS so i . . . EOS', 'SOS so i . EOS', 'SOS so i . . EOS', 'SOS so i . . EOS', 'SOS so i . . EOS']
['SOS was the tale told well ? EOS', 'SOS rachel pike : the science behind a climate headline EOS', 'SOS this is not just in hebrew , by the way . EOS', 'SOS it could be anti-bacterial . EOS', 'SOS porous , nonporous . EOS', 'SOS he was looking at us . EOS', 'SOS and what went wrong ? EOS', 'SOS so we were pretty reassured by this . EOS', "SOS it 's easy , isn 't it ? EOS", 'SOS okay , here it is . EOS']
0.026495725316691955
train, epoch: 1, batch number: 0, batch loss: 6.703671155782509
train,

train, epoch: 8, batch number: 0, batch loss: 5.9683431926064445
train, epoch: 8, batch number: 10, batch loss: 6.549615163420172
epoch: 8, average loss for epoch: 5.59719942830187, size of last batch 32
["SOS i 's a do , do , do , do , do , do , do , do ? . EOS", "SOS i 's a do , do , do , do , do , do ? . EOS", "SOS and i 's a a a a a a a a a a a a a a a a a a a a a a a a a a a", "SOS and i 's a a a a a a a a a a a a a a a a a a a a a a a a a a a", "SOS and i 's a a a a a a a a a a a a a a a a a a a a a a a a a a a", "SOS and i 's a a a a a a a a a a a a a a a a a a a a a a a a a a a", "SOS and i 's a a a a a a a a a a a a a a a a a a a a a a a a a a a", "SOS and i 's a a a a a a a a a a a a a a a a a a a a a a a a a a a", "SOS and i 's a a a a a a a a a a a a a a a a a a a a a a a a a a a", "SOS and i 's a a a a a a a a a a a a a a a a a a a a a a a a a a a"]
['SOS was the tale told well ? EOS', 'SOS rachel pike : the science behind a climate headline EOS', 'SOS this is not just in 

train, epoch: 14, batch number: 0, batch loss: 5.959371516290276
train, epoch: 14, batch number: 10, batch loss: 5.819982933496654
epoch: 14, average loss for epoch: 5.241161253957555, size of last batch 32
["SOS and i 'm do , the do , the do , the do , the do , the do , the do , the do , the do , the", "SOS i 'm do , the do , do , do , do , do , do , do , do , do , do , do , do , the", "SOS and i 's the the the the the the the the the the the the the the the the the the the the the the the the the the the", "SOS and i 's the the the the the the the the the the the the the the the the the the the the the the the the the the the", "SOS and i 's the the the the the the the the the the the the the the the the the the the the the the the the the the the", "SOS and i 's the the the the the the the the the the the the the the the the the the the the the the the the the the the", "SOS and i 's the the the the the the the the the the the the the the the the the the the the the the the the the 

train, epoch: 20, batch number: 0, batch loss: 5.621074645785385
train, epoch: 20, batch number: 10, batch loss: 5.608870380467256
epoch: 20, average loss for epoch: 5.015632431457921, size of last batch 32
['SOS but me : me . EOS', "SOS you 's 's 's 's to a a a a a a a a his ? EOS", 'SOS i one to lot to risk . EOS', 'SOS so you didn here years . EOS', "SOS i 's so that 's here brain . EOS", 'SOS and i one , here to reviewed . EOS', 'SOS and i one years to his . EOS', "SOS i 's so that 's here brain . EOS", "SOS you one to look that 's a look , and and and and and and i have to brain . EOS", "SOS i 's so that 's here brain . EOS"]
['SOS was the tale told well ? EOS', 'SOS rachel pike : the science behind a climate headline EOS', 'SOS this is not just in hebrew , by the way . EOS', 'SOS it could be anti-bacterial . EOS', 'SOS porous , nonporous . EOS', 'SOS he was looking at us . EOS', 'SOS and what went wrong ? EOS', 'SOS so we were pretty reassured by this . EOS', "SOS it 's easy , is

train, epoch: 26, batch number: 0, batch loss: 4.9013525783947065
train, epoch: 26, batch number: 10, batch loss: 4.830190158938814
epoch: 26, average loss for epoch: 4.325341474762391, size of last batch 32
["SOS but i look , but that 's to look , but that about to look that about to look that 's up ? EOS", "SOS and you 's years . EOS", "SOS it 's easy , and trials , you 're look . EOS", "SOS and i 're right feral . EOS", 'SOS so i are very simple . EOS', "SOS and i want later , but i 're going , but i 're going , but i 're going , but i 're going , but i 're going to", "SOS and i want later , but i 're going , but i 're going , but i 're going , but i 're going , but i 're going to", 'SOS i are neither feral . EOS', 'SOS you are neither feral . EOS', 'SOS so , you very . EOS']
['SOS was the tale told well ? EOS', 'SOS rachel pike : the science behind a climate headline EOS', 'SOS this is not just in hebrew , by the way . EOS', 'SOS it could be anti-bacterial . EOS', 'SOS porous , non

train, epoch: 33, batch number: 0, batch loss: 4.0849199196274455
train, epoch: 33, batch number: 10, batch loss: 4.107539043827677
epoch: 33, average loss for epoch: 3.525125920016394, size of last batch 32
["SOS i mean , i 'm going to course . EOS", 'SOS no , i do a experiences . EOS', "SOS it 's an expression art . EOS", "SOS so i effect to knowledge builds aircraft , isn trials , and i hated , and i 're going to do nothin , and and and and and they 're", "SOS it 's an riefenstahl , and and and and and and and and and and and and and and and and and and and and and and and and and", "SOS i 've collapsed index . EOS", 'SOS but i want to understand the university . EOS', "SOS it 's an riefenstahl in the cell . EOS", "SOS so i effect to knowledge builds works , and and and and and and and they 're study like grandparents . EOS", 'SOS thank you going to do today . EOS']
['SOS was the tale told well ? EOS', 'SOS rachel pike : the science behind a climate headline EOS', 'SOS this is not j

train, epoch: 41, batch number: 0, batch loss: 3.114648325122267
train, epoch: 41, batch number: 10, batch loss: 3.0908028071970604
epoch: 41, average loss for epoch: 2.477969124902123, size of last batch 32
["SOS i 're going to tell you about 110 degrees . EOS", 'SOS rachel peter moves his arm . EOS', 'SOS this is the matrix . EOS', 'SOS it could be anti-bacterial . EOS', 'SOS porous , nonporous . EOS', 'SOS i love to do is a different , in malaysia . EOS', 'SOS we love to do . EOS', 'SOS so maybe you took to get lives . EOS', "SOS it could make 't . EOS", 'SOS francesca fedeli , ciao . EOS']
['SOS was the tale told well ? EOS', 'SOS rachel pike : the science behind a climate headline EOS', 'SOS this is not just in hebrew , by the way . EOS', 'SOS it could be anti-bacterial . EOS', 'SOS porous , nonporous . EOS', 'SOS he was looking at us . EOS', 'SOS and what went wrong ? EOS', 'SOS so we were pretty reassured by this . EOS', "SOS it 's easy , isn 't it ? EOS", 'SOS okay , here it is

train, epoch: 49, batch number: 0, batch loss: 1.974234735732523
train, epoch: 49, batch number: 10, batch loss: 1.7828185918454231
epoch: 49, average loss for epoch: 1.5408733154548062, size of last batch 32
['SOS i discovered a hidden message . EOS', 'SOS when peter moves his arm . EOS', 'SOS this is the tower in the way . EOS', 'SOS it could be anti-bacterial . EOS', 'SOS porous , nonporous . EOS', 'SOS he was looking at us . EOS', 'SOS what was looking at us . EOS', 'SOS so i were pretty reassured by this . EOS', "SOS it 's easy , isn 't it ? EOS", 'SOS okay , here it is . EOS']
['SOS was the tale told well ? EOS', 'SOS rachel pike : the science behind a climate headline EOS', 'SOS this is not just in hebrew , by the way . EOS', 'SOS it could be anti-bacterial . EOS', 'SOS porous , nonporous . EOS', 'SOS he was looking at us . EOS', 'SOS and what went wrong ? EOS', 'SOS so we were pretty reassured by this . EOS', "SOS it 's easy , isn 't it ? EOS", 'SOS okay , here it is . EOS']
9.