|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 2:</h2>|<h1>Large language models<h1>|
|<h2>Section:</h2>|<h1>Build a GPT<h1>|
|<h2>Lecture:</h2>|<h1><b>Multihead attention: theory and implementation<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Hyperparameters

In [None]:
# data hyperparameters
seq_len = 8 # aka context window

# model hyperparameters
embed_dim = 128
n_heads = 4 #n embed_dim/n_heads must be int

# training hyperparameters
batch_size = 5

# Class for multihead attention

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, embed_dim):
    super().__init__()

    # head-dimensionality is embed_dim split across the heads
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads

    # num_heads Q, K, and V matrices, initialized as one "super-head"
    #    note: in model 5, these three matrices are combined into one
    self.query = nn.Linear(embed_dim, embed_dim, bias=False)
    self.key   = nn.Linear(embed_dim, embed_dim, bias=False)
    self.value = nn.Linear(embed_dim, embed_dim, bias=False)

    # final linear projection merges the heads' outputs
    self.W0 = nn.Linear(embed_dim, embed_dim, bias=False)

  def forward(self,x,track_sizes=False):

    # extract the dimension sizes of the inputs (token embeddings)
    B, T, E = x.shape # [batch, tokens (sequence length), embed_dim]
    if track_sizes: print(f"1){' Input data shape:':>28} {x.shape}")

    # push data through Q, K, and V (actually multiple heads still in the same matrix)
    q = self.query(x) # [batch, seq_len, embed_dim]
    k = self.key(x)
    v = self.value(x)
    if track_sizes: print(f"2){'q/k/v pre-split shape:':>28} {q.shape}")

    # reshape to split up the heads (note: head-splitting is done after XW_Q)
    q = q.view(B, T, self.num_heads, self.head_dim)
    k = k.view(B, T, self.num_heads, self.head_dim)
    v = v.view(B, T, self.num_heads, self.head_dim)

    # but pytorch's SDPA function needs the shape to be [B, num_heads, T, head_dim]
    q = q.transpose(1,2)
    k = k.transpose(1,2)
    v = v.transpose(1,2)
    if track_sizes: print(f"3){'q/k/v post-split shape:':>28} {q.shape}")

    # now we can call SDPA
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
    if track_sizes: print(f"4){'Data post-attention shape:':>28} {out.shape}")

    # but our code still needs [B, T, num_heads, head_dim]
    out = out.transpose(1,2)
    if track_sizes: print(f"5){'Post-attention data reshape:':>28} {out.shape}")

    # merge heads back into embed_dim
    out = out.reshape(B, T, E)
    if track_sizes: print(f"6){'Data merged to size:':>28} {out.shape}")

    # finally, apply linear mixing matrix
    out = self.W0(out)
    if track_sizes: print(f"7){'Post-MHA H0 linear mixing:':>28} {out.shape}")

    return out

In [None]:
mha = MultiHeadAttention(n_heads,embed_dim)
mha

In [None]:
# run some fake data through
data = torch.randn(size=(batch_size,seq_len,embed_dim))
out = mha(data)
print(f'Input size:  {data.shape}')
print(f'Output size: {out.shape}')

In [None]:
print(f'    Sequence length: {seq_len:2d}')
print(f'Embedding dimension: {embed_dim}')
print(f'    Number of heads: {n_heads:2d}')
print(f'Head dimensionality: {embed_dim // n_heads}')

print('\nDimensions of the data as it passes through the attention sublayer of one Transformer block:')
out = mha(data,track_sizes=True)