In [1]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from timeit import default_timer as timer
from attention import transformer

import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np

In [2]:
# Set seed.
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [3]:
# We need to modify the URLs for the dataset since the links to the original dataset are broken
# Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

# Place-holders
token_transform = {}
vocab_transform = {}

In [4]:
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

    
# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(
        yield_tokens(train_iter, ln),
        min_freq=1,
        specials=special_symbols,
        special_first=True,
    )

# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

2023-06-06 07:40:46.238670: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-06 07:40:46.734559: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda-11.4/lib64:
2023-06-06 07:40:46.734692: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda-11.4/lib64:
2023-06-06 07:40:47.320748: I 

In [5]:
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
MAX_LEN = 512
NUM_ENCODER_LAYERS = 3
DEVICE = 'cuda'
NUM_EPOCHS = 30
# DEVICE = 'cpu'

model = transformer.Transformer(
    embed_dim=EMB_SIZE,
    src_vocab_size=SRC_VOCAB_SIZE,
    tgt_vocab_size=TGT_VOCAB_SIZE,
    seq_len=MAX_LEN,
    num_layers=NUM_ENCODER_LAYERS,
    n_heads=NHEAD,
    device=DEVICE
).to(DEVICE)

# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

36,035,349 total parameters.
36,035,349 training parameters.


In [6]:
# for p in model.parameters():
#     if p.dim() > 1:
#         nn.init.xavier_uniform_(p)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [7]:
# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
    return src_batch, tgt_batch

In [8]:
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
def train_epoch(model, optimizer):
    model.train()
    losses = 0

    for src, tgt in train_dataloader:
        # print(" ".join(vocab_transform[SRC_LANGUAGE].lookup_tokens(list(src[0].cpu().numpy()))).replace("<bos>", "").replace("<eos>", ""))
        # print(" ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt[0].cpu().numpy()))).replace("<bos>", "").replace("<eos>", ""))
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        
        tgt_input = tgt[:, :-1]

        logits = model(src, tgt_input)

        optimizer.zero_grad()

        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.view(-1, TGT_VOCAB_SIZE), tgt_out.contiguous().view(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))


val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
def evaluate(model):
    model.eval()
    losses = 0

    for src, tgt in val_dataloader:
        # print(" ".join(vocab_transform[SRC_LANGUAGE].lookup_tokens(list(src[0].cpu().numpy()))).replace("<bos>", "").replace("<eos>", ""))
        # print(" ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt[0].cpu().numpy()))).replace("<bos>", "").replace("<eos>", ""))
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        
        tgt_input = tgt[:, :-1]
        
        logits = model(src, tgt_input)

        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.view(-1, TGT_VOCAB_SIZE), tgt_out.contiguous().view(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [9]:
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)
    end_time = timer()
    val_loss = evaluate(model)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

Epoch: 1, Train loss: 5.828, Val loss: 5.038, Epoch time = 15.031s
Epoch: 2, Train loss: 4.868, Val loss: 4.625, Epoch time = 14.707s
Epoch: 3, Train loss: 4.472, Val loss: 4.246, Epoch time = 14.515s
Epoch: 4, Train loss: 4.185, Val loss: 4.010, Epoch time = 14.542s
Epoch: 5, Train loss: 3.977, Val loss: 3.802, Epoch time = 14.630s
Epoch: 6, Train loss: 3.804, Val loss: 3.657, Epoch time = 14.668s
Epoch: 7, Train loss: 3.657, Val loss: 3.512, Epoch time = 14.672s
Epoch: 8, Train loss: 3.533, Val loss: 3.393, Epoch time = 14.439s
Epoch: 9, Train loss: 3.423, Val loss: 3.293, Epoch time = 14.514s
Epoch: 10, Train loss: 3.329, Val loss: 3.200, Epoch time = 14.858s
Epoch: 11, Train loss: 3.246, Val loss: 3.138, Epoch time = 14.746s
Epoch: 12, Train loss: 3.169, Val loss: 3.058, Epoch time = 14.853s
Epoch: 13, Train loss: 3.096, Val loss: 3.004, Epoch time = 14.839s
Epoch: 14, Train loss: 3.030, Val loss: 2.942, Epoch time = 14.758s
Epoch: 15, Train loss: 2.973, Val loss: 2.888, Epoch time

In [10]:
import os
os.makedirs('outputs', exist_ok=True)
torch.save(model, 'outputs/model.pth')

## Inference

In [11]:
# # function to generate output sequence using greedy algorithm
# def greedy_decode(model, src, start_symbol):
#     src = src.to(DEVICE)
#     ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
#     out = model.decode(torch.ravel(src).unsqueeze(0), ys)
#     print(out)
#     return out

# # actual function to translate input sentence into target language
# def translate(model: torch.nn.Module, src_sentence: str):
#     model.eval()
#     src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
#     num_tokens = src.shape[0]
#     tgt_tokens = greedy_decode(
#         model,  src, start_symbol=BOS_IDX)
#     return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens))).replace("<bos>", "").replace("<eos>", "")

In [12]:
# print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))

