### Single Head Attention

In [25]:
import torch 
import torch.nn as nn
import torch.optim as optim 

In [None]:
# Defining parameters for the transformer model
n_embed = 120
n_layers = 8
n_heads = 8
head_size = n_embed // n_heads
block_size = 128  # Context size for the model
dropout = 0.2  # Dropout rate for regularization
vocab_size = 8000

We know that for each token it consist a embedding vector of dimension n_embed.

i.e.. Ei vector of size n_embed*1                ,where i runs to block_size

And for each head there is a query matrix and key matrix of size head_size*n_embed. 
Which is applied to same x for self-head attention

and Qi = Wq * Ei = head_size*1 for each block_size and batch_size

It can be represented as Linear(n_embed,head_size)

How much each query vector attends to key vector is represented from dot product of Ki.Qi at each cell of matrix of size TxT

this is represented by 
Attend = query @ key

and the x is represented with the down projection to the dimension of head_size which is concatenated later
Vi = Wv * Ei

output from single head = attend @ Vi

In [None]:
class SingleHeadAttention(nn.Module):

    def __init__(self, n_embed, head_size):
        super().__init__()
        self.n_embed = n_embed
        self.head_size = head_size
        self.key = nn.Linear(n_embed, head_size)
        self.query = nn.Linear(n_embed, head_size)
        self.value = nn.Linear(n_embed, head_size)
    

    def forward(self, x):
        B, T, C = x.shape

    # x is a shape of Batch_size x Block_size x n_embed
        key= self.key(x)        # Batch_size x Block_size x head_size ( B, T, C)
        query = self.query(x)   # Batch_size x Block_size x head_size ( B, T, C)
        # Batch_size x Block_size x head_size @ Batch_size x head_size * Block_size 
        attend = query @ key.transpose(-2, -1)  # Batch_size x Block_size x Block_size ( B, T, T)
        attend = attend / (self.head_size ** 0.5)  #  Scaled Dot-Product Attention Attention(Q,K,V)=softmax(QK^T/sqrt(d_k))V  ( B, T, T)
        trill = torch.tril(torch.ones(attend.shape[-1], attend.shape[-1]))  # Lower triangular matrix of block_size
        attend = attend.masked_fill(trill == 0, float('-inf'))  # Masking future tokens
        attend = torch.softmax(attend, dim=-1) # Column-wise softmax IG

        value = self.value(x) # Batch_size x Block_size x head_size ( B, T, head_size)

        out = attend @ value  # Batch_size x Block_size x head_size   ( B, T, head_size) 

        return out




In [28]:
class MultiHeadAttention(nn.Module):
    def __init__(self,n_embed, n_heads):
        super().__init__()
        self.n_embed = n_embed
        self.n_heads = n_heads
        self.head_size = self.n_embed // self.n_heads
        self.heads = nn.ModuleList([SingleHeadAttention(n_embed, self.head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)  # Concatenate outputs from all heads
        out = self.proj(out)
        out = self.dropout(out)  # ( B, T, C)
        return out

    

In [29]:
class FeedForward(nn.Module):

    # Multi-layer perceptron (MLP) for feed-forward network in transformer
    def __init__(self, n_embed):
        super().__init__()
        self.network = nn.Sequential(
        nn.Linear(n_embed, 4 * n_embed),  # Up-projection min of 4* n_embed from the paper Attention Is All You Need
        nn.ReLU(),
        nn.Linear(4 * n_embed, n_embed),  # Down-projection back to n_embed
        nn.Dropout(dropout)
        )
     
    def forward(self, x):
        return self.network(x)
    

In [30]:
#Single Bloack of the Transformer
class Block(nn.Module):

    def __init__(self, n_embed, n_heads):
        super().__init__()

        self.attention = MultiHeadAttention(n_embed, n_heads)
        self.feed_forward = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        # Creating a residual connection ( Skip Connection ) around the attention layer
        x = x + self.attention(self.ln1(x))  # Layer normalization before attention :- Pre-Normalization
        x = x + self.feed_forward(self.ln2(x))
        return x

In [37]:
## The final Transformer Block 
print(vocab_size, n_embed)
class AakritiTransformer(nn.Module):
    def __init__(self, n_layers, n_embed, n_heads):
        super().__init__()

        # Embedding layers
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.pos_embedding_table =  nn.Embedding(block_size,n_embed )

        self.n_layers = n_layers
        self.n_embed = n_embed
        self.linear = nn.Linear(n_embed, vocab_size)
        self.bn1 = nn.BatchNorm1d(vocab_size)
        self.n_heads = n_heads
        self.network = nn.Sequential(
            *[Block(self.n_embed, self.n_heads  ) for _ in range(self.n_layers)]
        )

    def forward(self, x):
        embeddings = self.token_embedding_table(x) + self.pos_embedding_table(x)
        out = self.network(embeddings)
        out = self.linear(out)
        out = self.bn1(out)
        return out

8000 120


In [38]:
print(n_layers, n_heads)

8 8


In [40]:
model = AakritiTransformer(n_layers, n_embed, n_heads)

In [None]:

from torchinfo import summary
from torchviz import make_dot


summary(model)

Layer (type:depth-idx)                                  Param #
AakritiTransformer                                      --
├─Embedding: 1-1                                        960,000
├─Embedding: 1-2                                        15,360
├─Linear: 1-3                                           968,000
├─BatchNorm1d: 1-4                                      16,000
├─Sequential: 1-5                                       --
│    └─Block: 2-1                                       --
│    │    └─MultiHeadAttention: 3-1                     58,080
│    │    └─FeedForward: 3-2                            115,800
│    │    └─LayerNorm: 3-3                              240
│    │    └─LayerNorm: 3-4                              240
│    └─Block: 2-2                                       --
│    │    └─MultiHeadAttention: 3-5                     58,080
│    │    └─FeedForward: 3-6                            115,800
│    │    └─LayerNorm: 3-7                              240
│    │    └─