<a href="https://colab.research.google.com/github/winniema/mini_transformer/blob/main/Multi_Headed_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class Head(nn.Module):
    """ single head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(C, head_size, bias=False)
        self.query = nn.Linear(C, head_size, bias=False)
        self.value = nn.Linear(C, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(T, T)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)

        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)

        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# head size is the dimension which this attention head operates in, usually of a lower dimension than C
head_size = 8
num_heads = 4

# multiple heads of self-attention
heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
# linear projection
proj = nn.Linear(C, C)

out = [h(x) for h in heads] # [(B,T,hs), (B,T,hs), (B,T,hs), (B,T,hs)]
out = torch.cat(out, dim=-1) # concat along the last dimension (B,T,hs*num_heads)
out = proj(out) # (B,T,C) where C = hs*num_heads

Notes:
* There's three motivations for a multi-headed attention architecture:
  1. Lower computational complexity per attention head as C (n_embed) is mapped to a lower dimension num_heads
  2. Computation can be parallelized across the heads of attention
  3. Learning long-range dependencies (dependencies from a token from some time back) is easier when the computational path length from model input to output is shorter
* Each attention head, represented by a query, a key, and, a value would communicate different ideas. These ideas are then concatenated together (imagine stacking them side by side along the channel dimension). This stacking builds the last dimension back to the input channel dimension.
* The results are linearly projected to get the final output of the multi-headed attention block