Transformer From Scratch 
========================

Reading along Dan Jurafsky and James H. Martin's [Speech and Language Processing](https://web.stanford.edu/~jurafsky/slp3/) book, I decided to follow through Chapter 8 of their book to implement a Transformer using Pytorch. It is my goal to have a working transformer which I can use to train on the guitar dataset. I know some linear algebra, the book essentially gives the entire algorithm in terms of linear algebra, and pytorch provides a nice but still very informative abstractions for doing linear algebra. I had no reason not to pursue this project on top of whatever I proposed to do initially. 

## 

Attention Layer
---------------

At the heart of Transformer is the **attention layer**. It is a mechanism that allows words(tokens) to gain contextual meaning from their surrounding words(tokens). It can have multiple **"heads"**, where each "head" can be thought of as a specialist who asks particular set of questions given some data. For instance, one head could focus solely on grammar while another could instead focus on sentiments (even though that might not be exactly what occurs under the hood).

Each head's job, then, is to ask the right kinds of *questions* to choose which of previous words it has seen matters the most to the current word. To do this, each head consists of three main components: **Query**, **Key**, and **Value** weight matrices. 

<!-- 
    essentially, what it is at the end of the day is weighted sum, but it's obviously lot more complicated than that
    don't forget to write out the equations that I have referenced
    maybe throw in some pictures
    say something about how masking and softmax is used to determine what key's to focus on
    also explain how results from different heads are consolidated at the end
--!>

In [40]:
import torch 
import torch.nn as nn
import math

class AttentionLayer(nn.Module): 
    def __init__(self,
                 N,         # context length; how many tokens to consider at any point in time
                 model_dim, # d
                 key_dim,   # d_k
                 num_heads=1, 
                ):
        super().__init__() 

        if num_heads < 1:
            raise ValueError("num_heads cannot be less than 1!")

        if model_dim % num_heads != 0: 
            raise ValueError("model_dim is not divisible by num_heads!")
        
        self.N = N
        self.model_dim = model_dim
        self.key_dim = key_dim
        self.value_dim = model_dim//num_heads 
        self.num_heads = num_heads
        
        # Query Weights (num_heads, model_dim, key_dim)
        self.W_Q = nn.Parameter(torch.rand((num_heads, model_dim, key_dim)))
        # Key Weights   (num_heads, model_dim, key_dim)
        self.W_K = nn.Parameter(torch.rand((num_heads, model_dim, key_dim)))
        # Value Weights (num_heads, model_dim, value_dim)
        self.W_V = nn.Parameter(torch.rand((num_heads, model_dim, self.value_dim)))
        # Output Weights
        self.W_O = nn.Parameter(torch.rand((model_dim, model_dim)))

        # Mask (for autoregressive model)
        mask = torch.tensor([[0 if i>= j else -torch.inf for j in range(N)] for i in range(N)])
        self.register_buffer("mask", mask) # move mask to GPU
        
    def forward(self, X): # X has (N, model_dim) dimensions
        seq_len = X.shape[0]
        Q = X@self.W_Q # (num_heads, N, key_dim)
        K = X@self.W_K # (num_heads, N, key_dim)
        V = X@self.W_V # (num_heads, N, value_dim)

        current_mask = self.mask[:seq_len, :seq_len] # when seq_len < N

        attention = Q@(K.mT) / math.sqrt(self.key_dim) + current_mask # (num_heads, N (queries), N (keys)) 
        heads = nn.functional.softmax(attention, dim = -1)@V #(num_heads, N, value_dim)
        cat_heads = torch.cat(heads.unbind(), dim=1) # (N, value_dim) each and concatenate the columns to form (N, model_dim)
        A = cat_heads@self.W_O # (N, model_dim)

        return A
                    

In [41]:
X = torch.rand((3,4)) # 3 words represented as dim (1, 4) tensors
multihead_attention = AttentionLayer(N=3, model_dim=4, key_dim=2, num_heads=2)
multihead_attention.to("cuda")
multihead_attention(X.to("cuda"))


tensor([[0.7284, 1.4601, 1.0573, 2.4183],
        [0.9968, 1.8307, 1.3983, 3.0009],
        [0.9250, 1.8363, 1.3025, 3.0480]], device='cuda:0',
       grad_fn=<MmBackward0>)

In [42]:
class FeedForward(nn.Module):
    def __init__(self, model_dim, hidden_dim):
        super().__init__()
        self.lin_1 = nn.Linear(in_features=model_dim, out_features=hidden_dim)
        self.lin_2 = nn.Linear(in_features=hidden_dim, out_features=model_dim)
        self.relu = nn.ReLU()
    def forward(self, X):
        return self.lin_2(self.relu(self.lin_1(X)))

