# Multi-head attention

In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math

## single-head

In [2]:
### init data
B,T,C = 2, 4, 12 # batch, time, channels
head_size = 16
x = torch.randn(B,T,C)

### define Wq, Wk, Wv
k_proj = torch.rand((head_size, C))
q_proj = torch.rand((head_size, C))
v_proj  = torch.rand((head_size, C))

### compute q, k, v
k = x @ k_proj.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
q = x @ q_proj.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
v = x @ v_proj.T   # (B, T, C) @ (C, hs) -> (B, T, hs)

### compute attention score
attn =  q @ k.transpose(-2, -1) / math.sqrt(head_size)# (B, T, hs) @ (B, hs, T) ---> (B, T, T)
tril = torch.tril(torch.ones(T, T))
attn = attn.masked_fill(tril == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
print("causal attention score")
print(attn)

### compute output
out = attn @ v  # (B, T, T) @ (B, T, hs) -> (B, T, hs)

causal attention score
tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.3321e-01, 8.6679e-01, 0.0000e+00, 0.0000e+00],
         [1.0438e-01, 7.8750e-02, 8.1687e-01, 0.0000e+00],
         [1.0947e-16, 1.3340e-19, 2.7633e-15, 1.0000e+00]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [8.7014e-01, 1.2986e-01, 0.0000e+00, 0.0000e+00],
         [3.2443e-37, 3.1438e-36, 1.0000e+00, 0.0000e+00],
         [9.2693e-02, 1.4553e-02, 5.9197e-25, 8.9275e-01]]])


## multi-head

In [3]:
### init data
B,T,C = 2, 4, 12 # batch, time, channels
head_size = 16
x = torch.randn(B,T,C)

########################## Head 1 ##########################
### define Wq, Wk, Wv
k_proj_1 = torch.rand((head_size, C))
q_proj_1 = torch.rand((head_size, C))
v_proj_1  = torch.rand((head_size, C))

### compute q, k, v
k_1 = x @ k_proj_1.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
q_1 = x @ q_proj_1.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
v_1 = x @ v_proj_1.T   # (B, T, C) @ (C, hs) -> (B, T, hs)

### compute attention score
attn_1 =  q_1 @ k_1.transpose(-2, -1) / math.sqrt(head_size)# (B, T, hs) @ (B, hs, T) ---> (B, T, T)
tril = torch.tril(torch.ones(T, T))
attn_1 = attn.masked_fill(tril == 0, float('-inf'))
attn_1 = F.softmax(attn_1, dim=-1)

### compute output
out_1 = attn_1 @ v_1  # (B, T, T) @ (B, T, hs) -> (B, T, hs)

########################## Head 1 ##########################
### define Wq, Wk, Wv
k_proj_2 = torch.rand((head_size, C))
q_proj_2 = torch.rand((head_size, C))
v_proj_2  = torch.rand((head_size, C))

### compute q, k, v
k_2 = x @ k_proj_2.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
q_2 = x @ q_proj_2.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
v_2 = x @ v_proj_2.T   # (B, T, C) @ (C, hs) -> (B, T, hs)

### compute attention score
attn_2 =  q_2 @ k_2.transpose(-2, -1) / math.sqrt(head_size)# (B, T, hs) @ (B, hs, T) ---> (B, T, T)
tril = torch.tril(torch.ones(T, T))
attn_2 = attn.masked_fill(tril == 0, float('-inf'))
attn_2= F.softmax(attn_2, dim=-1)

### compute output
out_2 = attn_2 @ v_2 # (B, T, T) @ (B, T, hs) -> (B, T, hs)

########################## fuse multi head ##########################
multi_head_proj = torch.rand((head_size, head_size * 2)) # [hs, hs * 2]

concat_attention_output = torch.cat([out_1, out_2], dim = -1) # [B, T, hs * 2]

multi_head_output = concat_attention_output @ multi_head_proj.T # [B, T, hs]

# KV-caching

