Transformer From Scratch 
========================

Reading along Dan Jurafsky and James H. Martin's [Speech and Language Processing](https://web.stanford.edu/~jurafsky/slp3/) book, I decided to follow through Chapter 8 of their book to implement a Transformer using Pytorch. It is my goal to have a working transformer which I can use to train on the guitar dataset. I know some linear algebra, the book essentially gives the entire algorithm in terms of linear algebra, and pytorch provides a nice but still very informative abstractions for doing linear algebra. I had no reason not to pursue this project on top of whatever I proposed to do initially. 

## 

Attention Layer
---------------

At the heart of Transformer is the **attention layer**. It is a mechanism that allows words(tokens) to gain contextual meaning from their surrounding words(tokens). It can have multiple **"heads"**, where each "head" can be thought of as a specialist who asks particular set of questions given some data. For instance, one head could focus solely on grammar while another could instead focus on sentiments (even though that might not be exactly what occurs under the hood).

Each head's job, then, is to ask the right kinds of *questions* to choose which of previous words it has seen matters the most to the current word. To do this, each head consists of three main components: **Query**, **Key**, and **Value** weight matrices. 

<!-- 
    essentially, what it is at the end of the day is weighted sum, but it's obviously lot more complicated than that
    don't forget to write out the equations that I have referenced
    maybe throw in some pictures
    say something about how masking and softmax is used to determine what key's to focus on
    also explain how results from different heads are consolidated at the end
--!>

In [1]:
import torch 
import torch.nn as nn
import math

class AttentionLayer(nn.Module): 
    def __init__(self,
                 model_dim, # d
                 key_dim,   # d_k
                 num_heads=1
                ):
        super().__init__() 

        if num_heads < 1:
            raise ValueError("num_heads cannot be less than 1!")

        if model_dim % num_heads != 0: 
            raise ValueError("model_dim is not divisible by num_heads!")
        
        self.model_dim = model_dim
        self.key_dim = key_dim
        self.value_dim = model_dim//num_heads 
        self.num_heads = num_heads

        # without batching considerations...
        #self.W_Q = nn.Parameter(torch.rand((num_heads, model_dim, key_dim))) 
        #self.W_K = nn.Parameter(torch.rand((num_heads, model_dim, key_dim))) 
        #self.W_V = nn.Parameter(torch.rand((num_heads, model_dim, self.value_dim))) 
        # self.W_O = nn.Parameter(torch.rand((model_dim, model_dim)))

        # using nn.Linear to automatically do batching  
        # concatenate heads vertically for broadcasting
        self.W_Q = nn.Linear(in_features=model_dim, out_features=num_heads * key_dim, bias=False)
        self.W_K = nn.Linear(in_features=model_dim, out_features=num_heads * key_dim, bias=False)
        self.W_V = nn.Linear(in_features=model_dim, out_features=num_heads * self.value_dim, bias=False)
        self.W_O = nn.Linear(in_features=num_heads * self.value_dim, out_features=model_dim, bias=False)
        
    def forward(self, X, H_enc=None, mask=None): # X has (batch_size, N, model_dim) dimensions
        K_input = H_enc if H_enc is not None else X
        V_input = H_enc if H_enc is not None else X

        batch_size = X.shape[0]
        N = X.shape[1]
        M = K_input.shape[1]

        # without batching considerations 
        #Q = X@self.W_Q # (num_heads, N, key_dim)
        #K = K_input@self.W_K # (num_heads, N, key_dim)
        #V = V_input@self.W_V # (num_heads, N, value_dim)

        Q = self.W_Q(X)       # (batch_size, N, num_heads * key_dim)   
        K = self.W_K(K_input) # (batch_size, M, num_heads * key_dim) 
        V = self.W_V(V_input) # (batch_size, M, num_heads * key_dim) 

        Q = Q.view(batch_size, N, self.num_heads, self.key_dim).transpose(1,2) 
        K = K.view(batch_size, M, self.num_heads, self.key_dim).transpose(1,2)
        V = V.view(batch_size, M, self.num_heads, self.value_dim).transpose(1,2)

        attention = Q@(K.mT) / math.sqrt(self.key_dim) # (batch_size, num_heads, N (queries), M (keys)) 
        if mask is not None:
            current_mask = mask[:N, :M] # I should consider lengths of N, M carefully 
            attention += current_mask 
            
        prob = nn.functional.softmax(attention, dim = -1) # dim -1 should be keys 
        values = prob@V #(batch_size, num_heads, N, value_dim)
        values_cat = values.transpose(1, 2).contiguous().view(batch_size, N, -1) # (batch_size, N, num_heads * value_dim)
        
        #cat_heads = torch.cat(heads.unbind(), dim=1) # (N, value_dim) each and concatenate the columns to form (N, model_dim)
        A = self.W_O(values_cat) # (N, model_dim)

        return A
                    

