# Transformer Implementation

Following [The Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)

## Prelims

In [4]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR

# import pandas as pd
import polars as pl  # Polars because we're cool like that
import altair as alt

from torch.utils.data import DataLoader

import spacy
import GPUtil
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

# # Can't import torchtext with ROCm
# from torchtext.data.functional import to_map_style_dataset
# from torchtext.vocab import build_vocab_from_iterator
# import torchtext.datasets as datasets

In [5]:
# Set to False to skip notebook execution (e.g. for debugging)
RUN_EXAMPLES = True

In [10]:
# Convenience helper functions
def is_interactive_notebook() -> bool:
    """Returns a boolean if it is run in an interactive notebook."""
    return __name__ == "__main__"


def show_example(fn, args=[]):
    """Returns the result of a function with arguments if RUN_EXAMPLES is set to True"""
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)


def execute_example(fn, args=[]) -> None:
    """Executes function with arguments if RUN_EXAMPLES is set to True without returning"""
    if __name__ == "__main__" and RUN_EXAMPLES:
        fn(*args)


class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None


class DummyScheduler:
    def step(self):
        None

# Part 1: Model Architecture
## Model Architecture

In [11]:
class EncoderDecoder(nn.Module):
    """
    Standard Encoder-Decoder architecture for neural sequence transductions. Encoder maps input sequences of symbol representations to a continuous representation. The decoder generates an output sequence of symbols autoregressively.
    """

    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def encode(self, src, src_mask):
        """Encodes source sequence with a mask."""
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        """Decodes representation from memory and compares to target with mask."""
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked source and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

In [12]:
class Generator(nn.Module):
    """Defines a standard linear + softmax generation step."""

    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return log_softmax(self.proj(x), dim=-1)

Transformer architecture looks like this:

![Transformer Architecture](transformer.png)

### Encoder
Encoder uses a stack of $N = 6$ layers.

In [13]:
def clones(module, N):
    """Create N identical layers of a module."""
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

The encoder uses [layer normalization](https://arxiv.org/abs/1607.06450) after each set of layers.

In [14]:
class LayerNorm(nn.Module):
    """Layer Normalization module."""

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        # Normalize
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

The encoder also uses a [residual connection](https://arxiv.org/abs/1512.03385) around each of the two sub-layers.

In [15]:
class SublayerConnection(nn.Module):
    """
    Residual connection with layer norm and dropout.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        """Apply residual conection to sublayers with layer norm and dropout"""
        return x + self.dropout(sublayer(self.norm(x)))

Layer norm, residual connections, and multiple layers are used to form the core encoder. The output of each sublayer is $LayerNorm(x + Sublayer(x))$.

In [16]:
class Encoder(nn.Module):
    """Core encoder as a stack of N encoder layers with Layer Normalization."""

    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        """Pass input and mask through each layer and then normalize."""
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

The encoder layer has two sublayers:
1. Multi-head self-attention
2. Position-wise fully connected feed-forward network

In [18]:
class EncoderLayer(nn.Module):
    """Full encoder layer with self-attn and feed forward."""

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        """Use masked self-attn and feed forward networks."""
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

### Decoder
Also uses $N = 6$ identical layers.

In [19]:
class Decoder(nn.Module):
    """Masked N layer decoder."""

    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

The decoder uses the self-attention and feed forward sub-layers, but also adds a sub-layer that performs multi-head attention over the output of the encoder stack (cross-attention). Residual connections and layer norms are still used here.

In [21]:
class DecoderLayer(nn.Module):
    """Decoder uses self-attn, cross-attn (from src), and feed forward."""

    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

Self-attention sub-layer in the decoder stack needs a causal mask to prevent positions from attending to subsequent positions.

In [114]:
def causal_mask(size):
    """Mask out subsequent positions."""
    attn_shape = (1, size, size)
    causal_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return causal_mask == 0

In [131]:
# Example of what the causal mask looks like
def example_mask():
    mask = causal_mask(20)
    LS_data = pl.concat(
        [
            pl.DataFrame(
                {
                    "Causal Mask": mask[0][x, y].flatten().item(),
                    "Window": y,
                    "Masking": x,
                }
            )
            for y in range(20)
            for x in range(20)
        ]
    )

    return (
        alt.Chart(LS_data)
        .mark_rect()
        .properties(height=300, width=300)
        .encode(
            x="Window:O",
            y="Masking:O",
            color=alt.Color(
                "Causal Mask:N",
                scale=alt.Scale(domain=[True, False], range=["limegreen", "gray"]),
            ),
        )
        # .interactive()
    )

In [132]:
show_example(example_mask)

In [None]:
### Attention