In [1]:
import loader
import argparse
import rnn_models
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torchtext import data
from collections import defaultdict
import numpy as np
import pdb

from torchtext import data
from torchtext import datasets

import io
import os
import string

In [2]:
parser = argparse.ArgumentParser(description='Testing')
parser.add_argument("--max_sentence_length", help="maximum sentence length", type=int, default=50)
parser.add_argument("--min_freq", help="filter out tokens less than min frequency", type=int, default=3)
parser.add_argument("--max_vocab_size", help="at most n tokens in vocabulary", type=int, default=100000)

_StoreAction(option_strings=['--max_vocab_size'], dest='max_vocab_size', nargs=None, const=None, default=100000, type=<class 'int'>, choices=None, help='at most n tokens in vocabulary', metavar=None)

In [3]:
class Args():
    
    max_sentence_length = 50
    min_freq = 3
    max_vocab_size = 100000
    data = 'data'
    hidden_size = 256
    embedding_size = 256
    bidirectional = True
    num_encoder_layers = 2
    num_decoder_layers = 2
    attn_model = 'general'
    lr = 1e-3
    epochs = 5
    batch_size = 32
    clip = 1
    
args = Args()
device = 'cpu'

In [4]:
train_data, val_data, test_data, src, trg = loader.load_data(args)

most common source vocabs: [(',', 128638), ('.', 120849), ('là', 51451), ('và', 47993), ('một', 40378), ('tôi', 38381), ('những', 37809), ('của', 36330), ('có', 26166), ('bạn', 26111)]
source vocab size: 20125
most common english vocabs: [(',', 156165), ('.', 132505), ('the', 109723), ('and', 79673), ('to', 65979), ('of', 60510), ('a', 55374), ('that', 49320), ('i', 43629), ('in', 41318)]
english vocab size: 22443


In [5]:
src_padding_idx = src.vocab.stoi['<pad>']
trg_padding_idx = trg.vocab.stoi['<pad>']
EOS_IDX = trg.vocab.stoi['EOS']

encoder = rnn_models.Encoder(args, src_padding_idx, len(src.vocab))
decoder = rnn_models.LuongAttnDecoderRNN(args, trg_padding_idx, len(trg.vocab))

# initialize weights using gaussian with 0 mean and 0.01 std, just like the paper said
# TODO: Better initialization. Xavier?
for net in [encoder, decoder]:
    for name, param in net.named_parameters(): 
        #print(name, type(param), param)
        if 'bias' in name:
            nn.init.constant_(param, 0.0)
        elif 'weight' in name:
            nn.init.xavier_normal_(param)
            
encoder_optimizer = optim.Adam(encoder.parameters(), lr=args.lr)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=args.lr)

loss_func = nn.NLLLoss()

loss_history = defaultdict(list)
bleu_history = defaultdict(list)

# for i in range(args.epochs):
#     train_loss, val_loss, val_bleu = train_and_val(args, encoder, decoder, encoder_optimizer, 
#                                                    decoder_optimizer, loss_function, device, i, 
#                                                    train_data, val_data, trg, encoder_embedding_dict, 
#                                                    decoder_embedding_dict)

In [6]:
import torch
h = torch.randn(1, 32, 256)
e = torch.randn(17, 32, 256)
energy = torch.bmm(e.transpose(1, 0), h.squeeze(0).unsqueeze(2))
# energy.shape = Size([32, 17, 1])
score = F.softmax(energy, dim = 1).view(1, 32, -1)
context_vector = torch.bmm(score.transpose(1, 0), e.transpose(1, 0))
context_vector.shape

torch.Size([32, 1, 256])

In [7]:
h.squeeze(0).unsqueeze(2).shape

torch.Size([32, 256, 1])

