# The Annotated *Annotated Transformer*

## Introduction

This is the annotated *Annotated Transformer*.   
The original *Transformer* paper needs no introduction.  
The *Annotated Transformer* codebase is originally from Harvard that implements the paper.    
I expanded on some of the concepts mentioned in the *Annotated Transformer*, takeaways from related paper, pytorch syntax, and extended the project to more language modeling tasks. Also I included answers to some of my own questions that I came up with as I studied the paper.  
Making this public so more students of deep learning can benefit from my trial and error (i.e. stochastic gradient descent). 

Overall this notebook is divided into the following sections:
* Transformer Model
* Data Pipeline
* Training Pipeline

In [None]:
# dependency is managed via conda through an environment file (env.yml)
import copy
import math
import os
import time

import altair as alt # for visualization
import pandas as pd
import spacy # for tokenization

import torch
from torchtext.data.functional import to_map_style_dataset # from iterator to map object
import torchtext.datasets as datasets
from torchtext.vocab import build_vocab_from_iterator # given tokenizer, language, return vocab (i.e. dictionary)
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

## Transformer Model

### Motivation for the Transformer
* Seq2Seq and other RNN language models achieved SOTA results at the time with an encoder-decoder architecture
* However RNN is hard to parallelize with O(n) operations during training; also computing a single context vector would result in loss in information, especially as the context window becomes large (since the information of the initial token needs to go through O(n) layer to reach the nth token; this makes backprop hard too) 
* Attention mechanism is proposed as a way to re-compute the context vector depending on the token being decoded. This goes one step further than the word2vec idea - not only can a word be represented by the words around it on average, this representation can be further refined conditional on the current sentence
* Now comes the idea *Attention is all you need* - we don't even need the RNN part of the model. Instead we use attention coupled with multi-head mechanism + positional encoding to process a full sequence at once - this is O(1) for one sequence 
* Another keyword here is "self-attention" - this is cool as all the latent representation comes from the original sequence itself in the encoder and decoder. Cross attention is used when results from the encoder and decoder are compared


### Key design
How does the Encoder - Decoder architecture work?   
* Encoder produces a representation of the input sentence  
* Decoder takes in this representation, and the decoded output so far, and produces the next token prediction autoregressively  
* There are other architectures later on that use decoder / encoder only architectures - more to come on these later (#TODO)
  
What is in the Encoder vs. Decoder?
* Encoder: Embedding => Positional encoding => Attention layer => Feedforward network  
* Decoder: same as the encoder except the attention layer also takes in *final* output from the encoder 

What is the src mask vs. target mask? Why do we need them?
* One is applied to the input and the other the output
* One reason masks are needed is to replace padding with -inf so no attention is paid to them in the attention layer - this applies to both the src and tgt masks
* A different mask is needed for the target - specifically, when trying to predict the token at position i, only y[:i] is displayed to prevent leakage. To handle the special case of producing the first token, a special start_of_sentence token is pre-pended to the target
* This effectively produces a matrix where the upper triangular area gets masked out
 
Is there a difference in the src and tgt embedding?
* Some research has showed that using the same embeddings results in better performance as the src embedding converges to the tgt embedding
* In the case where there are two languages, two separate embeddings may make more sense

What is the generator?
* Generator projects the transformer's output to vocab space and generates the softmax probability 

In [None]:
class EncoderDecoder(nn.Module):
    # overall, think of there being only two components. The encoder encodes and the decoder decodes
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        encoded_output = self.encode(src, src_mask)
        decoded_output = self.decode(encoded_output, src_mask, tgt, tgt_mask) # cross attention needs src_mask
        return decoded_output
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
    

# alternatively can put the generator in the encoder decoder module as another function
class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return log_softmax(self.proj(x), dim = -1) # log_softmax numerically more stable than the softmax and usable with cross entropy

### Encoder in detail
Encoder: a stack of N=6 identical layers

- encoder_layer_0
    - sublayer_0 attention
        - layer norm
        - attention
        - dropout 
        - residual connection
    - sublayer_1 feedforward
        - layer norm
        - feedforward
        - dropout
        - residual connection
- encoder_layer_1  
...  
- layer norm

What is batch norm? (in comparison to layer norm)
* First it normalizes the statistics in a given batch so that given a feature, the batch has a mean of 0 and standard deviation of 1. 
* Then two learnable parameters gamma (scale) and beta (shift) to learn any distribution
* It was introduced as a way to improve network stability by reducing "covariate shift" - the stability in turn allows for higher learning rate
* Typically added in between the linear component and the non-linear component of a network

What is layer norm?
* Layer norm normalizes all features in a given example / set of examples such as they have mean 0 and standard deviation of 1
* Similarly it has two learnable parameters to shift the distribution
* This is useful when the batch of data is small
* In RNN, batch norm struggles as it disrupts dependencies between time steps. In the transformer, each sequence is also processed independently

What is drop out?  
* Some neurons' weights are set to 0, which allows the network to learn redundant representation

How does these techniques behave differently during training vs. inference?  
* Batch norm: an EMA mean and std are cached during to be used for inference (since inference only has 1 instance)
* Layer norm: behaves the same
* Dropout: scaling the activation by 1/p during training so that the expected value remains the same at inference; alternatively, activation is scaled by p during inference

What is the purpose of residual connection?  
* This was originally introduced as part of the paper *Deep Residual Learning for Image Recognition*
* Structurally the output of a sublayer is added to the input of the sublayer, and go through ReLU activation again
* In the extreme, this allows for an identity function to be learned quickly, speeding up training in deep network
* One intuition is that a deeper network should act strictly better than a shallow one, as the additional layers can be just identity functions - this is not empricially observed however
* Instead of learning the mapping between input and output, now the network learns the difference between the input and output (residual mapping)
* The residual mapping can easily be learned to be 0, which is the identity function



In [None]:
# an encoder consists of N encoder_layers
# each encoder_layer contain a self-attention sublayer and feedforward sublayer
# Each sublayer contains layer norm, dropout and residual connection
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N) # replicate given layer n times 
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask): # need to have the mask defined here to handle different masks instead of __init__
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)
    

class SubLayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super().__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer): # notice x is the input whereas sublayer is a function
        return x + self.dropout(sublayer(self.norm(x)))


class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout): # dropout is a scalar here
        super().__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)
    

class LayerNorm(nn.Module):
    """learned scaling factor. prevent the gradient vanishing / exploding problems"""
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim = True) # [example_cnt, seq_length, embedding_size] => take the mean across features; dont collapse the matrix
        std = x.std(-1, keepdim = True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
    

def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) # module list, pretty similar to Sequential

### Decoder in detail
What is the input and output for a decoder's prediction?
* The input is the same source sentence and the target sequence up to the token being predicted
* The target is the token being predicted 
* Teacher forcing is used during the training where the correct target sequence is used rather than what's previously generated by the model


Why do we need source mask in the decoder if the encoder memory input is already affected by it?  
*  The output of the encoder contains layer norm layer, so source mask needs to be re-applied

Why is the subsequent mask size (1, seq_length, seq_length) and how is it applied?  
*  This is applied to the attention matrix of size (example_count, seq_length, seq_length) to mask out the upper triangular area


In [None]:
# a decoder consists of N decoder layer
class Decoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
    
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)
    

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super().__init__()
        self.size = size
        self.self_attn = self_attn # attention(query, key, value, mask)
        self.src_attn = src_attn # attention(query, key, value, mask)
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) # self attention
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) # cross attention
        return self.sublayer[2](x, self.feed_forward) # feedforward layer
    

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size) # [1, seq_length, seq_length]
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    ) 
    return subsequent_mask == 0 # diagnoals are True

### Attention in detail
Attention maps a query and key-value pairs into an output.   
A compatibility function between the query and key determines which value to output.  
The output is a weighted sum of the values.  

What is the complexity of the attention mechanism?
* query@key: each cell takes d computation, there are n**2 cells, so O(n^2d)
* softmax: each cell takes n computation, there are n**2 cells, so O(n^2)
* value: [n,n]@[n,d] => each cell takes n, there are n*d cells, so O(n^2d)
* overall, O(n^2d)

What is multi-head attention?
* Instead of learning one representation of key, query, and value, we learn h different representations
* This allows the model to focus on different parts of the sequence
* The h outputs are concatenated and linearly transformed to the expected dimension