In [13]:
from attention.transformer import TransformerDecoder, TransformerEncoder

In [14]:
def make_tgt_mask(tgt, pad_token_id=1):
    """
    :param tgt: Target sequence.
    Returns:
        tgt_mask: Target mask.
    """
    batch_size = tgt.shape[0]
    device = tgt.device

    # Same as src_mask but we additionally want to mask tokens from looking forward into the future tokens
    # Note: wherever the mask value is true we want to attend to that token, otherwise we mask (ignore) it.
    sequence_length = tgt.shape[1]  # trg_token_ids shape = (B, T) where T max trg token-sequence length
    trg_padding_mask = (tgt != pad_token_id).view(batch_size, 1, 1, -1)  # shape = (B, 1, 1, T)
    trg_no_look_forward_mask = torch.triu(torch.ones((1, 1, sequence_length, sequence_length), device=device) == 1).transpose(2, 3)

    # logic AND operation (both padding mask and no-look-forward must be true to attend to a certain target token)
    tgt_mask = trg_padding_mask & trg_no_look_forward_mask  # final shape = (B, 1, T, T)
    return tgt_mask
    
def make_src_mask(src, pad_token_id=1):
    """
    :param src: Source sequence.

    Returns:
        src_mask: Source mask.
    """
    batch_size = src.shape[0]

    # src_mask shape = (B, 1, 1, S) check out attention function in transformer_model.py where masks are applied
    # src_mask only masks pad tokens as we want to ignore their representations (no information in there...)
    src_mask = (src != pad_token_id).view(batch_size, 1, 1, -1)
    return src_mask

In [15]:
decoder = TransformerDecoder(
            TGT_VOCAB_SIZE,
            EMB_SIZE,
            MAX_LEN,
            NUM_ENCODER_LAYERS,
            expansion_factor=4,
            n_heads=NHEAD
        ).to(DEVICE).eval()

In [16]:
decoder.load_state_dict(model.decoder.state_dict())

<All keys matched successfully>

In [17]:
encoder = TransformerEncoder(
            MAX_LEN,
            SRC_VOCAB_SIZE,
            EMB_SIZE,
            NUM_ENCODER_LAYERS,
            expansion_factor=4,
            n_heads=NHEAD
        ).to(DEVICE).eval()

In [18]:
encoder.load_state_dict(model.encoder.state_dict())

<All keys matched successfully>

In [19]:
def decode(src, tgt):
    """
    :param src: Encoder input
    :param tgt: Decoder input

    Returns:
        out_labels: Final prediction sequence
    """
    tgt_mask = make_tgt_mask(tgt).to(DEVICE)
    src_mask = make_src_mask(src).to(DEVICE)
    enc_out = encoder(src)
    out_labels = []
    batch_size, seq_len = src.shape[0], src.shape[1]
    out = tgt
    with torch.no_grad():
        for i in range(seq_len):
            if i != 0:
                tgt = torch.tensor(out_labels, dtype=torch.long).unsqueeze(0).to(DEVICE)
                print(tgt)
                out = decoder(torch.tensor(tgt).to(DEVICE), enc_out, src_mask, tgt_mask)
            else:
                out = decoder(out, enc_out, src_mask, tgt_mask)
            # print(out.shape)
            # out = out[:, -1, :]
            out = out.reshape(-1, out.shape[-1])
            # print(out.shape)
            # out = out.argmax(-1)
            num_of_trg_tokens = len(tgt[0])
            out = out[num_of_trg_tokens-1::num_of_trg_tokens]
            out = torch.argmax(out, dim=-1)
            out_labels.append(out.item())
            out = torch.unsqueeze(out, 0)
        return out_labels

In [20]:
src_sentence = "Eine Gruppe von Menschen steht vor einem Iglu ."
start_symbol = BOS_IDX
src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
num_tokens = src.shape[0]
src = src.to(DEVICE)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
out = decode(torch.ravel(src).unsqueeze(0), ys)
print(out)
print(" ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(out))).replace("<bos>", "").replace("<eos>", ""))

tensor([[6]], device='cuda:0')
tensor([[ 6, 39]], device='cuda:0')
tensor([[ 6, 39, 13]], device='cuda:0')
tensor([[ 6, 39, 13, 22]], device='cuda:0')
tensor([[ 6, 39, 13, 22, 37]], device='cuda:0')
tensor([[ 6, 39, 13, 22, 37,  7]], device='cuda:0')
tensor([[ 6, 39, 13, 22, 37,  7, 44]], device='cuda:0')
tensor([[ 6, 39, 13, 22, 37,  7, 44, 13]], device='cuda:0')
tensor([[ 6, 39, 13, 22, 37,  7, 44, 13,  4]], device='cuda:0')
tensor([[  6,  39,  13,  22,  37,   7,  44,  13,   4, 208]], device='cuda:0')
[6, 39, 13, 22, 37, 7, 44, 13, 4, 208, 5]
A group of people standing in front of a store .


  out = decoder(torch.tensor(tgt).to(DEVICE), enc_out, src_mask, tgt_mask)
