In [2]:
import torch
import torch.nn as nn


In [None]:
class MultiHeadAttentionOpt(nn.Module):
    
    def __init__(self,d_in,d_out,context_len,num_heads,dropout,kqv_bias=False):
        super().__init__()
        
        self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        assert  d_out%num_heads == 0 , "d_out not divisible by num_heads"
        self.W_Q = nn.Linear(d_in,d_out,kqv_bias)
        self.W_K = nn.Linear(d_in,d_out,kqv_bias)
        self.W_V = nn.Linear(d_in,d_out,kqv_bias)
        self.out_proj = nn.Linear(d_out,d_out,kqv_bias)
        self.dropout = nn.Dropout(dropout)
        self.mask = torch.triu(torch.ones(context_len,context_len),diagonal=1)
        self.d_head = self.d_out//self.num_heads



    def forward(self,x):
        queries = self.W_Q(x)
        keys = self.W_K(x)
        values = self.W_V(x)
        batch_size , context_len , d_out = queries.shape
        #print(batch_size,context_len,self.num_heads,self.d_head)
        queries = queries.view(batch_size,context_len,self.num_heads,self.d_head) # (b , c , h , d_h)
        keys = queries.view(batch_size,context_len,self.num_heads,self.d_head)    # (b , c , h , d_h)
        values = queries.view(batch_size,context_len,self.num_heads,self.d_head)  # (b , c , h , d_h)

        queries = queries.transpose(1,2) # (b , h , c , d_h)
        keys = keys.transpose(1,2)       # (b , h , c , d_h)
        values = values.transpose(1,2)   # (b , h , c , d_h)   

        attention_score = queries @ keys.transpose(2,3) # (b , h , c , c)
        attention_weight = attention_score.masked_fill(self.mask.bool(),-torch.inf) # (b , h , c , c)
        attention_weight= torch.softmax(attention_score/self.d_head**0.5,dim=-1) # (b , h , c , c)
        context_vector = attention_weight @ values # (b , h , c , d_h)
        context_vector = context_vector.transpose(1,2) # (b , c , h , d_h)
        context_vector = context_vector.reshape(batch_size,context_len,d_out) # (b , c , dout = h*d_h)
        context_vector = self.out_proj(context_vector)
        return context_vector


In [8]:
torch.manual_seed(123)
din = 3
dout = 4
context_length = 5
dropout = 0
num_heads = 2
x = torch.randn(context_length,din)
x_batch = torch.stack((x,x,x))
sa = MultiHeadAttentionOpt(din,dout,context_length,num_heads,dropout)
context_vector = sa(x_batch)
context_vector.shape
#context_vector.view(3,5,2,2)


torch.Size([3, 5, 4])