In [1]:
# here we only have a multi head attention sublayer
# not a full transformer block 

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [4]:
# data hyperparas
seq_len = 8 # aka context length


#model hyperparas
embed_dim = 128
n_heads = 4 #n embed_dim/n_heads must be int

# training hyperpara
batch_size=5

In [10]:
a=10
print(f"Hello is {'raeez':>28} {a}")

Hello is                        raeez 10


class for multihead attention

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_dim):
        
        super().__init__()
    
        #head dimensionality is embed_dim split across the heads
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
    
        # num_heads Q,K and V matrices, initialized as one "super head"
            # note: in model 5 , these 3 matrices are combined into one
        self.key = nn.Linear(embed_dim,embed_dim, bias=False)
        self.query = nn.Linear(embed_dim,embed_dim, bias=False)
        self.value = nn.Linear(embed_dim,embed_dim, bias=False)
    
        #final linear projection merges the heads outputs
        self.W0 = nn.Linear(embed_dim, embed_dim,bias=False)
    
    def forward(self, x, track_sizes=False):
        # extract the dimension size of the inputs(token embedds)
        B, T, E = x.shape # [batch, tokens (or seq_len), embed_dim]
        if track_sizes: print(f"1) {'Input data shape:' :>28} {x.shape}")

        #push data through Q,K and V (acutally mulitple heads still in same matrix)
        q = self.query(x) # [batch, seq_len, embed_dim]
        k = self.key(x)
        v = self.value(x)
        if track_sizes: print(f"2) {'q/k/v pre-split shape':>28} {q.shape}")

        # reshape to split up the heads (note: head splitting done after X.Wq
        q = q.view(B, T, self.num_heads, self.head_dim)
        k = k.view(B, T, self.num_heads, self.head_dim)
        v = v.view(B, T, self.num_heads, self.head_dim)

        # but pytorch SDPA func needs the shape to be [B, num_heads, T, head_dim]
        #(1,2) bcoz batch is 1st dim, and dont want to transpose batch dim
        # we preserve batch, head_dim
        q = q.transpose(1,2) 
        k = k.transpose(1,2)
        v = v.transpose(1,2)
        if track_sizes: print(f"3) {'q/k/v post-split shape':>28} {q.shape}")

        # now we call SDPA
        out = F.scaled_dot_product_attention(q,k,v,is_causal=True)
        if track_sizes: print(f"4) {'Data post attention shape':>28} {out.shape}")

        # but our code still needs [B,T,num_heads, head_dim]
        out = out.transpose(1,2)
        if track_sizes: print(f"5) {'Post attention data reshape':>28} {out.shape}")

        # merge heads back into embed_dim
        out = out.reshape(B, T, E) # this is the [A1A2A3...Ah] line in notebook
        if track_sizes: print(f"6) {'Data merged to size':>28} {out.shape}")

        #finallt apply linear mixing matrix
        out = self.W0(out)
        if track_sizes: print(f"7) {'Post-MHA H0 linear mixing':>28} {out.shape}")

        return out

In [14]:
mha = MultiHeadAttention(n_heads, embed_dim)
mha

MultiHeadAttention(
  (key): Linear(in_features=128, out_features=128, bias=False)
  (query): Linear(in_features=128, out_features=128, bias=False)
  (value): Linear(in_features=128, out_features=128, bias=False)
  (W0): Linear(in_features=128, out_features=128, bias=False)
)

In [15]:
# run some fake data though
data = torch.randn(size=(batch_size, seq_len,embed_dim))
out = mha(data)
print(f'Input size:{data.shape}')
print(f'Output size:{out.shape}')

Input size:torch.Size([5, 8, 128])
Output size:torch.Size([5, 8, 128])


In [20]:
print(f'     Seq length: {seq_len}')
print(f'embedding dimen: {embed_dim}')
print(f'    No of heads: {n_heads:}')
print(f'     Head dimensionality: {embed_dim // n_heads}')

print('\n Dimensions of data as iit passes through attention sublayer of one transformer block')
out = mha(data,track_sizes=True)

     Seq length: 8
embedding dimen: 128
    No of heads: 4
     Head dimensionality: 32

 Dimensions of data as iit passes through attention sublayer of one transformer block
1)            Input data shape: torch.Size([5, 8, 128])
2)        q/k/v pre-split shape torch.Size([5, 8, 128])
3)       q/k/v post-split shape torch.Size([5, 4, 8, 32])
4)    Data post attention shape torch.Size([5, 4, 8, 32])
5)  Post attention data reshape torch.Size([5, 8, 4, 32])
6)          Data merged to size torch.Size([5, 8, 128])
7)    Post-MHA H0 linear mixing torch.Size([5, 8, 128])