How does dimension of the different matrices change in multi-head attention?
* Q: [cnt,n,d], K: [cnt,n,d], V: [cnt,n,d]
* multi-head: for {Q,K,V}: [cnt,n,d]@[cnt,d,d] => [cnt,n,d]; => [cnt,n,d] reshape => [cnt,n,h,d/h] transpose => [cnt,h,n,d/h] transpose => [cnt,n,h,d/h] reshape => [cnt,n,d]  

How come there is no non-linearity (e.g. ReLU) in the attention mechanism?  
* Not needed to model dependencies between the query and key; also without relu it's more stable

In [None]:
def attention(query, key, value, mask=None, dropout=None): # dropout is a function this time
    d_k = query.size(-1) # embedding size
    scores = torch.matmul(query, key.transpose(-2,-1)) / math.sqrt(d_k) # [count, n, d] @ [count, d, n] => [count, n, n]

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9) # minus infinity => dot product becomes very negative => close to 0 from softmax
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn # return both the weighted value, as well as the probability


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4) # four separate linear models for q / k / v / output
        self.attn = None
        self.dropout = nn.Dropout(p = dropout)
        
    def forward(self, query, key, value, mask = None):
        if mask is not None:
            mask = mask.unsqueeze(1) # [cnt, n, n] => [cnt, 1, n, n]
        nbatches = query.size(0)

        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
            for lin, x in zip(self.linears, (query, key, value))
        ]
        # said differently...
        # query = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        # key = self.linears[1](key).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        # value = self.linears[2](value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) # [cnt, h, n, d/h]

        # note this implementation is slightly different from simple concatenation - attention is processed as one matrix
        x = (
            x.transpose(1,2)
            .contiguous() # make sure the memory is contiguous for ops like view and performance
            .view(nbatches, -1, self.h * self.d_k) 
        )
        del query
        del key
        del value
        
        return self.linears[-1](x)

### Embedding and Positional encoding  
Embedding is really just a lookup table taking in [n, vocab] and returning [n, d]  
The biggest takeaway from the positional encoding is that the model can learn the relative position of the tokens, but not the absolute position. This is due to periodicity of the sin and cosine function. e.g. f(pos+k) -f(pos) = f(pos+2k) - f(pos+k)

Dropout is used for positional encoding, but not for embedding.  

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.lut = nn.Embedding(vocab, d_model) # lookup table
        self.d_model = d_model
    
    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model) # maintain the variance of the embedding 
    

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model) # [max_len, d_model]
        position = torch.arange(0, max_len).unsqueeze(1) # [max_len, 1]
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # [1, max_len, d_model]
        self.register_buffer("pe", pe) # register as a parameter without backprop

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False) # notice no gradient calc is needed
        return self.dropout(x)

### Make a model

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))


def make_model(
    src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1
):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    # encoder, decoder, src_embed, tgt_embed, generator
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)), # a sequential module is defined here directly
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab)
    )

    # xavier initialization
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p) # make training faster
    
    return model


Simple inference test - predict without feedback

In [None]:
def inference_test():
    test_model = make_model(11, 11, 2)
    test_model.eval()

    src = torch.tensor([[1,2,3,4,5,6,7,8,9,10]])
    src_mask = torch.ones(1,1,10)

    memory = test_model.encode(src, src_mask)

    ys = torch.zeros(1,1).type_as(src) # [1,1]
    
    for i in range(20):
        out = test_model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)
        )
        prob = test_model.generator(out[:,-1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]

        ys = torch.cat(
            [ys, torch.empty(1,1).type_as(src.data).fill_(next_word)], dim=1
        )
    print(f"prediction {ys}")

In [None]:
def run_tests():
    for _ in range(10):
        inference_test()

run_tests()

### Simple copy paste model

Data generation

In [None]:
def data_gen(V, batch_size, nbatches):
    """Generate random data for a src-tgt copy task.
    V: vocab size
    """
    for i in range(nbatches):
        data = torch.randint(1, V, size=(batch_size, 10)) # [count, vocab]
        data[:, 0] = 1 # SOS == 1
        src = data.requires_grad_(False).clone().detach()
        tgt = data.requires_grad_(False).clone().detach()
        yield Batch(src, tgt, 0) # src, tgt, padding value


class Batch:
    """Object for holding a batch of data with mask during training."""

    def __init__(self, src, tgt=None, pad=2):  # 2 = <blank>
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2) # [count,1,vocab]
        if tgt is not None:            
            self.tgt = tgt[:, :-1] # skip the last one EOS [count, vocab-1]
            self.tgt_y = tgt[:, 1:] # skip the first one SOS [count, vocab-1]
            self.tgt_mask = self.make_std_mask(self.tgt, pad) # [count, vocab-1, vocab-1]
            self.ntokens = (self.tgt_y != pad).data.sum() # count of non-padding token

    @staticmethod
    def make_std_mask(tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2) # [count,1,vocab-1] 
        tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        ) # [count,1,vocab-1] & [1,vocab-1,vocab-1]
        return tgt_mask # [count, vocab-1, vocab-1]
    

