In [1]:
import torch
import torch.nn as nn


In [13]:
class SingleHeadAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,kqv_bias=False):
        super().__init__()
        self.W_Q = torch.nn.Linear(d_in,d_out,bias=kqv_bias)
        self.W_K = torch.nn.Linear(d_in,d_out,bias=kqv_bias)
        self.W_V = torch.nn.Linear(d_in,d_out,bias=kqv_bias)
        self.dropout = torch.nn.Dropout(dropout)
        #self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length),diagonal=1))
        self.mask = torch.triu(torch.ones(context_length,context_length),diagonal=1)


    def forward(self,x):
        b,context_length,dimention = x.shape
        queries = self.W_Q(x) # (b , c , d_out)
        keys = self.W_K(x)    # (b , c , d_out) 
        values = self.W_V(x)  # (b , c , d_out)
        d_k = queries.shape[-1]
        attention_score = queries @ keys.transpose(1,2) # (b , c , c)
        attention_score = attention_score.masked_fill(self.mask.bool(),-torch.inf) # (b , c , c)
        attention_weight = torch.softmax(attention_score/d_k**0.5,dim=-1) # (b , c , c)
        attention_weight = self.dropout(attention_weight) # (b , c , c)
        context_vector = attention_weight @ values # (b , c , d_out)
        #return (context_vector , attention_weight , attention_score , queries , keys , values) 
        return context_vector

In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self,num_heads,d_in,d_out,context_length,dropout,kqv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([SingleHeadAttention(d_in,d_out,context_length,dropout,kqv_bias) for _ in range(num_heads)])

    def forward(self,x):
        return torch.cat([head(x) for head in self.heads],dim=-1)



In [23]:
torch.manual_seed(123)
d_in = 3
d_out = 2
context_length = 5
dropout = 0
x = torch.randn(context_length,din)
x_batch = torch.stack((x,x,x))
num_heads = 2


In [24]:
mha = MultiHeadAttention(num_heads,d_in,d_out,context_length,dropout)
mha(x_batch)


tensor([[[ 0.0406, -0.0787,  0.0990,  0.0047],
         [ 0.0022, -0.3431,  0.3303,  0.0755],
         [ 0.0025, -0.3457,  0.4547,  0.0703],
         [ 0.0142, -0.2599,  0.3762,  0.0511],
         [ 0.0072, -0.1075,  0.1825,  0.0155]],

        [[ 0.0406, -0.0787,  0.0990,  0.0047],
         [ 0.0022, -0.3431,  0.3303,  0.0755],
         [ 0.0025, -0.3457,  0.4547,  0.0703],
         [ 0.0142, -0.2599,  0.3762,  0.0511],
         [ 0.0072, -0.1075,  0.1825,  0.0155]],

        [[ 0.0406, -0.0787,  0.0990,  0.0047],
         [ 0.0022, -0.3431,  0.3303,  0.0755],
         [ 0.0025, -0.3457,  0.4547,  0.0703],
         [ 0.0142, -0.2599,  0.3762,  0.0511],
         [ 0.0072, -0.1075,  0.1825,  0.0155]]], grad_fn=<CatBackward0>)