Transformer From Scratch 
========================

Reading along Dan Jurafsky and James H. Martin's [Speech and Language Processing](https://web.stanford.edu/~jurafsky/slp3/) book, I decided to follow through Chapter 8 of their book to implement a Transformer using Pytorch. It is my goal to have a working transformer which I can use to train on the guitar dataset. I know some linear algebra, the book essentially gives the entire algorithm in terms of linear algebra, and pytorch provides a nice but still very informative abstractions for doing linear algebra. I had no reason not to pursue this project on top of whatever I proposed to do initially. 

## 

Attention Layer
---------------

At the heart of Transformer is the **attention layer**. It is a mechanism that allows words(tokens) to gain contextual meaning from their surrounding words(tokens). It can have multiple **"heads"**, where each "head" can be thought of as a specialist who asks particular set of questions given some data. For instance, one head could focus solely on grammar while another could instead focus on sentiments (even though that might not be exactly what occurs under the hood).

Each head's job, then, is to ask the right kinds of *questions* to choose which of previous words it has seen matters the most to the current word. To do this, each head consists of three main components: **Query**, **Key**, and **Value** weight matrices. 

<!-- 
    essentially, what it is at the end of the day is weighted sum, but it's obviously lot more complicated than that
    don't forget to write out the equations that I have referenced
    maybe throw in some pictures
    say something about how masking and softmax is used to determine what key's to focus on
    also explain how results from different heads are consolidated at the end
--!>

In [108]:
import torch 
import torch.nn as nn
import math

class AttentionLayer(nn.Module): 
    def __init__(self,
                 N,         # context length; how many tokens to consider at any point in time
                 model_dim, # d
                 key_dim,   # d_k
                 num_heads=1, 
                ):
        super().__init__() 

        if model_dim % num_heads != 0: 
            raise ValueError("model_dim is not divisible by num_heads!")
        
        self.N = N
        self.model_dim = model_dim
        self.key_dim = key_dim
        self.value_dim = model_dim//num_heads 
        self.num_heads = num_heads
        
        # Query Weights (num_heads, model_dim, key_dim)
        self.W_Q = nn.Parameter(torch.rand((num_heads, model_dim, key_dim)))
        # Key Weights   (num_heads, model_dim, key_dim)
        self.W_K = nn.Parameter(torch.rand((num_heads, model_dim, key_dim)))
        # Value Weights (num_heads, model_dim, value_dim)
        self.W_V = nn.Parameter(torch.rand((num_heads, model_dim, self.value_dim)))
        # Output Weights
        self.W_O = nn.Parameter(torch.rand((model_dim, model_dim)))

        # Mask (for autoregressive model)
        mask = torch.tensor([[0 if i>= j else -torch.inf for j in range(N)] for i in range(N)])
        self.register_buffer("mask", mask) # move mask to GPU
        
    def forward(self, X): # X has (N, model_dim) dimensions
        seq_len = X.shape[0]
        Q = X@self.W_Q # (num_heads, N, key_dim)
        K = X@self.W_K # (num_heads, N, key_dim)
        V = X@self.W_V # (num_heads, N, value_dim)

        current_mask = self.mask[:seq_len, :seq_len] # when seq_len < N

        attention = Q@(K.mT) / math.sqrt(key_dim) + current_mask # (num_heads, N (queries), N (keys)) 
        heads = nn.functional.softmax(attention, dim = -1)@V #(num_heads, N, value_dim)
        cat_heads = torch.cat(heads.unbind(), dim=1) # (N, value_dim) each and concatenate the columns to form (N, model_dim)
        A = cat_heads@self.W_O # (N, model_dim)

        return A
                    

In [111]:
X = torch.rand((3,4)) # 3 words represented as dim (1, 4) tensors
multihead_attention = AttentionLayer(N=3, model_dim=4, key_dim=2, num_heads=2)
multihead_attention.to("cuda")
multihead_attention(X.to("cuda"))


tensor([[1.9358, 2.6155, 3.2847, 1.3195],
        [1.4636, 2.0018, 2.4938, 0.9933],
        [1.3739, 1.8974, 2.3533, 0.9227]], device='cuda:0',
       grad_fn=<MmBackward0>)

In [None]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        pass
    def forward(self, X):
        pass

class LayerNorm(nn.Module): 
    def __init__(self): 
        super().__init__()
        pass
    def forward(self, X):
        pass