Learning rate

In [None]:
def rate(step, model_size, factor, warmup):
    """
    we have to default the step to 1 for LambdaLR function
    to avoid zero raising to negative power.
    """
    if step == 0:
        step = 1
    return factor * (
        model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))
    )

Regularization

What is label smoothing?  
* Label smoothing is a regularization technique that prevents the model from being too confident about the prediction
* We can use it as the criterion instead of cross entropy loss
* Specifically the target distribution is smoothed to be uniform over the vocabulary
* The target becomes 1-alpha, and the rest becomes alpha/(vocab_size-1)

What is KL divergence loss vs. Cross entropy loss?  
* KL divergence is often used for comparing two probability distributions
* With label smoothing, KL divergence could be a better fit 
* Entropy + KL divergence = cross entropy
* KL can be less stable than cross entropy due to the inclusion of entropy term (unstable with mini-batches)

In [None]:
class LabelSmoothing(nn.Module):
    "Implement label smoothing."

    def __init__(self, size, padding_idx, smoothing=0.0):
        super().__init__()
        self.criterion = nn.KLDivLoss(reduction="sum")
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None

    def forward(self, x, target):
        assert x.size(1) == self.size # vocab
        true_dist = x.data.clone() # [count, vocab]
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0

        mask = torch.nonzero(target.data == self.padding_idx) # return padding indices
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        
        return self.criterion(x, true_dist.clone().detach())
    

class SimpleLossCompute:
    "A simple loss compute and train function."

    def __init__(self, generator, criterion):
        self.generator = generator
        self.criterion = criterion

    def __call__(self, x, y, norm):
        x = self.generator(x)
        sloss = (
            self.criterion(
                x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)
            )
            / norm
        )
        return sloss.data * norm, sloss # return both total loss and avg loss
    

class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None


class DummyScheduler:
    def step(self):
        None


class TrainState:
    step: int = 0
    accum_step: int = 0
    samples: int = 0
    tokens: int = 0

Training

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data) # no need for padding mask during lazy decode
        )
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1
        )
    return ys

In [None]:
def run_epoch(
    data_iter,
    model,
    loss_compute,
    optimizer,
    scheduler,
    mode="train",
    accum_iter=1,
    train_state=TrainState(),
):
    """Train a single epoch"""
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    n_accum = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(
            batch.src, batch.tgt, batch.src_mask, batch.tgt_mask
        )
        loss, loss_node = loss_compute(out, batch.tgt_y, batch.ntokens)
        # loss_node = loss_node / accum_iter
        if mode == "train" or mode == "train+log":
            loss_node.backward()
            train_state.step += 1
            train_state.samples += batch.src.shape[0]
            train_state.tokens += batch.ntokens
            if i % accum_iter == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                n_accum += 1
                train_state.accum_step += 1
            scheduler.step()

        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 40 == 1 and (mode == "train" or mode == "train+log"):
            lr = optimizer.param_groups[0]["lr"]
            elapsed = time.time() - start
            print(
                (
                    "Epoch Step: %6d | Accumulation Step: %3d | Loss: %6.2f "
                    + "| Tokens / Sec: %7.1f | Learning Rate: %6.1e"
                )
                % (i, n_accum, loss / batch.ntokens, tokens / elapsed, lr)
            )
            start = time.time()
            tokens = 0
        del loss
        del loss_node
    return total_loss / total_tokens, train_state


