# The Transformer

The transformer architecture introduced in the paper [Attention is All You Need](https://arxiv.org/pdf/1706.03762.pdf) has served as the basis for nearly all of the modern day language models.
This assumes some understanding of basic deep learning including how CNNs and RNNs work. 

**Heavily Sourced From**
https://nlp.seas.harvard.edu/annotated-transformer/ 

In [None]:

from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import pandas as pd
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator
import warnings


# Set to False to skip notebook execution (e.g. for debugging)
warnings.filterwarnings("ignore")
RUN_EXAMPLES = True

## 1. Architecture

<p align="center">
<img src="./assets/transformer.png" height="480">
</p>

### Encoder Decoder

The transformer can be broken down into the encoder-decoder architecture (left and right in the figure above respectively)

Each step in this architecture is auto-regressive, taking previously generated symbols as additional input when generating text. (???)

#### Masking

#### Residual Connection
https://arxiv.org/abs/1512.03385 

#### Layer Normalization
https://arxiv.org/abs/1607.06450

The output of each sublayer is `LayerNorm(x + Sublayer(x))` where `Sublayer(x)` is the function implemented by sublayer

#### Dropout
https://jmlr.org/papers/v15/srivastava14a.html


In [None]:
# Overall Architecture 
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super().__init__()
        self.encoder = encoder 
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
    

class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab)
        
    def forward(self, x):
        return log_softmax(self.proj(x), dim=-1)

In [None]:
# Helper functions 

def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

### Encoder

The encoder maps an input sequence of symbol/token representation $(x_1, ..., x_n)$ to a sequence of continuous representation $\mathbf{z} = (z_1, \dots, z_n)$.


#### Encoder Sublayer

<figure align="center">
<img src="./assets/transformer-encoder-sublayer.png" height="480">
<figcaption>The `SublayerConnection`</figcaption>
</figure>

In [None]:

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super().__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))



<p align="center">
<img src="./assets/transformer-encoder.png" height="480">
</p>

In [None]:

class EncoderLayer(nn.Module):

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)



### Decoder

The decoder takes the output of the encoder $\mathbf{z}$ and outputs a sequence $(y_1, \dots y_m)$ of symbols/tokens. 

The output embedding refers to the embedding of the *target* sequence taht is input to the decoder during training.
It is the vector representation of the output you want. During training these are usually the correct tokens up to the current position in the inputs, shifted by one to indicate the start of a sequence/start token (<SOS>, etc.).
With this we are trying to tell the model to predict the next token in the sequence given the previous tokens.



<p align="center">
<img src="./assets/transformer-decoder.png" height="480">
</p>


In [None]:
# Decoder 

class Decoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)
    
class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super().__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)



#### Additional Masking in Decoder

We modify the masking in the self-attention sublayer (called look-ahead or causal masking) in the decoder to prevent future tokens in the 'output embedding' from being "attended" to.
Each token can only attend to itself and tokens before it in the sequence. 
This + output embeddings are offset by one position, ensures that predictions for position $i$ (current position in the target/output sequence) can depend only on the known outputs at position $< i$, 
thus enforcing an auto-regressive property where the model can only use information from previous tokens in the sequence to make predictions.

This is done by setting the weights for the future tokens to negative infinity before the softmax

In [None]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    mask = torch.triu(torch.ones(attn_shape, dtype=torch.bool), diagonal=1)
    return mask

### The Attention Mechanism

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

<p align="center">
  <img src="./assets/scaled-dot-attention.png" height="360">
</p>

An attention function can be thought of as a mapping of queries and a set of key-value pairs to an output. 
The output is a weighted sum of the values where the weight is calculated by a 'compatibility function' between the query and key vectors.
The attention used in the original paper is called the "Scaled Dot-Product Attention".
The inputs consist of queries $Q$ and keys $K$ of dimension $d_k$ and values $V$ of dimension $d_v$. 
We compute the dot product between the query and all the keys, scale by $\sqrt{d_k}$, add an optional mask, apply the softmax function to obtain the weights for the values.
In practive we compute the attention function on a set of queries simultaneously, in a matrix $Q$ $d_k \times d_k$.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

#### Alternate Attention Mechanism