In [8]:
batch_size = 10
N = 10
model_dim = 24
num_heads = 4
key_dim = 3

M = 8
X = torch.rand((batch_size, N, model_dim)) # batch_size is 10, 3 words represented as dim (1, 4) tensors
Y = torch.rand((batch_size, M, model_dim)) # 3 words represented as dim (1, 4) tensors
mask = torch.tensor([[0 if i>= j else -torch.inf for j in range(N)] for i in range(N)])

multihead_attention = AttentionLayer(model_dim=model_dim, key_dim=key_dim, num_heads=num_heads)
multihead_attention(X, H_enc=Y, mask=mask).shape
#multihead_attention.to("cuda")
#multihead_attention(X.to("cuda"), Y.to("cuda"), mask=mask.to("cuda"))

torch.Size([10, 10, 24])

In [29]:
class FeedForward(nn.Module):
    def __init__(self, model_dim, hidden_dim):
        super().__init__()
        self.lin_1 = nn.Linear(in_features=model_dim, out_features=hidden_dim)
        self.lin_2 = nn.Linear(in_features=hidden_dim, out_features=model_dim)
        self.relu = nn.ReLU()
    def forward(self, X):
        return self.lin_2(self.relu(self.lin_1(X)))

class LayerNorm(nn.Module): 
    def __init__(self, model_dim, epsilon=1e-6): 
        super().__init__()
        # gamma and beta is to normalize features 
        self.gamma = nn.Parameter(torch.ones(model_dim)) # initialize gammas to ones because if initialized randomly to 0, it's dead signal
        self.beta = nn.Parameter(torch.zeros(model_dim))
        self.eps = epsilon
        
    def forward(self, X): # (batch_size, N, model_dim)
        model_dim = X.shape[-1]
        mean = torch.mean(X, dim=-1, keepdims=True)
        std = torch.std(X, dim=-1, keepdims=True)
        X_hat = (X - mean) / (std + self.eps) # add epsilon for numerical stability
        layer_norm = self.gamma * X_hat + self.beta
        return layer_norm

In [20]:
ff = FeedForward(model_dim=model_dim, hidden_dim=12)
ff(X)
norm = LayerNorm(model_dim=model_dim)
norm(X)

torch.Size([24]) torch.Size([24])