def example_simple_model():
    V = 11
    criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
    model = make_model(V, V, N=2)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9
    )
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, model_size=model.src_embed[0].d_model, factor=1.0, warmup=400
        ),
    )

    batch_size = 80

    # for testing
    src = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
    max_len = src.shape[1]
    src_mask = torch.ones(1, 1, max_len)
    train_state = TrainState()

    for epoch in range(20):
        model.train()
        _, train_state = run_epoch(
            data_gen(V, batch_size, 20),
            model,
            SimpleLossCompute(model.generator, criterion),
            optimizer,
            lr_scheduler,
            train_state=train_state,
            mode="train"
        )

        model.eval()
        run_epoch(
            data_gen(V, batch_size, 5),
            model,
            SimpleLossCompute(model.generator, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval"
        )[0]
        
        print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))

In [None]:
example_simple_model()

### German to English translation task
This is essentially the preprocessing step before the LLM.  
First we get the tokenizer for the respective language. 
Then we construct the vocabulary size.  
Finally we map the sentence to the vocabulary as an one hot encoding.  

In [None]:
def load_tokenizers():
    lm = "de_core_news_sm"
    try:
        spacy_de = spacy.load(lm)
    except OSError:
        os.system(f"python -m spacy download {lm}")
        spacy_de = spacy.load(lm)

    lm = "en_core_web_sm"
    try:
        spacy_en = spacy.load(lm)
    except OSError:
        os.system(f"python -m spacy download {lm}")
        spacy_en = spacy.load(lm)
    
    return spacy_de, spacy_en


def tokenize(text, tokenizer):
    return [tok.text for tok in tokenizer.tokenizer(text)]


def yield_tokens(data_iter, tokenizer, index):
    for from_to_tuple in data_iter: # from_to refer to the original vs. translated sentences 
        yield tokenizer(from_to_tuple[index])

In [None]:
def build_vocabulary(spacy_de, spacy_en):
    def tokenize_de(text):
        return tokenize(text, spacy_de)
    
    def tokenize_en(text):
        return tokenize(text, spacy_en)
    
    train, val, test = datasets.Multi30k(language_pair=("de", "en")) # FIXME: there is some issue with the test dataset server so it cant be downloaded
    vocab_src = build_vocab_from_iterator(
        yield_tokens(train + val, tokenize_de, index=0), # index 0 is german
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )

    train, val, test = datasets.Multi30k(language_pair=("de", "en"))
    vocab_tgt = build_vocab_from_iterator(
        yield_tokens(train + val , tokenize_en, index=1), # index 1 is english
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )

    vocab_src.set_default_index(vocab_src["<unk>"]) # out of vocabulary index
    vocab_tgt.set_default_index(vocab_tgt["<unk>"])

    return vocab_src, vocab_tgt


def load_vocab(spacy_de, spacy_en):
    if not os.path.exists("vocab.pt"):
        vocab_src, vocab_tgt = build_vocabulary(spacy_de, spacy_en)
        torch.save((vocab_src, vocab_tgt), "vocab.pt")
    else:
        vocab_src, vocab_tgt = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes:")
    print(len(vocab_src))
    print(len(vocab_tgt))
    return vocab_src, vocab_tgt

In [None]:
spacy_de, spacy_en = load_tokenizers()
vocab_src, vocab_tgt = load_vocab(spacy_de, spacy_en)

In [None]:
# helpful examples
for v in ["<s>", "</s>", "<blank>", "ahbirf", "<unk>", "a"]:
    print(f"{v}: {vocab_src[v]} {vocab_tgt[v]}")

collate batch adds sos and eos tokens to the input data pair, as well as max padding

In [None]:
def create_dataloaders(
    vocab_src,
    vocab_tgt,
    spacy_de,
    spacy_en,
    batch_size=12000,
    max_padding=128,
):
    def tokenize_de(text):
        return tokenize(text, spacy_de)

    def tokenize_en(text):
        return tokenize(text, spacy_en)

    def collate_fn(batch):
        return collate_batch(
            batch,
            tokenize_de,
            tokenize_en,
            vocab_src,
            vocab_tgt,
            max_padding=max_padding,
            pad_id=vocab_src.get_stoi()["<blank>"], # string to text 
        )

    train_iter, valid_iter, test_iter = datasets.Multi30k(
        language_pair=("de", "en")
    )

    train_iter_map = to_map_style_dataset(
        train_iter
    ) # so the data can be indexed
    valid_iter_map = to_map_style_dataset(valid_iter)

    train_dataloader = DataLoader(
        train_iter_map,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )
    valid_dataloader = DataLoader(
        valid_iter_map,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn, # necessary for a map styled iterator
    )
    return train_dataloader, valid_dataloader