- Additive attention: Uses single layer feed-forward layer (slower and less space efficient even though theoretically same complexity).
- Dot product attention: Same as in the paper except for scaling factor of $\frac{1}{\sqrt{d_k}}$

The scaling factor is added as additive attention outperforms dot product attention for larger values of $d_k$ ([https://arxiv.org/abs/1703.03906](Paper)), 
suspectd that for large $d_k$ the dot product grows large in magnitude pushing softmax fucntion into regions of small gradients. 

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

### Multihead Attention
Multi-head attention combines multiple attention mechanisms in parallel to allow the model to jointly attend to different mappings learned by each component. 

$$\begin{align}
  \text{MultiHead}(Q, K, V) &= \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O\\
  \text{where head}_\text{i} &= \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
\end{align}$$

Where each. 

<p align="center">
  <img src="./assets/multihead-attention.png" height="360">
</p>

**TODO:** needs better explanation

The multi-head attention mechanism is used in three ways in the transformer

  1. Encoder-Decoder Attention: The queries come from previous decoder layers, the memory keys and values come from the output of the encoder. This means that all positions in the decoder attends over all positions in the input sequence (similar to Seq2Seq).
  2. Encoder Attention: Has self-attention layesr where all of the keys, values, and queries come from the sample place. In an encoder this is the output of trhe previous layer. Each position in the encoder can attend to all positions in the previous layer of the encoder.
  3. Decoder Attention: Similar to encoder attention, allows each position in the decoder to attend to all positions in the decoder up to and including that decoder (preserve auto-regressive property)



In [None]:
class MultiheadAttention(nn.Module):
  def __init__(self, h, d_model, dropout=0.1):
    super().__init__()
    assert d_model % h == 0
    self.d_k = d_model // h # assume d_v = d_k
    self.h = h
    self.linears = clones(nn.Linear(d_model, d_model), 4)
    self.attn = None
    self.dropout = nn.Dropout(p=dropout)

  def forward(self, query, key, value, mask=None):
    if mask is not None:
      mask = mask.unsqueeze(1)
    nbatches = query.size(0)
    # 1) Do all the linear projections in batch from d_model => h x d_k
    query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
                         for l, x in zip(self.linears, (query, key, value))]
    # 2) Apply attention on all the projected vectors in batch.
    x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
    # 3) "Concat" using a view and apply a final linear.
    x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
    del query, key, value
    return self.linears[-1](x)



### Feed Forward Neural Network

Each layer of the encoder and decoder also contains a fully connected feed-forward network. 

**TODO:** why tho

In [None]:
class FFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

### Embeddings

We use learned embeddings to convert the input and output tokens to vectors of dimension $d_{model}$.

We use the learned linear transformation and softmax fucntion to convert the decoder output to next-token probabilities. 

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

### Positional Encoding

The current model does not have a way to encode the position and order of the tokens yet. To do this we add positional encodings to the bottom of the encoder and decoder stacks. These have dimesion $d_{model}$ so that they can be summed.

The original paper makes use of the following positional encoding function

$$
PE_{(pos, 2i)} = \sin(pos/10000^{2i/d_{model}})\\
PE_{(pos, 2i + 1)} = \cos(pos/10000^{2i/d_{model}})
$$

where $pos$ is the position and $i$ is the dimension, each dimension of the positional encoding corresponds to a sinusoid. 

**TODO:** why these values

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, dropout, max_len=5000):
    super(PositionalEncoding, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe', pe)
  
  def forward(self, x):
    x = x + self.pe[:, :x.size(1)]
    return self.dropout(x)

### Putting it all together

In [None]:

model = EncoderDecoder(
    encoder=Encoder(EncoderLayer(512, MultiheadAttention(8, 512), FFN(512, 2048), 0.1), 6),
    decoder=Decoder(DecoderLayer(512, MultiheadAttention(8, 512), MultiheadAttention(8, 512), FFN(512, 2048), 0.1), 6),
    src_embed=nn.Sequential(Embeddings(512, 1000), PositionalEncoding(512, 0.1)),
    tgt_embed=nn.Sequential(Embeddings(512, 1000), PositionalEncoding(512, 0.1)),
    generator=Generator(512, 1000)
)

## 2. Training

## 3. Inference


## 4. Uses in the Real World

### BERT


### GPT 


### LLAMA