In [2]:
import torch
import torch.nn as nn
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batchs=torch.stack((inputs,inputs),dim=0)
batchs.shape

torch.Size([2, 6, 3])

# METHOD 1: Using simple approach

In [5]:
class casualattention(nn.Module):
    def __init__(self,din,dout,biasbool,batchsize,dropoutsize):
        super().__init__()
        self.W_query=nn.Linear(din,dout,bias=biasbool)
        self.W_keys=nn.Linear(din,dout,bias=biasbool)
        self.W_values=nn.Linear(din,dout,bias=biasbool)
        self.dropout=nn.Dropout(dropoutsize)
        self.register_buffer('mask', torch.triu(torch.ones(batchsize, batchsize), diagonal=1))

    def forward(self,x):
        b, num_tokens, d_in = x.shape
        query= self.W_query (x)
        keys= self.W_keys (x)
        values=self.W_values (x)

        scores=query @ keys.transpose(1,2)
        scores.masked_fill_(  
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights=torch.softmax(scores/(keys.shape[-1])**0.5,dim=-1)
        attn_weights = self.dropout(attn_weights) 

        context_vec = attn_weights @ values
        return context_vec
    
class Multiheadattention(nn.Module):
    def __init__ (self,din,dout,biasbool,batchsize,dropoutsize,noofattentionhead):
        super().__init__()
        self.heads=nn.ModuleList([casualattention(din,dout,biasbool,batchsize,dropoutsize) for i in range(noofattentionhead)])
        self.out_proj=nn.Linear(dout*noofattentionhead,dout*noofattentionhead)
    
    def forward(self,x):
        context_vec=torch.cat([head(x) for head in self.heads],dim=-1)
        return self.out_proj(context_vec)


torch.manual_seed(123)

context_length = batchs.shape[1]
ca = Multiheadattention(3, 2,False, context_length, 0.0,8)

context_vecs = ca(batchs)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

        



tensor([[[ 0.1232, -0.1066, -0.0376, -0.1082, -0.3488,  0.0211,  0.3412,
           0.1926, -0.3484,  0.5276,  0.3979,  0.1163,  0.3184,  0.0403,
           0.0790, -0.4035],
         [ 0.0984, -0.0307, -0.0094, -0.1888, -0.4576,  0.0386,  0.2738,
           0.3218, -0.4608,  0.6029,  0.3603,  0.2969,  0.4167,  0.1128,
           0.1826, -0.4163],
         [ 0.0886, -0.0022, -0.0031, -0.2085, -0.4895,  0.0468,  0.2533,
           0.3586, -0.4961,  0.6282,  0.3520,  0.3553,  0.4418,  0.1322,
           0.2140, -0.4146],
         [ 0.1011, -0.0076, -0.0316, -0.1693, -0.4635,  0.0505,  0.2437,
           0.3211, -0.4247,  0.5862,  0.3282,  0.3417,  0.3826,  0.1484,
           0.1847, -0.3842],
         [ 0.0954, -0.0044, -0.0319, -0.1285, -0.4194,  0.0414,  0.2410,
           0.2809, -0.3754,  0.5705,  0.3284,  0.3530,  0.3389,  0.1275,
           0.1464, -0.3400],
         [ 0.1050, -0.0089, -0.0432, -0.1365, -0.4345,  0.0466,  0.2355,
           0.2888, -0.3688,  0.5581,  0.3132,  0.343

# Method 2: Using split approach and one class 

In [8]:
class multiheadv2(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,attention_head,boolbias):
        super().__init__()
        self.head_dim=d_out//attention_head
        self.d_out=d_out
        self.attention_head=attention_head

        self.W_query = nn.Linear(d_in, d_out, bias=boolbias)
        self.W_key = nn.Linear(d_in, d_out, bias=boolbias)
        self.W_value = nn.Linear(d_in, d_out, bias=boolbias)

        self.out_proj = nn.Linear(d_out, d_out)  
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self,x):
        b,num_token,d_out=x.shape

        keys = self.W_key(x) 
        queries = self.W_query(x)
        values = self.W_value(x)

        keys=keys.view(b,num_token,self.attention_head,self.head_dim)
        queries=queries.view(b,num_token,self.attention_head,self.head_dim)
        values=values.view(b,num_token,self.attention_head,self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_score=queries @ keys.transpose(2,3)

        mask_bool = self.mask.bool()[:num_token, :num_token]
        attn_score.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_score / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)


        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        context_vec = context_vec.contiguous().view(b, num_token, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec
    
torch.manual_seed(123)

context_length = batchs.shape[1]
ca = multiheadv2(3, 24,context_length, 0.0,8,False)

context_vecs = ca(batchs)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)



        


tensor([[[ 0.3008, -0.1006,  0.3409,  0.0193,  0.2135,  0.3918,  0.0463,
          -0.2520,  0.2986,  0.1838, -0.2335,  0.3096,  0.5285, -0.1114,
          -0.0949,  0.1673,  0.4495, -0.2283, -0.1096,  0.0981, -0.5552,
           0.1351,  0.4825, -0.0715],
         [ 0.2895, -0.0249,  0.3578, -0.0220,  0.1746,  0.3995,  0.0053,
          -0.2650,  0.2269,  0.2103, -0.2278,  0.3667,  0.5568, -0.1256,
          -0.0434,  0.1500,  0.5413, -0.2296, -0.1346,  0.0883, -0.6241,
           0.0612,  0.3667, -0.0474],
         [ 0.2864,  0.0035,  0.3658, -0.0380,  0.1613,  0.3977, -0.0085,
          -0.2692,  0.2044,  0.2146, -0.2232,  0.3859,  0.5650, -0.1307,
          -0.0279,  0.1420,  0.5711, -0.2296, -0.1400,  0.0880, -0.6438,
           0.0396,  0.3268, -0.0373],
         [ 0.2696, -0.0032,  0.3284, -0.0109,  0.1526,  0.3744, -0.0135,
          -0.2644,  0.1833,  0.1758, -0.1899,  0.3336,  0.5152, -0.0962,
          -0.0169,  0.1189,  0.5340, -0.2107, -0.1523,  0.0530, -0.6005,
          