In [1]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
import io
from torch.nn.utils.rnn import pad_sequence
from MyTransformer import MyTransformer
from torch.utils.data import DataLoader
import numpy as np

In [2]:
# Given code
def build_vocab(filepath, tokenizer):
    my_counter = Counter()
    with io.open(filepath, encoding="utf8") as filehandle:
        for str in filehandle:
            my_counter.update(tokenizer(str))
    return Vocab(my_counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

def tokenize_sentence(sentence : str, vocab, tokenizer):
    '''
    Tokenizes the given sentence
    '''
    tok_arr = [vocab['<bos>']]
    for token in tokenizer(sentence.rstrip("\n")):
        if token in vocab:
            tok_arr.append(vocab[token])
        else:
            tok_arr.append(vocab['<unk>'])
    tok_arr.append(vocab['<eos>'])

    return tok_arr

def tokenize_text(iterator, vocab, tokenizer) -> list:
    tokenized_text = []
    for sentence in iterator:
        tokenized_sentence = torch.Tensor(tokenize_sentence(sentence, vocab, tokenizer))
        tokenized_text.append(tokenized_sentence)
    return tokenized_text

def create_batch(each_data_batch, PAD_IDX):
    '''
    Creates a batch
    '''
    de_batch, en_batch = [], []
    for (de_item, en_item) in each_data_batch:
        de_batch.append(de_item)
        en_batch.append(en_item)

    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    return de_batch, en_batch

def run_one_epoch(epoch_index : int, model, dataloader : DataLoader, loss_func, optimizer, pad_idx, split : str, device : str = 'cpu'):
    running_loss = 0.

    for src, tgt in dataloader:
        # Change the device and format
        tgt = tgt.type(torch.LongTensor)
        src = src.to(device=device)
        tgt = tgt.to(device=device)
        
        # Zero the gradients for every batch
        optimizer.zero_grad()

        tgt_input = tgt[:-1]
        tgt_out = tgt[1:]

        # Make the predictions from the forward pass
        logits = model(src=src, trg=tgt_input, memory_key_padding_mask = None, PAD_IDX=torch.tensor(pad_idx, device=device))

        # Compute the loss
        loss = loss_func(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        
        # If we're training the model
        if split == 'TRAIN':
            # Compute loss gradients using backward
            loss.backward()

            # Make adjustments to the learning weights
            optimizer.step()

        # Sum the losses and accuracies
        running_loss += loss.item()

    avg_loss = running_loss / len(dataloader)

    return avg_loss

def train(model: nn.Module, train_dataloader : DataLoader, valid_dataloader : DataLoader, loss_func, optimizer, pad_idx, device = 'cpu'):
    EPOCHS = 5

    avg_train_losses = []

    avg_valid_losses = []

    best_valid_accuracy = 0
    
    # Initialize parameters
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    model.train()

    for epoch in range(EPOCHS):
        print(f"---Epoch {epoch + 1}---")
        # Ensure gradient tracking is on
        model.train(True)
        
        avg_train_loss = run_one_epoch(
            epoch,
            model,
            train_dataloader,
            loss_func,
            optimizer,
            pad_idx,
            'TRAIN',
            device
        )
        
        # Evaluate the model off of the validation dataset
        model.eval()
        with torch.no_grad():
            avg_valid_loss = run_one_epoch(
                epoch,
                model,
                valid_dataloader,
                loss_func,
                optimizer,
                pad_idx,
                'VALID',
                device
            )

        # TODO: IMPLEMENT LATER
        # # Save the best model parameters
        # if avg_valid_accuracy > best_valid_accuracy:
        #     best_valid_accuracy = avg_valid_accuracy
        #     torch.save(model.state_dict(), 'saved_model.pt')
        
        # Print the per epoch losses and accuracies
        print(f"Loss:     | Train: {avg_train_loss:.5f}     | Validation: {avg_valid_loss:.5f}")
    
    stats = {
        'training' : {
            'losses' : avg_train_losses,
        },
        'validation' : {
            'losses' : avg_valid_losses,
        },
        'epochs' : np.arange(0, EPOCHS)
    }

    return stats

def calc_acc(prediction, actual):
    # Calculate accuracy
    accuracy = torch.sum(torch.argmax(prediction, 1) == torch.argmax(actual, 1))/actual.shape[0]
    return accuracy.item()

def get_data(src_data_path, trg_data_path, src_vocab, trg_vocab, src_tokenizer, trg_tokenizer):
    '''
    Generic function to get data. src would be the you're translating from,
    and trg is the language you're translating into.
    '''

    src_data_raw_iter = iter(io.open(src_data_path, encoding="utf8"))
    trg_data_raw_iter = iter(io.open(trg_data_path, encoding="utf8"))

    # Tokenize the sentences
    src_tokenized = tokenize_text(src_data_raw_iter, src_vocab, src_tokenizer)
    trg_tokenized = tokenize_text(trg_data_raw_iter, trg_vocab, trg_tokenizer)

    # Group the data together
    data = list(zip(src_tokenized, trg_tokenized))
    
    return data

def get_dataloader(src_data_path, trg_data_path, src_vocab, trg_vocab, src_tokenizer, trg_tokenizer, batch_size):
    data = get_data(
        src_data_path=src_data_path,
        trg_data_path=trg_data_path,
        src_vocab=src_vocab,
        trg_vocab=trg_vocab,
        src_tokenizer=src_tokenizer,
        trg_tokenizer=trg_tokenizer
    )

    dataloader = DataLoader(
        data,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda batch: create_batch(batch, src_vocab['<pad>'])
    )

    return dataloader

In [3]:
german_train_path = 'data/train.de'
english_train_path = 'data/train.en'
german_test_path = 'data/test.de'
english_test_path = 'data/test.en'
german_valid_path = 'data/val.de'
english_valid_path = 'data/val.en'

german_vocab_path = 'German_vocab.pth'
english_vocab_path = 'English_vocab.pth'

In [4]:
BATCH_SIZE = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Get the tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

# Get the vocabulary
de_vocab = torch.load(german_vocab_path, de_tokenizer).stoi
en_vocab = torch.load(english_vocab_path, en_tokenizer).stoi

PAD_IDX = de_vocab['<pad>']

# Get the tokenized dataloader
train_dataloader = get_dataloader(
    src_data_path=german_train_path,
    trg_data_path=english_train_path,
    src_vocab=de_vocab,
    trg_vocab=en_vocab,
    src_tokenizer=de_tokenizer,
    trg_tokenizer=en_tokenizer,
    batch_size=BATCH_SIZE
)

# Get the validation dataloader
valid_dataloader = get_dataloader(
    src_data_path=german_valid_path,
    trg_data_path=english_valid_path,
    src_vocab=de_vocab,
    trg_vocab=en_vocab,
    src_tokenizer=de_tokenizer,
    trg_tokenizer=en_tokenizer,
    batch_size=BATCH_SIZE
)

In [5]:
# Define the model architecture
'''
1. No. of encoder layers: 3.
2. No. of decoder layers: 3.
3. Embedding dimension: 512.
4. Feedforward dimension: 512.
5. No. of Attention head: 8.
'''
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512

model = MyTransformer(
    num_encoder_layers=NUM_DECODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    emb_size=EMB_SIZE,
    nhead=NHEAD,
    src_vocab_size=len(de_vocab),
    tgt_vocab_size=len(en_vocab),
    dim_feedforward=FFN_HID_DIM
).to(device)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

train(
    model=model,
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    loss_func=loss_fn,
    optimizer=optimizer,
    pad_idx=PAD_IDX,
    device=device
)



---Epoch 1---




Loss:     | Train: 4.27580     | Validation: 3.27088
---Epoch 2---
Loss:     | Train: 3.03120     | Validation: 2.65624
---Epoch 3---