def collate_batch(
    batch,
    src_pipeline, # the tokenizer
    tgt_pipeline,
    src_vocab, # the vocab
    tgt_vocab,
    max_padding=128,
    pad_id=2,
):
    bs_id = torch.tensor([0])  # <s> token id
    eos_id = torch.tensor([1])  # </s> token id
    src_list, tgt_list = [], []
    for (_src, _tgt) in batch:
        processed_src = torch.cat(
            [
                bs_id,
                torch.tensor(
                    src_vocab(src_pipeline(_src)),
                    dtype=torch.int64,
                    device="cpu",
                ),
                eos_id,
            ],
            0,
        )
        processed_tgt = torch.cat(
            [
                bs_id,
                torch.tensor(
                    tgt_vocab(tgt_pipeline(_tgt)),
                    dtype=torch.int64,
                    device="cpu",
                ),
                eos_id,
            ],
            0,
        )
        src_list.append(
            pad(
                processed_src,
                (
                    0,
                    max_padding - len(processed_src), # padded to the max length
                ),
                value=pad_id,
            )
        )
        tgt_list.append(
            pad(
                processed_tgt,
                (0, max_padding - len(processed_tgt)),
                value=pad_id,
            )
        )

    src = torch.stack(src_list) # note the data structure here
    tgt = torch.stack(tgt_list)
    return (src, tgt)

In [None]:
def train_model(
    vocab_src,
    vocab_tgt,
    spacy_de,
    spacy_en,
    config,
):
    pad_idx = vocab_tgt["<blank>"]
    d_model = 512
    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    module = model

    criterion = LabelSmoothing(
        size=len(vocab_tgt), padding_idx=pad_idx, smoothing=0.1
    )

    train_dataloader, valid_dataloader = create_dataloaders(
        vocab_src,
        vocab_tgt,
        spacy_de,
        spacy_en,
        batch_size=config["batch_size"],
        max_padding=config["max_padding"],
    )

    optimizer = torch.optim.Adam(
        model.parameters(), lr=config["base_lr"], betas=(0.9, 0.98), eps=1e-9
    )
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, d_model, factor=1, warmup=config["warmup"]
        ),
    )
    train_state = TrainState()

    for epoch in range(config["num_epochs"]):
        model.train()
        print(f"Epoch {epoch} Training ====", flush=True)
        _, train_state = run_epoch(
            (Batch(b[0], b[1], pad_idx) for b in train_dataloader),
            model,
            SimpleLossCompute(module.generator, criterion),
            optimizer,
            lr_scheduler,
            mode="train+log",
            accum_iter=config["accum_iter"],
            train_state=train_state,
        )

        file_path = "%s%.2d.pt" % (config["file_prefix"], epoch)
        torch.save(module.state_dict(), file_path)
        torch.cuda.empty_cache()

        print(f"Epoch {epoch} Validation ====", flush=True)
        model.eval()
        sloss = run_epoch(
            (Batch(b[0], b[1], pad_idx) for b in valid_dataloader),
            model,
            SimpleLossCompute(module.generator, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval",
        )
        print(sloss)

    file_path = "%sfinal.pt" % config["file_prefix"]
    torch.save(module.state_dict(), file_path)

In [None]:
def load_trained_model():
    config = {
        "batch_size": 32,
        "num_epochs": 5,
        "accum_iter": 10,
        "base_lr": 1.0,
        "max_padding": 72,
        "warmup": 3000,
        "file_prefix": "multi30k_model_",
    }
    model_path = "multi30k_model_final.pt"
    if not os.path.exists(model_path):
        train_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config)

    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    model.load_state_dict(torch.load("multi30k_model_final.pt"))
    return model

In [None]:
model = load_trained_model()