In [108]:
import torch 
import torch.nn as nn
import math

class AttentionLayer(nn.Module): 
    def __init__(self,
                 N,         # context length; how many tokens to consider at any point in time
                 model_dim, # d
                 key_dim,   # d_k
                 num_heads=1, 
                ):
        super().__init__() 

        if model_dim % num_heads != 0: 
            raise ValueError("model_dim is not divisible by num_heads!")
        
        self.N = N
        self.model_dim = model_dim
        self.key_dim = key_dim
        self.value_dim = model_dim//num_heads 
        self.num_heads = num_heads
        
        # Query Weights (num_heads, model_dim, key_dim)
        self.W_Q = nn.Parameter(torch.rand((num_heads, model_dim, key_dim)))
        # Key Weights   (num_heads, model_dim, key_dim)
        self.W_K = nn.Parameter(torch.rand((num_heads, model_dim, key_dim)))
        # Value Weights (num_heads, model_dim, value_dim)
        self.W_V = nn.Parameter(torch.rand((num_heads, model_dim, self.value_dim)))
        # Output Weights
        self.W_O = nn.Parameter(torch.rand((model_dim, model_dim)))

        # Mask (for autoregressive model)
        mask = torch.tensor([[0 if i>= j else -torch.inf for j in range(N)] for i in range(N)])
        self.register_buffer("mask", mask) # move mask to GPU
        
    def forward(self, X): # X has (N, d) dimensions
        seq_len = X.shape[0]
        Q = X@self.W_Q # (num_heads, N, key_dim)
        K = X@self.W_K # (num_heads, N, key_dim)
        V = X@self.W_V # (num_heads, N, value_dim)

        current_mask = self.mask[:seq_len, :seq_len] # when seq_len < N 

        attention = Q@(K.mT) / math.sqrt(key_dim) + current_mask # (num_heads, N (queries), N (keys)) 
        heads = nn.functional.softmax(attention, dim = -1)@V #(num_heads, N, value_dim)
        cat_heads = torch.cat(heads.unbind(), dim=1) # (N, value_dim) each and concatenate the columns to form (N, model_dim)
        A = cat_heads@self.W_O 

        return A
                    

In [103]:
model_dim = 4
key_dim = 2
N = 3
num_heads = 2
value_dim = model_dim // num_heads
X = torch.rand((N,model_dim))
W_Q = torch.rand((num_heads, model_dim, key_dim), requires_grad=True)
W_K = torch.rand((num_heads, model_dim, key_dim), requires_grad=True)
W_V = torch.rand((num_heads, model_dim, value_dim), requires_grad=True)
W_O = torch.rand((model_dim, model_dim), requires_grad=True)
mask = torch.tensor([[0 if i>= j else -torch.inf for j in range(N)] for i in range(N)])

Q = X@W_Q # (num_heads, N, key_dim)
K = X@W_K # (num_heads, N, key_dim)
V = X@W_V # (num_heads, N, value_dim)
attention = Q@(K.mT) / math.sqrt(key_dim) + mask # (num_heads, N (queries), N (keys)) 
heads = nn.functional.softmax(attention, dim = -1)@V #(num_heads, N, value_dim)
cat_heads = torch.cat(heads.unbind(), dim=1)
A = cat_heads@W_O
print(heads)
print(A)

tensor([[[0.7208, 0.9492],
         [1.3220, 1.5685],
         [1.1685, 1.4063]],

        [[0.6066, 0.5019],
         [1.2620, 1.2463],
         [1.1876, 1.0844]]], grad_fn=<UnsafeViewBackward0>)
tensor([[1.5921, 0.5447, 1.7739, 1.8473],
        [3.0285, 1.0847, 3.2703, 3.5603],
        [2.7400, 0.9735, 2.9330, 3.1974]], grad_fn=<MmBackward0>)


In [107]:
multihead_attention = AttentionLayer(N=3, model_dim=4, key_dim=2, num_heads=2)
multihead_attention(X)

tensor([[1.3525, 0.8059, 1.0190, 1.8855],
        [2.5131, 1.6030, 1.9719, 3.5664],
        [2.3800, 1.5152, 1.8651, 3.3762]], grad_fn=<MmBackward0>)