In [3]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
import io
from torch.nn.utils.rnn import pad_sequence
from MyTransformer import MyTransformer
from torch.utils.data import DataLoader
import numpy as np
from util.bleu import get_bleu
'''
Author: Philip Paterson

Referenced PyTorch documentation
'''

'\nAuthor: Philip Paterson\n\nReferenced PyTorch documentation\n'

In [4]:
# Given code
def build_vocab(filepath, tokenizer):
    my_counter = Counter()
    with io.open(filepath, encoding="utf8") as filehandle:
        for str in filehandle:
            my_counter.update(tokenizer(str))
    return Vocab(my_counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

def tokenize_sentence(sentence : str, vocab, tokenizer):
    '''
    Tokenizes the given sentence
    '''
    tok_arr = [vocab['<bos>']]
    for token in tokenizer(sentence.rstrip("\n")):
        if token in vocab:
            tok_arr.append(vocab[token])
        else:
            tok_arr.append(vocab['<unk>'])
    tok_arr.append(vocab['<eos>'])

    return tok_arr

def tokenize_text(iterator, vocab, tokenizer) -> list:
    tokenized_text = []
    for sentence in iterator:
        tokenized_sentence = torch.Tensor(tokenize_sentence(sentence, vocab, tokenizer))
        tokenized_text.append(tokenized_sentence)
    return tokenized_text

def create_batch(each_data_batch, PAD_IDX):
    '''
    Creates a batch
    '''
    de_batch, en_batch = [], []
    for (de_item, en_item) in each_data_batch:
        de_batch.append(de_item)
        en_batch.append(en_item)

    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    return de_batch, en_batch

def run_one_epoch(epoch_index : int, model, dataloader : DataLoader, loss_func, optimizer, pad_idx, split : str, device : str = 'cpu'):
    running_loss = 0.

    for src, tgt in dataloader:
        # Change the device and format
        tgt = tgt.type(torch.LongTensor)
        src = src.to(device=device)
        tgt = tgt.to(device=device)
        
        # Zero the gradients for every batch
        optimizer.zero_grad()

        tgt_input = tgt[:-1]
        tgt_out = tgt[1:]

        # Make the predictions from the forward pass
        logits = model(src=src, trg=tgt_input, memory_key_padding_mask = None, PAD_IDX=torch.tensor(pad_idx, device=device))

        # Compute the loss
        loss = loss_func(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        
        # If we're training the model
        if split == 'TRAIN':
            # Compute loss gradients using backward
            loss.backward()

            # Make adjustments to the learning weights
            optimizer.step()

        # Sum the losses and accuracies
        running_loss += loss.item()

    avg_loss = running_loss / len(dataloader)

    return avg_loss

def train(
        model: nn.Module,
        train_dataloader : DataLoader,
        valid_dataloader : DataLoader,
        test_dataloader : DataLoader,
        loss_func, optimizer,
        pad_idx,
        tgt_vocab,
        device = 'cpu'
    ):

    EPOCHS = 5

    # Referenced https://pieriantraining.com/reversing-keys-and-values-in-a-python-dictionary/ 
    # for reversing a dictionary
    reversed_tgt_vocab = {value: key for key, value in tgt_vocab.items()}
    
    # Initialize parameters
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    # Loop variables
    avg_train_losses = []
    avg_valid_losses = []
    bleu_scores = []
    best_bleu_score = 0

    for epoch in range(EPOCHS):
        print(f"---Epoch {epoch + 1}---")
        
        # Run the training
        model.train(True)
        
        avg_train_loss = run_one_epoch(
            epoch,
            model,
            train_dataloader,
            loss_func,
            optimizer,
            pad_idx,
            'TRAIN',
            device
        )
        
        # Evaluate the model off of the validation dataset
        model.eval()
        with torch.no_grad():
            avg_valid_loss = run_one_epoch(
                epoch,
                model,
                valid_dataloader,
                loss_func,
                optimizer,
                pad_idx,
                'VALID',
                device
            )

        # Calculate the bleu score
        with torch.no_grad():
            bleu_score = calc_bleu_score(
                model=model,
                dataloader=test_dataloader,
                device=device,
                tgt_vocab=tgt_vocab,
                reversed_tgt_vocab=reversed_tgt_vocab
            )

        # Save the best model parameters
        if bleu_score > best_bleu_score:
            best_bleu_score = bleu_score
            torch.save(model.state_dict(), 'saved_model.pt')
        
        # Append the losses and bleu score
        avg_train_losses.append(avg_train_loss)
        avg_valid_losses.append(avg_valid_loss)
        bleu_scores.append(bleu_score)

        # Print the per epoch losses and accuracies
        print(f"Loss: | Train: {avg_train_loss:.5f} | Validation: {avg_valid_loss:.5f}")
        print(f"BLEU Score: {bleu_score}")
    
    stats = {
        'training' : {
            'losses' : avg_train_losses,
        },
        'validation' : {
            'losses' : avg_valid_losses,
        },
        'testing:' : {
            'bleu' : bleu_scores
        },
        'epochs' : np.arange(0, EPOCHS)
    }

    return stats

def test_translator(src_data_path : str, tgt_data_path : str, src_vocab_path : str, tgt_vocab_path : str, model_param_path : str):
    '''
    TEST FUNCTION For testing the translator model.

    Only works for testing translation from German to English.
    '''

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    
    SRC_TOKENIZER_LANGUAGE = 'de_core_news_sm'
    TGT_TOKENIZER_LANGUAGE = 'en_core_web_sm'

    # Get the tokenizer
    src_tokenizer = get_tokenizer('spacy', language=SRC_TOKENIZER_LANGUAGE)
    tgt_tokenizer = get_tokenizer('spacy', language=TGT_TOKENIZER_LANGUAGE)

    # Get the vocabulary
    src_vocab = torch.load(src_vocab_path, src_tokenizer).stoi
    tgt_vocab = torch.load(tgt_vocab_path, tgt_tokenizer).stoi
    
    # Get the data
    data = get_data(
        src_data_path=src_data_path,
        tgt_data_path=tgt_data_path,
        src_vocab=src_vocab,
        tgt_vocab=tgt_vocab,
        src_tokenizer=src_tokenizer,
        tgt_tokenizer=tgt_tokenizer,
    )
    dataloader = DataLoader(
        data,
        batch_size=1,
        shuffle=False,
        collate_fn=lambda batch: create_batch(batch, src_vocab['<pad>'])
    )

    # Define and create the model
    NUM_ENCODER_LAYERS = 3
    NUM_DECODER_LAYERS = 3
    EMB_SIZE = 512
    NHEAD = 8
    FFN_HID_DIM = 512

    model = MyTransformer(
        num_encoder_layers=NUM_ENCODER_LAYERS,
        num_decoder_layers=NUM_DECODER_LAYERS,
        emb_size=EMB_SIZE,
        nhead=NHEAD,
        src_vocab_size=len(src_vocab),
        tgt_vocab_size=len(tgt_vocab),
        dim_feedforward=FFN_HID_DIM
    ).to(device)

    model.load_state_dict(torch.load(model_param_path, map_location=device))

    model.eval()

    # Create a reversed dictionary for calculating the bleu score
    reversed_tgt_vocab = {value: key for key, value in tgt_vocab.items()}

    # Perform inference
    with torch.no_grad():
        bleu_score = calc_bleu_score(
                model=model,
                dataloader=dataloader,
                device=device,
                tgt_vocab=tgt_vocab,
                reversed_tgt_vocab=reversed_tgt_vocab
            )
        
    print(f"Testing BLEU Score: {bleu_score}")
    return bleu_score
    

def get_data(src_data_path, tgt_data_path, src_vocab, tgt_vocab, src_tokenizer, tgt_tokenizer):
    '''
    Generic function to get data. src would be the you're translating from,
    and tgt is the language you're translating into.
    '''

    src_data_raw_iter = iter(io.open(src_data_path, encoding="utf8"))
    tgt_data_raw_iter = iter(io.open(tgt_data_path, encoding="utf8"))

    # Tokenize the sentences
    src_tokenized = tokenize_text(src_data_raw_iter, src_vocab, src_tokenizer)
    tgt_tokenized = tokenize_text(tgt_data_raw_iter, tgt_vocab, tgt_tokenizer)

    # Group the data together
    data = list(zip(src_tokenized, tgt_tokenized))
    
    return data

def get_dataloader(src_data_path, tgt_data_path, src_vocab, tgt_vocab, src_tokenizer, tgt_tokenizer, batch_size):
    data = get_data(
        src_data_path=src_data_path,
        tgt_data_path=tgt_data_path,
        src_vocab=src_vocab,
        tgt_vocab=tgt_vocab,
        src_tokenizer=src_tokenizer,
        tgt_tokenizer=tgt_tokenizer
    )

    dataloader = DataLoader(
        data,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda batch: create_batch(batch, src_vocab['<pad>'])
    )

    return dataloader

def calc_bleu_score(model : nn.Module, dataloader : DataLoader, device, tgt_vocab, reversed_tgt_vocab):
    SENTENCE_GENERATION_MAX = 500
    
    pad_idx = tgt_vocab['<pad>']

    running_bleu = 0.
    sent_idx = 0 # Keeps track of how many sentences generated
    for src, tgt in dataloader:
        # Stops the generation of sentences
        if sent_idx >= SENTENCE_GENERATION_MAX:
            break

        tgt = tgt.type(torch.LongTensor)
        src = src.to(device=device)
        tgt = tgt.to(device=device)

        # print("SRC SHAPE:", src.shape) # TODO: Remove

        # pred_sent_beginning = [[tgt_vocab['<bos>']]] * src.size(0)
        next_tok = tgt_vocab['<bos>']
        pred_sent_beginning = [[tgt_vocab['<bos>']]]

        pred_sent = torch.tensor(pred_sent_beginning).to(device=device) # Initialize the predicted sentence

        # Iteratively test to build the prediction sentence
        while len(pred_sent) <= len(tgt) and next_tok != tgt_vocab['<eos>']:
            logits = model(src=src, trg=pred_sent, memory_key_padding_mask=None, PAD_IDX=torch.tensor(pad_idx, device=device))
            next_tok = torch.argmax(logits[-1][0], dim=-1)
            pred_sent = torch.cat((pred_sent, next_tok.unsqueeze(0).unsqueeze(0)), dim=0)

        # Convert the tokenized predicted and target sentences to list of words
        trg_sent_list = detokenize(tgt.squeeze(1).tolist(), reversed_tgt_vocab)
        pred_sent_list = detokenize(pred_sent.squeeze(1).tolist(), reversed_tgt_vocab)

        # Get the bleu scores
        bleu = get_bleu(hypotheses=pred_sent_list, reference=trg_sent_list)

        running_bleu += bleu
        
        sent_idx += 1
    
    average_total_bleu = running_bleu / len(dataloader) # Calculate the average bleu score

    return average_total_bleu

def detokenize(tokenized_sentence, reversed_vocab):
    '''
    Takes a tokenized sentence and converts it to a sentence of words.
    '''

    # Build the translated sentence
    sentence = []
    for token in tokenized_sentence:
        sentence.append(reversed_vocab[token])

    return sentence

In [5]:
# TODO: Define a main function!!!

german_train_path = 'data/train.de'
english_train_path = 'data/train.en'
german_test_path = 'data/test.de'
english_test_path = 'data/test.en'
german_valid_path = 'data/val.de'
english_valid_path = 'data/val.en'

german_vocab_path = 'German_vocab.pth'
english_vocab_path = 'English_vocab.pth'

In [6]:
BATCH_SIZE = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Get the tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

# Get the vocabulary
de_vocab = torch.load(german_vocab_path, de_tokenizer).stoi
en_vocab = torch.load(english_vocab_path, en_tokenizer).stoi

PAD_IDX = de_vocab['<pad>']

# Get the tokenized dataloader
train_dataloader = get_dataloader(
    src_data_path=german_train_path,
    tgt_data_path=english_train_path,
    src_vocab=de_vocab,
    tgt_vocab=en_vocab,
    src_tokenizer=de_tokenizer,
    tgt_tokenizer=en_tokenizer,
    batch_size=BATCH_SIZE
)

# Get the validation dataloader
valid_dataloader = get_dataloader(
    src_data_path=german_valid_path,
    tgt_data_path=english_valid_path,
    src_vocab=de_vocab,
    tgt_vocab=en_vocab,
    src_tokenizer=de_tokenizer,
    tgt_tokenizer=en_tokenizer,
    batch_size=BATCH_SIZE
)

# Get the test datalaoder
test_data = get_data(
    src_data_path=german_test_path,
    tgt_data_path=english_test_path,
    src_vocab=de_vocab,
    tgt_vocab=en_vocab,
    src_tokenizer=de_tokenizer,
    tgt_tokenizer=en_tokenizer,
)
test_dataloader = DataLoader(
    test_data,
    batch_size=1,
    shuffle=False,
    collate_fn=lambda batch: create_batch(batch, de_vocab['<pad>'])
)

In [7]:
# Define the model architecture
'''
1. No. of encoder layers: 3.
2. No. of decoder layers: 3.
3. Embedding dimension: 512.
4. Feedforward dimension: 512.
5. No. of Attention head: 8.
'''
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512

model = MyTransformer(
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    emb_size=EMB_SIZE,
    nhead=NHEAD,
    src_vocab_size=len(de_vocab),
    tgt_vocab_size=len(en_vocab),
    dim_feedforward=FFN_HID_DIM
).to(device)

# Define the loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

# Train the model
stats = train(
    model=model,
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    test_dataloader=test_dataloader,
    loss_func=loss_fn,
    optimizer=optimizer,
    pad_idx=PAD_IDX,
    tgt_vocab=de_vocab,
    device=device
)



---Epoch 1---




KeyboardInterrupt: 

In [8]:
# Test the saved model parameters
test_translator(
    src_data_path=german_test_path,
    tgt_data_path=english_test_path,
    src_vocab_path=german_vocab_path,
    tgt_vocab_path=english_vocab_path,
    model_param_path='saved_model.pt'
)



Testing BLEU Score: 38.08720844546422


38.08720844546422