In [1]:
import pandas as pd
import torch
from torch import nn

n_tokens = 10
emb_dim = 256
x = torch.randn(n_tokens, emb_dim)

class MHA(nn.Module):
    
    def __init__(self, n_heads = 4, emb_dim = emb_dim):
        super(MHA, self).__init__()
        self.n_heads = n_heads
        assert emb_dim % n_heads == 0, "emb_dim must be divisible by n_heads"
        self.head_dim = emb_dim // n_heads
        self.pq = torch.nn.Linear(self.head_dim, self.head_dim)
        self.pk = torch.nn.Linear(self.head_dim, self.head_dim)
        self.pv = torch.nn.Linear(self.head_dim, self.head_dim)
        self.output = torch.nn.Linear(emb_dim, emb_dim)
    
    def forward(self, x):                                                               #(n_tokens, emb_dim)
        x = x.reshape(n_tokens, self.n_heads, self.head_dim)                            #(n_tokens, n_head, head_dim)
        x = x.permute(1, 0, 2)                                                          #(n_head, n_tokens, head_dim) 
        q = self.pq(x)                                                                  #(n_head, n_tokens, head_dim)
        k = self.pk(x)                                                                  #(n_head, n_tokens, head_dim)
        v = self.pv(x)                                                                  #(n_head, n_tokens, head_dim)
        kt = torch.transpose(k, 1 ,2)                                                   #(n_head, head_dim, n_tokens)
        attn_scores = q@kt                                                              #(n_head, n_tokens, n_tokens)
        n_head_tensor = torch.tensor(self.head_dim, dtype= float)                       #(1)
        attn_weights = nn.functional.softmax(attn_scores)/torch.sqrt(n_head_tensor)     #(n_head, n_tokens, n_tokens)
        context_vector = attn_weights @ v                                               #(n_head, n_tokens, head_dim)
        context_vector = context_vector.permute(1, 0, 2)                                #(n_tokens, n_head, head_dim)
        context_vector = context_vector.contiguous()                                
        context_vector = context_vector.view(n_tokens, self.head_dim * self.n_heads)    #(n_tokens, emb_dim)
        context_vector = self.output(context_vector)
        print(f'{context_vector.shape = }')
        return(context_vector)
        
my_mha = MHA()
my_mha(x)

context_vector.shape = torch.Size([10, 256])


  attn_weights = nn.functional.softmax(attn_scores)/torch.sqrt(n_head_tensor)     #(n_head, n_tokens, n_tokens)


tensor([[-0.0689,  0.0257,  0.1130,  ..., -0.0400, -0.0066, -0.0514],
        [ 0.0220,  0.0993,  0.0570,  ..., -0.0787, -0.0307, -0.0619],
        [ 0.0005,  0.0103,  0.1120,  ..., -0.0099, -0.0527, -0.0462],
        ...,
        [-0.0036,  0.0739,  0.0293,  ..., -0.0195, -0.0068, -0.0577],
        [-0.0003,  0.0262,  0.1123,  ..., -0.0230,  0.0333, -0.0129],
        [-0.0471, -0.0655,  0.1521,  ..., -0.0323, -0.1300, -0.0706]],
       grad_fn=<AddmmBackward0>)