In [4]:
# ----------------------------------------------------------------------
# hyper‑parameters
# ----------------------------------------------------------------------
B           = 1        # batch size
C           = 12       # input/channel size per token
head_size   = 16       # hidden size of our (single) attention head
steps       = 10       # how many new tokens to append

# ----------------------------------------------------------------------
# projection matrices (random for demo)
# ----------------------------------------------------------------------
k_proj  = torch.randn(head_size, C)
q_proj  = torch.randn(head_size, C)
v_proj  = torch.randn(head_size, C)
next_proj = torch.randn(C, head_size)   # turns last hidden → next x


In [5]:
# ----------------------------------------------------------------------
# seed context: one initial token vector  (B, 1, C)
# ----------------------------------------------------------------------
torch.manual_seed(42)
x_seq_no_cache = torch.randn(B, 1, C)

# ----------------------------------------------------------------------
# naïve autoregressive loop (no KV cache)
# ----------------------------------------------------------------------
for t in range(steps):
    # 1) project **all** tokens seen so far
    k = x_seq_no_cache @ k_proj.T    # (B, T, hs)
    q = x_seq_no_cache @ q_proj.T    # (B, T, hs)    
    v = x_seq_no_cache @ v_proj.T    # (B, T, hs)

    # 2) causal attention over the full T×T matrix
    T = k.size(1)
    attn = q @ k.transpose(-2, -1) / math.sqrt(head_size)  # (B, T, T)
    mask = torch.tril(torch.ones(T, T))
    attn = attn.masked_fill(mask == 0, float("-inf"))
    attn = F.softmax(attn, dim=-1)

    # 3) hidden state of the **last** position
    out = attn @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
    out_last = out[:, -1, :]    # (B, hs)

    # 4) predict next token
    x_next = out_last @ next_proj.T # (B, hs) @ (hs, C) -> (B, C)
    x_next = x_next.unsqueeze(1) # (B, C) -> (B, 1 ,C)

    # 5) append to sequence
    x_seq_no_cache = torch.cat([x_seq_no_cache, x_next], dim=1) # (B, T, C) (B, 1, C) -> (B, T+1, C)

In [None]:
# ----------------------------------------------------------------------
# seed context and empty KV cache
# ----------------------------------------------------------------------
torch.manual_seed(42)
x_seq_kv_cache = torch.randn(B, 1, C)          # (B, 1, C)  initial token
k_cache = torch.zeros(B, 0, head_size)  # (B, 0, hs), tensor([])
v_cache = torch.zeros(B, 0, head_size)  # (B, 0, hs), tensor([])

# ----------------------------------------------------------------------
# KV‑cached autoregressive loop
# ----------------------------------------------------------------------
for _ in range(steps):
    
    # 1) project **current** token only
    x_step = x_seq_kv_cache[:, -1, :].unsqueeze(1)  # current input token (B,1,C)

    k_step = x_step @ k_proj.T     # (B, 1, hs)
    q_step = x_step @ q_proj.T     # (B, 1, hs)
    v_step = x_step @ v_proj.T     # (B, 1, hs)

    # 2) append new key/value to the cache
    k_cache = torch.cat([k_cache, k_step], dim=1)  # (B, t, hs)
    v_cache = torch.cat([v_cache, v_step], dim=1)  # (B, t, hs)

    # 3) causal attention over cached keys (size 1 × (t+1))
    attn   = q_step @ k_cache.transpose(-2, -1) / math.sqrt(head_size)
        # (B, 1, hs) @ (B, hs, t) -> (B, 1, t)
    attn   = F.softmax(attn, dim=-1)              # (B,1,t)

    # 4) hidden state for current position
    out_last = attn @ v_cache
        # (B, 1, t) @ (B, t, hs) -> (B, 1, hs)

    # 5) predict next token vector
    x_next = out_last @ next_proj.T
        # (B, 1, hs) @ (hs, C) -> (B, 1, C)

    # 6) append next token to full sequence & advance
    x_seq_kv_cache = torch.cat([x_seq_kv_cache, x_next], dim=1)     # (B, t+1, C)

In [7]:
torch.testing.assert_close(x_seq_no_cache, x_seq_kv_cache)