tensor([[[100.9493, 100.2801, 100.1258,  ..., 100.2354, 100.9412, 100.4811],
         [100.0720, 100.8310, 100.8976,  ..., 100.8133, 100.0188, 100.8997],
         [100.1473, 100.7575, 100.9484,  ..., 100.8251, 100.1841, 100.3081],
         ...,
         [100.2794, 100.0046, 100.7232,  ..., 100.8640, 100.9785, 100.0878],
         [100.8674, 100.9698, 100.3904,  ..., 100.3108, 100.5854, 100.9261],
         [100.2624, 100.3437, 100.0919,  ..., 100.0314, 100.0253, 100.1013]],

        [[100.8234, 100.7289, 100.7664,  ..., 100.5530, 100.8430, 100.4411],
         [100.7440, 100.0608, 100.2004,  ..., 100.1383, 100.9858, 100.4086],
         [100.1379, 100.4247, 100.3906,  ..., 100.4993, 100.9951, 100.6967],
         ...,
         [100.0161, 100.8644, 100.0283,  ..., 100.9211, 100.3531, 100.8766],
         [100.1065, 100.1919, 100.6886,  ..., 100.6221, 100.0700, 100.5314],
         [100.0892, 100.5095, 100.7930,  ..., 100.8587, 100.5753, 100.9021]],

        [[100.8972, 100.8424, 100.2850,  ...

In [52]:
class TransformerBlock(nn.Module): 
    def __init__(self, 
                 N, 
                 model_dim, 
                 key_dim, 
                 hidden_dim, 
                 num_heads=1,
                 cross_attention=False
                ): 
        super().__init__() 
        self.N = N
        self.model_dim = model_dim
        self.key_dim = key_dim 
        self.hidden_dim = hidden_dim 
        self.num_heads = num_heads
        self.cross_attention = cross_attention
        
        self.attention_layer = AttentionLayer(model_dim=model_dim, key_dim=key_dim, num_heads=num_heads)
        self.ffn = FeedForward(model_dim=model_dim, hidden_dim=hidden_dim)
        self.norm_1 = LayerNorm(model_dim=model_dim)
        self.norm_2 = LayerNorm(model_dim=model_dim)

        if self.cross_attention:
            self.cross_attention_layer = AttentionLayer(model_dim=model_dim, key_dim=key_dim, num_heads=num_heads)
            self.norm_3 = LayerNorm(model_dim=model_dim)

    def forward(self, X, H_enc=None, mask=None): 
        T_1 = self.norm_1(X) 
        T_2 = self.attention_layer(T_1, mask=mask)
        T_3 = T_2 + X
        if self.cross_attention: 
            if H_enc is None:
                raise ValueError("H_enc cannot be None if cross_attention is enabled!")
            T_a = self.norm_3(T_3) 
            T_b = self.cross_attention_layer(T_a, H_enc=H_enc)
            T_3 += T_b
        T_4 = self.norm_2(T_3)
        T_5 = self.ffn(T_4)
        H = T_5 + T_3 

        return H

In [53]:
encoder_block = TransformerBlock(N=N, model_dim=model_dim, key_dim=key_dim, hidden_dim=8, num_heads=num_heads)
decoder_block = TransformerBlock(N=N, model_dim=model_dim, key_dim=key_dim, hidden_dim=8, num_heads=num_heads, cross_attention=True)
encoder_block(X,mask=mask)
decoder_block(X, Y)

tensor([[[ 0.5950,  0.6532, -0.0883,  ...,  0.8696,  0.9251,  0.5034],
         [ 0.1859,  0.9912,  0.9921,  ...,  1.2090, -0.1045,  0.7301],
         [ 0.1552,  1.0726,  0.9090,  ...,  1.3952,  0.0828,  0.1536],
         ...,
         [ 0.1655,  0.2254,  0.4131,  ...,  1.6077,  1.0697, -0.1074],
         [ 0.5230,  1.3601,  0.4634,  ...,  0.6692,  0.5866,  1.0162],
         [ 0.2869,  0.6113,  0.0617,  ...,  0.6822, -0.1595, -0.0758]],

        [[ 0.3425,  1.2843,  0.5048,  ...,  0.9217,  1.1042,  0.7714],
         [ 0.8797,  0.4325,  0.0953,  ...,  0.7544,  1.1505,  0.9455],
         [ 0.1512,  0.3505,  0.6392,  ...,  0.5101,  1.2656,  0.6963],
         ...,
         [ 0.3292,  0.8944,  0.0065,  ...,  1.2100,  0.6854,  0.7299],
         [ 0.4246,  0.1905,  0.8780,  ...,  0.6986,  0.3605,  0.4341],
         [ 0.0440,  0.5063,  0.9983,  ...,  0.7885,  0.7122,  0.9352]],

        [[ 0.9221,  0.7674,  0.2664,  ...,  1.2873,  0.8794,  0.8752],
         [-0.0435, -0.1394,  0.7310,  ...,  0

In [46]:
class TransformerStack(nn.Module): 
    def __init__(self, 
                 N, 
                 model_dim, 
                 key_dim, 
                 hidden_dim, 
                 num_heads=1,
                 cross_attention=False,
                 num_stack=1
                ):
        super().__init__()

        if num_stack < 1: 
            raise ValueError("num_stack cannot be less than 1!")

        # using module list a
        self.blocks = nn.ModuleList([
            TransformerBlock(N, model_dim, key_dim, hidden_dim, num_heads, cross_attention=cross_attention) 
             for _ in range(num_stack)
        ])
    def forward(self, X, H_enc=None, mask=None): 
        for block in self.blocks:
            X = block(X, H_enc, mask)
        return X
        

In [54]:
model = TransformerStack(N=N, model_dim=model_dim, key_dim=key_dim, hidden_dim=8, num_heads=num_heads, num_stack=9)
model.state_dict()
model(X)

tensor([[[ 1.6120,  0.5040,  0.3937,  ...,  2.0055, -1.2134,  0.9813],
         [ 0.7487,  0.2654,  0.1714,  ...,  1.7507, -1.4817,  2.1662],
         [ 0.6117,  0.7551,  0.4294,  ...,  2.4666, -2.3493,  0.7458],
         ...,
         [ 0.6705,  0.2939,  0.7746,  ...,  2.3897, -1.4351,  1.0199],
         [ 1.4329,  0.9844,  0.4120,  ...,  1.9334, -1.7569,  1.4429],
         [ 1.1453,  0.5180,  0.2179,  ...,  1.2895, -2.1818,  0.3902]],

        [[ 0.9372,  0.2691,  0.6632,  ...,  1.8215, -1.2884, -0.1815],
         [ 0.6051,  0.4189,  0.4812,  ...,  1.6336, -1.1065, -0.1776],
         [ 0.1922,  0.2067,  0.5248,  ...,  1.2971, -0.9452,  1.3682],
         ...,
         [-0.3864,  1.0897,  0.0713,  ...,  1.4096, -1.7086,  1.2176],
         [-0.0498,  0.1457,  0.7907,  ...,  1.7237, -1.9853,  0.7064],
         [ 0.1061,  0.6879,  0.8388,  ...,  1.4432, -1.7721,  1.8719]],

        [[ 0.4825,  0.7374,  0.0675,  ...,  2.1558, -0.2209, -0.0694],
         [-0.3750,  0.3154,  0.1831,  ...,  2

In [None]:
# from https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html
# i don't completely understand positional encoding yet, but I have built the intuition that 
# it is analogous to how binary numbers encode numbers; smaller bits flips more frequently 
# than larger bits; this is modeled by the sinusodial waves 
# it also takes advantage of linearity of trigonometric addition formulas, which supposedly 
# helps the model to figure out relative positioning...
# https://medium.com/thedeephub/positional-encoding-explained-a-deep-dive-into-transformer-pe-65cfe8cfe10b 
class PositionalEncoding(nn.Module):

    def __init__(self, model_dim: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
        pe = torch.zeros(max_len, 1, model_dim)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [None]:
# what I need to do to finish up the encoder_decoder architecture 
# the only difference for the decoder architecture is the cross attention layer `
# which is much like the self-attension layer except that it is using both the final 
# H of the encoder and that of decoder to do query-key matching, thus decoder needs to 
# take in memory from encoder 

In [None]:
import torch 
import math

# not the most efficient way to do so, but in the way that my mind thinks about things
N = 5
model_dim = 4
batch_size = 10
key_dim = 2
num_heads = 2
num_batch = 10
value_dim = model_dim // num_heads
mask = torch.tensor([[0 if i>= j else -torch.inf for j in range(N)] for i in range(N)])

X = torch.ones((num_batch, N, model_dim))
X[:,1] = 2
X[:,:,2] = 3
print(X) # should not be symmetrical for demonstration purposes
single_head = torch.ones((model_dim, key_dim))
single_head_v = torch.ones((model_dim, value_dim))
W_Q = torch.cat([single_head * i for i in range(2, 2 + num_heads)], dim=1)
W_K = torch.cat([single_head * i for i in range(4, 4 + num_heads)], dim=1)
W_V = torch.cat([single_head_v * i for i in range(6, 6 + num_heads)], dim=1)
W_O = torch.rand((model_dim, model_dim))
print(X.shape)
print(W_Q.shape)
Q = X@W_Q
K = X@W_K
V = X@W_V
print(Q.shape) # batch_size, N, key_dim * num_heads
print(Q.unbind()[0]) # essentially Q's from different heads concatenated to be next to each other "vertically"
Q_reshaped = Q.view(num_batch, N, num_heads, key_dim).transpose(1,2) 
K_reshaped = K.view(num_batch, N, num_heads, key_dim).transpose(1,2)
V_reshaped = V.view(num_batch, N, num_heads, value_dim).transpose(1,2)
print(Q_reshaped.unbind()[0])
print(Q.data_ptr() == Q_reshaped.data_ptr()) # this should be False, meaning reshape has created new tensor, which is not memory efficient
attention = Q_reshaped@K_reshaped.mT / math.sqrt(key_dim) + mask # mask broadcasting
print(attention.shape, V_reshaped.shape)
probs = torch.nn.functional.softmax(attention, dim=-1) 
values = probs@V_reshaped / 24 # just dividing by arbitrary number for ease of seeing value matrix of each head as a unified number
print(values.shape) # batch_size, num_heads, N, value_dim
# batch_size, num_heads, N, value_dim --> batch_size, num_heads, value_dim, N --> batch_size, num_heads*value_dim, N --> batch_size, N, num_heads*value_dim (model_dim)
values_cat = values.transpose(-2, -1).flatten(start_dim=1, end_dim=2).transpose(-2, -1)
# batch_size, num_heads, value_dim, N --> batch_size, N, num_heads, value_dim --> batch_size, N, num_heads*value_dim
values_cat_2 = values.transpose(1, 2).contiguous().view(batch_size, N, -1)
values_cat == values_cat_2
#values_cat@W_O