In [2]:
import torch
import torch.nn as nn
import math

In [3]:
hidden_states = torch.randn(2, 3, 4)
q_proj = nn.Linear(4, 4, bias=False)
k_proj = nn.Linear(4, 4, bias=False)
v_proj = nn.Linear(4, 4, bias=False)
o_proj = nn.Linear(4, 4, bias=False)


In [4]:
q = q_proj(hidden_states)
k = k_proj(hidden_states)
v = v_proj(hidden_states)

q.shape, k.shape, v.shape

(torch.Size([2, 3, 4]), torch.Size([2, 3, 4]), torch.Size([2, 3, 4]))

In [6]:
mask = torch.ones(2, 3)
mask[1][2] = 0
mask

tensor([[1., 1., 1.],
        [1., 1., 0.]])

In [18]:
P = q @ (k.transpose(-2, -1)) / math.sqrt(4)
causal_mask = torch.tril(torch.ones(2, 3, 3))
causal_mask

tensor([[[1., 0., 0.],
         [1., 1., 0.],
         [1., 1., 1.]],

        [[1., 0., 0.],
         [1., 1., 0.],
         [1., 1., 1.]]])

In [19]:
P = P + (1 - causal_mask) * -1e9
P

tensor([[[ 2.0278e-01, -1.0000e+09, -1.0000e+09],
         [-6.7292e-02,  2.2792e-01, -1.0000e+09],
         [ 1.0554e-01, -3.6337e-01, -2.5134e-01]],

        [[ 2.8567e-02, -1.0000e+09, -1.0000e+09],
         [ 3.8103e-02,  2.2329e-01, -1.0000e+09],
         [ 2.6402e-02,  5.7007e-02,  5.0832e-02]]], grad_fn=<AddBackward0>)

In [20]:
A = torch.softmax(P, dim=-1)
A

tensor([[[1.0000, 0.0000, 0.0000],
         [0.4267, 0.5733, 0.0000],
         [0.4300, 0.2690, 0.3009]],

        [[1.0000, 0.0000, 0.0000],
         [0.4538, 0.5462, 0.0000],
         [0.3272, 0.3374, 0.3353]]], grad_fn=<SoftmaxBackward0>)

In [22]:
O = A @ v
O

tensor([[[-0.6265,  0.8823,  0.3128,  0.2613],
         [-0.4459,  0.1374,  0.3487, -0.2749],
         [-0.2798,  0.4121,  0.3666, -0.0914]],

        [[-0.5247,  0.2238,  0.1720, -0.0454],
         [-0.3772,  0.1483,  0.0423,  0.0421],
         [-0.2919,  0.2075,  0.0152,  0.1126]]], grad_fn=<UnsafeViewBackward0>)

In [38]:
num_heads = 4
hidden_states = torch.randn(2, 3, 4 * num_heads).view(2, 3, num_heads, 4)
q_proj = nn.Linear(4, 4, bias=False)
k_proj = nn.Linear(4, 4, bias=False)
v_proj = nn.Linear(4, 4, bias=False)
o_proj = nn.Linear(4 * num_heads, 4 * num_heads, bias=False)

In [39]:
q = q_proj(hidden_states.transpose(1, 2)) # shape: (2, 4, 3, 4)
k = k_proj(hidden_states.transpose(1, 2)) # shape: (2, 4, 3, 4)
v = v_proj(hidden_states.transpose(1, 2)) # shape: (2, 4, 3, 4)
q.shape, k.shape, v.shape

(torch.Size([2, 4, 3, 4]), torch.Size([2, 4, 3, 4]), torch.Size([2, 4, 3, 4]))

In [40]:
P = q @ (k.transpose(-2, -1)) / math.sqrt(4)
causal_mask = torch.tril(torch.ones(2, 4, 3, 3))
P = P + (1 - causal_mask) * -1e9
A = torch.softmax(P, dim=-1)
O = A @ v
O.shape

torch.Size([2, 4, 3, 4])

In [41]:
O_merged = O.transpose(1, 2).reshape(2, 3, -1)
final_output = o_proj(O_merged)
final_output.shape

torch.Size([2, 3, 16])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads, embed_dim):
        self.__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.q_proj = nn.Linear(input_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(input_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(input_dim, embed_dim, bias=False)
        self.o_proj = nn.Linear(embed_dim * num_heads, input_dim, bias=False)
    
    def forward(self, x):
        seq_len = x.shape[1]
        q = self.q_proj(x).view(-1, seq_len, self.num_heads, self.embed_dim)
        k = self.k_proj(x).view(-1, seq_len, self.num_heads, self.embed_dim)
        v = self.v_proj(x).view(-1, seq_len, self.num_heads, self.embed_dim)
        