In [1]:
import pickle as pkl
import numpy as np
import gzip
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import time
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Import Dictionaries and Data

In [2]:
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pkl.load(f)
        return loaded_object

In [3]:
# Print sentence given numbers
def ids2sentence(sentence, dictionary):
    return [dictionary[i] for i in sentence]
#ids2sentence(en_train_num[0], id2word_en_dic)

def add_symbol(id2word_dic, word2id_dic):
    symbols = ['<pad>', '<unk>', '<sos>', '<eos>']
    for i, symbol in enumerate(symbols):
        id2word_dic[i] = symbol
        word2id_dic[symbol] = i
    return id2word_dic, word2id_dic

id2word_vi_dic = load_zipped_pickle("../embeddings/id2word_vi_dic.p")
word2id_vi_dic = load_zipped_pickle("../embeddings/word2id_vi_dic.p")

id2word_en_dic = load_zipped_pickle("../embeddings/id2word_en_dic.p")
word2id_en_dic = load_zipped_pickle("../embeddings/word2id_en_dic.p")

id2word_vi_dic, word2id_vi_dic = add_symbol(id2word_vi_dic, word2id_vi_dic)
id2word_en_dic, word2id_en_dic = add_symbol(id2word_en_dic, word2id_en_dic)

vi_train = load_zipped_pickle("../data/vi-en-tokens/train_vi_tok.p")
en_train = load_zipped_pickle("../data/vi-en-tokens/train_en_tok.p") # Already Processed for symbols

vi_train_num = load_zipped_pickle("../data/vi-en-tokens/train_vi_tok_num.p")
en_train_num = load_zipped_pickle("../data/vi-en-tokens/train_en_tok_num.p") # Already Processed for symbols

## Padding Data

### Sort by input data length

In [4]:
def sort_by_length(data_input, target_data):
    input_size = [len(data) for data in data_input]
    size_index = np.argsort(input_size)
    return list(np.array(data_input)[size_index]), list(np.array(target_data)[size_index])

vi_train_num, en_train_num = sort_by_length(vi_train_num, en_train_num)

### Padding Data given batch size

In [5]:
def pad(data, length):
    # Cap maximum length at 100
    length = min(100, length)
    for i, line in enumerate(data):
        if len(line) < length:
            for i in range(len(line), length):
                line.append(0)
        else:
            data[i] = line[0:length]
    return data

# Return the batch data and target
def get_batch(i, batch_size, train_data, train_target):
    if i * batch_size > len(train_data):
        raise Exception('Incorrect batch index')
    start_idx = i * batch_size
    end_idx = (i + 1) * batch_size
    batch_data = list(np.array(train_data)[start_idx:end_idx])
    batch_target = list(np.array(train_target)[start_idx:end_idx])
    batch_data = pad(batch_data, len(batch_data[batch_size - 1]))
    max_target = max([len(data) for data in batch_data])
    batch_target = pad(batch_target, max_target)
    return batch_data, batch_target

# get_batch(5, 64, vi_train_num, en_train_num)

## Models

### Encoder

In [6]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, batch_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        # input_size: input dictionary size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.num_layers = num_layers
        self.gru = nn.GRU(hidden_size, 
                          hidden_size, 
                          num_layers= num_layers, 
                          batch_first = True) # BATCH FIRST

    def forward(self, encoder_input, hidden_input):
        # encoder_input: batch * 1 (for 1 word each time)
        embedded_input = self.embedding(encoder_input)
        # embedded_input: batch * 1 * emb_dim
        # hidden_input: batch * 1(layer) * hidden_size
        output, hidden = self.gru(embedded_input, hidden_input)
        return output, hidden

    def initHidden(self):
        return torch.zeros(self.num_layers, self.batch_size, self.hidden_size, device=device)

### Decoder

In [7]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, num_layers, batch_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        # output_size: input dictionary size
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, 
                          hidden_size,
                          num_layers= num_layers, 
                          batch_first = True) # BATCH_FRIST
        self.out = nn.Linear(hidden_size, output_size)
        # self.softmax = nn.LogSoftmax(dim=1) # Use cross entropy loss outside

    def forward(self, decoder_input, hidden_input):
        # decoder_input: batch * 1
        embedded_input = self.embedding(decoder_input)
        # embedded_input: batch * 1 * emb_dim
        embedded_input = F.relu(embedded_input)
        # hidden_input: batch * hidden_size
        output, hidden = self.gru(embedded_input, hidden_input)
        output = self.out(output)
        # output = self.softmax(output) # not using softmax
        return output, hidden

    def initHidden(self):
        return torch.zeros(self.num_layers, self.batch_size, self.hidden_size, device=device)

## Training