class LayerNorm(nn.Module): 
    def __init__(self, model_dim, epsilon=1e-6): 
        super().__init__()
        # gamma and beta is to normalize features 
        self.gamma = nn.Parameter(torch.ones(model_dim)) # initialize gammas to ones because if initialized randomly to 0, it's dead signal
        self.beta = nn.Parameter(torch.zeros(model_dim))
        self.eps = epsilon
    def forward(self, X): # (N, model_dim)
        model_dim = X.shape[-1]
        mean = torch.mean(X, dim=-1, keepdims=True)
        std = torch.std(X, dim=-1, keepdims=True)
        X_hat = (X - mean) / (std + self.eps) # add epsilon for numerical stability
        layer_norm = self.gamma * X_hat + self.beta
        return layer_norm

In [44]:
ff = FeedForward(model_dim=4, hidden_dim=8)
ff(X)
norm = LayerNorm(model_dim=4)
norm(X)

tensor([[ 0.8143,  0.7121, -0.1866, -1.3398],
        [-0.4161,  1.1091, -1.1701,  0.4772],
        [-0.1556,  0.6163,  0.8854, -1.3461]], grad_fn=<AddBackward0>)

In [45]:
class TransformerBlock(nn.Module): 
    def __init__(self, 
                 N, 
                 model_dim, 
                 key_dim, 
                 hidden_dim, 
                 num_heads=1): 
        super().__init__() 
        self.N = N
        self.model_dim = model_dim
        self.key_dim = key_dim 
        self.hidden_dim = hidden_dim 
        self.num_heads = num_heads
        
        self.attention_layer = AttentionLayer(N=N, model_dim=model_dim, key_dim=key_dim, num_heads=num_heads)
        self.ffn = FeedForward(model_dim=model_dim, hidden_dim=hidden_dim)
        self.norm_1 = LayerNorm(model_dim=model_dim)
        self.norm_2 = LayerNorm(model_dim=model_dim)

    def forward(self,X): 
        T_1 = self.norm_1(X) 
        T_2 = self.attention_layer(T_1)
        T_3 = T_2 + X
        T_4 = self.norm_2(T_3)
        T_5 = self.ffn(T_4)
        H = T_5 + T_3 

        return H
    

In [47]:
block = TransformerBlock(N=3, model_dim=4, key_dim=2, hidden_dim=8, num_heads=2)
block(X)

tensor([[ 1.0590,  0.0180, -0.3674, -0.2382],
        [ 0.6360,  0.5467, -0.3314,  0.2132],
        [ 1.0297,  0.3499,  0.0983, -0.3943]], grad_fn=<AddBackward0>)

In [67]:
class Transformer(nn.Module): 
    def __init__(self, 
                 N, 
                 model_dim, 
                 key_dim, 
                 hidden_dim, 
                 num_heads=1,
                 num_stack=1
                ):
        super().__init__()

        if num_stack < 1: 
            raise ValueError("num_stack cannot be less than 1!")

        # missing language head, embedding/unembedding matricies
        blocks = [TransformerBlock(N, model_dim, key_dim, hidden_dim, num_heads) for _ in range(num_stack)] 
        self.model = nn.Sequential(*blocks)
    def forward(self,X): 
        return self.model(X) 
        

In [70]:
model = Transformer(N=3, model_dim=4, key_dim=2, hidden_dim=8, num_heads=2, num_stack=9)
model.state_dict()

OrderedDict([('model.0.attention_layer.W_Q',
              tensor([[[0.5050, 0.2384],
                       [0.7371, 0.0694],
                       [0.2770, 0.8450],
                       [0.1893, 0.7880]],
              
                      [[0.0665, 0.0112],
                       [0.9444, 0.6999],
                       [0.4469, 0.6458],
                       [0.7735, 0.2013]]])),
             ('model.0.attention_layer.W_K',
              tensor([[[0.2654, 0.5657],
                       [0.2708, 0.6153],
                       [0.1084, 0.9999],
                       [0.2192, 0.9200]],
              
                      [[0.6538, 0.9476],
                       [0.9976, 0.9429],
                       [0.2572, 0.9861],
                       [0.7417, 0.9392]]])),
             ('model.0.attention_layer.W_V',
              tensor([[[0.8646, 0.3119],
                       [0.4881, 0.1162],
                       [0.6396, 0.5684],
                       [0.2829, 0.1063]],
    