In [37]:
import torch.nn as nn


In [38]:
import torch
class causal_attention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        # self.d_in = d_in
        self.w_query = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

    def forward(self,x):
        b,num_tokens,d_in = x.shape
        keys = self.w_key(x)
        queries = self.w_query(x)
        values = self.w_value(x)

        attn_score = queries @ keys.transpose(1,2)
        attn_score.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens],-torch.inf)
        attn_weight = torch.softmax(attn_score/keys.shape[-1]**0.5,dim=-1)
        attn_weight = self.dropout(attn_weight)

        context_vector = attn_weight @ values
        return context_vector

In [39]:
class MultiHeadAttention_v1(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias = False):
        super().__init__()
        self.heads  = nn.ModuleList([causal_attention(d_in,d_out,context_length,dropout,qkv_bias) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads],dim=-1)

In [40]:
import torch

inputs = torch.tensor([
    [0.43,0.15,0.89],       # your
    [0.55,0.87,0.66],       # journey
    [0.57,0.85,0.64],       # starts
    [0.22,0.58,0.33],       # with
    [0.77,0.25,0.10],       # one
    [0.05,0.80,0.55]        # step
])

In [41]:
batch = torch.stack((inputs,inputs),dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [42]:
## apllying multihead attention to this batch
torch.manual_seed(123)
context_length = batch.shape[1]
d_in,d_out = 3,2
mha =MultiHeadAttention_v1(d_in,d_out,context_length,0.5,num_heads=3)
context_vextor = mha(batch)
print(context_vextor)
print(context_vextor.shape)


tensor([[[-0.9038,  0.4432,  0.9544,  0.2127,  0.0000,  0.0000],
         [-0.7381, -0.2026,  0.0000,  0.0000,  1.1584,  0.6021],
         [-0.7751,  0.0077,  0.4847,  0.3565,  0.4538,  0.2094],
         [-0.4090,  0.0315,  0.5655,  0.3037,  0.5935,  0.3046],
         [-0.4745,  0.0076,  0.4557,  0.2409,  0.8967,  0.4397],
         [-0.5318, -0.0458,  1.0153,  0.6987,  0.3049,  0.1135]],

        [[ 0.0000,  0.0000,  0.9544,  0.2127,  0.9131,  0.5458],
         [-0.7381, -0.2026,  1.1781,  0.6513,  0.6619,  0.3054],
         [-0.2883,  0.1414,  0.7557,  0.4154,  1.2498,  0.6203],
         [-1.1349, -0.1685,  0.8795,  0.6697,  0.5447,  0.2524],
         [-0.3373, -0.1226,  0.2740,  0.2004,  0.9127,  0.3868],
         [-0.6897, -0.0976,  0.6164,  0.4248,  0.3936,  0.2035]]],
       grad_fn=<CatBackward0>)
torch.Size([2, 6, 6])


### Implimenting the multiple head attention in a class with weight split

In [54]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias = False):
        super().__init__()
        assert(d_out % num_heads==0),\
            "d_out must be divisible by nums head"
        
        self.d_out = d_out
        self.num_head = num_heads
        self.head_dim = d_out//num_heads
        # self.d_in = d_in
        self.w_query = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in,d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

    def forward(self,x):
        b,num_token,d_in = x.shape

        keys = self.w_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.w_query(x)
        values = self.w_value(x)

        keys = keys.view(b,num_token,self.num_head,self.head_dim)
        queries = queries.view(b,num_token,self.num_head,self.head_dim)
        values = values.view(b,num_token,self.num_head,self.head_dim)

        #grouping by num_heads
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        queries = queries.transpose(1,2)

        # calculating attention score
        attn_score = queries @ keys.transpose(2,3)

        # calculating attention weigths,masking, scaling and dropout
        mask_bool = self.mask.bool()[:num_token,:num_token]
        attn_score= attn_score.masked_fill_(mask_bool, - torch.inf)
        attn_weight = torch.softmax(attn_score/keys.shape[-1]**0.5, dim=-1)
        attn_weight = self.dropout(attn_weight)
        
        #calculating the context vector
        context_vector = attn_weight @ values #ntokn x ntoken * ntoken x head_dim
        # trasposing to get all the context vextor togeth
        context_vector = context_vector.transpose(1,2)

        # combining heads 
        context_vector = context_vector.contiguous().view(b,num_token,self.d_out)
        context_vector = self.out_proj(context_vector) # optional projection
        return context_vector

In [61]:
torch.manual_seed(123)
batch = torch.stack((inputs,inputs),dim=0)
batch_size, context_length, d_in = batch.shape
d_out = 3
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=3)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.0766,  0.0755, -0.0321],
         [ 0.0311,  0.1048, -0.0368],
         [ 0.0165,  0.1088, -0.0409],
         [-0.0470,  0.0841, -0.0825],
         [-0.1018,  0.0327, -0.1292],
         [-0.1060,  0.0508, -0.1246]],

        [[ 0.0766,  0.0755, -0.0321],
         [ 0.0311,  0.1048, -0.0368],
         [ 0.0165,  0.1088, -0.0409],
         [-0.0470,  0.0841, -0.0825],
         [-0.1018,  0.0327, -0.1292],
         [-0.1060,  0.0508, -0.1246]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 3])