In [8]:
def train(train_input, train_target, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, batch_size):
    
    start_time = time.time()
    
    # Batch
    for i in range(len(train_input) // batch_size):
        loss = 0
        encoder_hidden = encoder.initHidden()
        
        batch = get_batch(i, batch_size, train_input, train_target)
        # size batch_size * seq_length
        batch_input = torch.tensor(batch[0], device=device)
        batch_target = torch.tensor(batch[1], device=device)
        input_length = batch_input.shape[1] ## should be seq length
        target_length = batch_target.shape[1]

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        
        encoder_outputs = torch.zeros(input_length, batch_size, 1, 256, device=device)
        encoder_hiddens = torch.zeros(input_length, 1, batch_size, 256, device=device)
        
        # Encode
        for ec_idx in range(input_length):
            # input batch_size * 1
            encoder_output, encoder_hidden = encoder(batch_input[:, ec_idx].unsqueeze(1), encoder_hidden)
            encoder_outputs[ec_idx] = encoder_output
            encoder_hiddens[ec_idx] = encoder_hidden
        
        # Decode
        decoder_input = torch.tensor([2] * batch_size, device=device) # SOS token 2
        decoder_hidden = encoder_hidden
        
        ## Print Value
        sample_sentence = []
        
        # Always use Teacher Forcing
        for dc_idx in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input.unsqueeze(1), decoder_hidden)
            decoder_output = decoder_output.squeeze(1) # get rid of the seq dimention
            loss += criterion(decoder_output, batch_target[:, dc_idx])
            decoder_input = batch_target[:, dc_idx]
            
            if i % 100 == 0:
                ## Print Value
                sample_sentence.append(torch.argmax(decoder_output[0]).item())
            
        loss.backward()
        
        encoder_optimizer.step()
        decoder_optimizer.step()
        
        if i % 50 == 0:
            s = int(time.time() - start_time)
            m = math.floor(s / 60)
            s = s - m * 60
            print('Time: ', m, 'mins', s, 'seconds' , ' Training Loss: ', loss.item() / target_length, 'Progress: ', round(i / (len(train_input) // batch_size) * 100, 2), '%')
            if i % 200 == 0:
                print("Predict: ", ids2sentence(sample_sentence, id2word_en_dic))
                print("Actual: ", ids2sentence(batch_target[0].cpu().numpy(), id2word_en_dic))
        
    print('Training Complete')

## Parameters

In [9]:
dic_size_vi = len(id2word_vi_dic.keys())
dic_size_en = len(id2word_en_dic.keys())
hidden_size = 256
learning_rate = 0.01
batch_size = 64

## Add ignore index
criterion = nn.CrossEntropyLoss()

#encoder = EncoderRNN(input_size = dic_size_vi, hidden_size = hidden_size, num_layers = 1, batch_size = batch_size).to(device)
#decoder = DecoderRNN(hidden_size = hidden_size, output_size = dic_size_en, num_layers = 1, batch_size = batch_size).to(device)

encoder = pkl.load(open("./model/encoder.p", "rb"))
decoder = pkl.load(open("./model/decoder.p", "rb"))

encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

In [12]:
for i in range(50):
    train(vi_train_num, en_train_num, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, batch_size)
    if i % 2 == 0:
        pkl.dump(encoder, open("./model/encoder.p", "wb"))
        pkl.dump(decoder, open("./model/decoder.p", "wb"))



Time:  0 mins 0 seconds  Training Loss:  8.157244682312012 Progress:  0.0 %
Predict:  ['<sos>', 'the']
Actual:  ['<sos>', '<eos>']
Time:  0 mins 9 seconds  Training Loss:  2.1544600895472934 Progress:  2.4 %
Time:  0 mins 18 seconds  Training Loss:  2.984735276963976 Progress:  4.8 %
Time:  0 mins 27 seconds  Training Loss:  2.839649200439453 Progress:  7.2 %
Time:  0 mins 36 seconds  Training Loss:  2.814463806152344 Progress:  9.6 %
Predict:  ['<sos>', 'and', 'i', 'm', 'walking', 'to', '<eos>', '<pad>', '<pad>', '<pad>']
Actual:  ['<sos>', 'and', 'i', 'was', 'thrilled', '.', '<eos>', '<pad>', '<pad>', '<pad>']
Time:  0 mins 45 seconds  Training Loss:  2.8500487587668677 Progress:  12.0 %
Time:  0 mins 55 seconds  Training Loss:  2.8859275182088218 Progress:  14.4 %
Time:  1 mins 4 seconds  Training Loss:  2.8084589640299478 Progress:  16.8 %
Time:  1 mins 14 seconds  Training Loss:  2.781269366924579 Progress:  19.2 %
Predict:  ['<sos>', 'the', 'same', 'is', 'the', 'of', 'the', 'litt

KeyboardInterrupt: 

In [11]:
# overfit_vi_train = vi_train_num[10000:10002]
# overfit_en_train = en_train_num[10000:10002]

# for i in range(100):
#     train(overfit_vi_train, overfit_en_train, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, batch_size)