In [8]:
def run_batch(phase, args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func, batch, device):
    if phase == "train":
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
    
    loss = 0
    
    # TODO: it seems that currently batch size is always the same. Make sure to use the last batch
    max_trg_seq_len, batch_size = batch.trg[0].shape
    
    hidden = encoder.random_init_hidden(device, batch_size)
    encoder_outputs, hidden = encoder(hidden, batch.src[0], batch.src[1])
    
    
    ### Decoder
    ### Teacher-forcing
    
    translated_tokens_list = []
    decoder_input = batch.trg[0][0,:] # [2, 2, 2, ..., 2]. List of SOS tokens, batch-sized. 
    translated_tokens_list.append(decoder_input.unsqueeze(0))
    eos_encountered_list = [False]*batch_size
    i = 0
    number_of_loss_calculation = 0
    
    # This step is necessary to get the hidden state from encoder
    #decoder.hidden = encoder.hidden[:decoder.n_layers] # Use last (forward) hidden state from encoder #TODO: verify
    hidden = hidden[:decoder.n_layers]
    
    while ((i+1 < max_trg_seq_len) and (sum(eos_encountered_list) < batch_size)):
        logits, decoder_attn, hidden = decoder(hidden, decoder_input, encoder_outputs)
        
        logits = logits.unsqueeze(0)
        output = F.log_softmax(logits, dim = 2)
        decoder_input = batch.trg[0][i+1,:]
        
        # i+1 represents the current index in all sequences
        for j in range(batch_size):
            if not eos_encountered_list[j]:
                loss += loss_func(output[0, j, :].view(1, -1), batch.trg[0][i+1, j].view(1))
                number_of_loss_calculation += 1
                
                if batch.trg[0][i+1, j] == EOS_IDX:
                    eos_encountered_list[j] = True
                    
        translated_tokens_list.append(decoder_input.unsqueeze(0))  
        i += 1
        
    
    loss.backward() # calculates gradients for both encoder and decoder
    nn.utils.clip_grad_norm_(encoder.parameters(), args.clip)
    nn.utils.clip_grad_norm_(decoder.parameters(), args.clip)
    
    encoder_optimizer.step()
    decoder_optimizer.step()
        
    return loss.item()/number_of_loss_calculation


#               #
# Loss function #
#               #

# loss += loss_func(output[0, j, :].view(1, -1), batch.trg[0][i+1, j].view(1))
                
# so the way NLLLoss is set up, the target is simply the index that you want to predict. 
# and the input can be a softmax over the entire output vocabulary space
# and nllloss calculate loss value between that index between predicted and 
# elementary vector e_target_idx (zeroes everywhere except 1 in target index position)
    

In [9]:
train_iter = data.BucketIterator(
        dataset=train_data, 
        batch_size=args.batch_size,
        repeat=False,
        sort_key=lambda x: len(x.src),
        sort_within_batch=True,
        device=device,
        train=True
    )
    
val_iter = data.BucketIterator(
    dataset=val_data, 
    batch_size=args.batch_size,
    train=False,
    shuffle=False,
    #A key to use for sorting examples in order to batch together 
    # examples with similar lengths and minimize padding.
    sort=True,
    sort_key=lambda x: len(x.src),
    repeat=False,
    sort_within_batch=True,
    device=device
)

encoder.train()
decoder.train()

train_losses = []
batch = next(iter(train_iter))
batch.trg[0]
for i in range(1):
    batch = next(iter(train_iter))
    loss = run_batch('train', args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func, batch, device)
    print(loss)
    
#for i, batch in enumerate(iter(train_iter)):


    

#                    #
# Batch & Dimensions #
#                    #
# `batch` represents a batch of examples. 
# `batch.src` consists of two tensors. 
# The first, `b.src[0]`, is the `src` examples from your batch; it's a tensor with the shape (max_seq_len, batch_size). 
# Your sequences have already been indexed and padded. 
# The second, `b.src[1]`, is the actual lengths of each sequence. It is of shape (batch_size, 1). 

# data.BucketIterator automatically batches sequences of similar lengths together. 
# it also automatically sorts in reverse order. 

# Say you have a bidirectional, 2-layer RNN encoder. A single batch has max length 19 and batch size 32. 
# The encoder_outputs will have shape: (19, 32, 512). 
# Basically, it only returns the topmost layer's hidden states at each step of the sequence. 
# And it concatenates both directional outputs (hidden states) for the topmost layer. 

torch.Size([37, 32, 512])




10.018727222040905


In [10]:
b = next(iter(train_iter))

In [11]:
b.trg[0].shape

torch.Size([50, 32])

In [12]:
# def train_and_val(args, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_function, device, i, 
#                   train_data, val_data, trg, encoder_embedding_dict, decoder_embedding_dict):
    
    
        
    
    