In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import math

In [2]:
@dataclass
class Config:
  vocab_size: int = 50000
  seq_len: int = 4096
  d_model: int = 5120
  n_heads: int = 32
  n_kv_heads: int = 8
  n_layers: int = 40
  hidden_size: int = 14336

In [9]:
class MultiHeadAttentionBlock(nn.Module):
  def __init__(self, config: Config):
    super().__init__()

    self.d_model = config.d_model
    self.n_heads = config.n_heads
    assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"
    self.head_dim = self.d_model // self.n_heads

    self.w_q = nn.Linear(self.d_model, self.n_heads * self.head_dim, bias=False)
    self.w_k = nn.Linear(self.d_model, self.n_heads * self.head_dim, bias=False)
    self.w_v = nn.Linear(self.d_model, self.n_heads * self.head_dim, bias=False)

    self.w_o = nn.Linear(self.n_heads * self.head_dim, self.d_model, bias=False)

  def forward(self, x, mask=None):
    # x: (B, T, d_model)
    # mask: (B, 1, T, T)

    B, T, _ = x.shape
    query = self.w_q(x)    # (B, T, nh * hd)
    key = self.w_k(x)    # (B, T, nh * hd)
    value = self.w_v(x)    # (B, T, nh * hd)

    query = query.view(B, T, self.n_heads, self.head_dim)    # (B, T, nh, hd)
    key = key.view(B, T, self.n_heads, self.head_dim)    # (B, T, nh, hd)
    value = value.view(B, T, self.n_heads, self.head_dim)    # (B, T, nh, hd)

    query = query.transpose(1, 2)    # (B, nh, T, hd)
    key = key.transpose(1, 2)        # (B, nh, T, hd)
    value = value.transpose(1, 2)    # (B, nh, T, hd)

    attention_scores = query @ key.transpose(2, 3) / math.sqrt(self.head_dim)   # (B, nh, T, T)

    # apply mask
    if mask is not None:
      attention_scores = attention_scores.masked_fill_(mask == 0, -1e9)

    z = attention_scores @ value    # (B, nh, T, hd)
    z = z.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)    # (B, T, nh * hd)

    return self.w_o(z)    # (B, T, d_model)

In [10]:
def causal_mask(size):
  mask = torch.triu(torch.ones(1, size, size), diagonal=1)
  return mask == 0

In [11]:
config = Config()
mask = causal_mask(config.seq_len)
mask = mask.unsqueeze(0)
mask.shape  # (B, nh, T, T)

torch.Size([1, 1, 4096, 4096])

In [12]:
BATCH_SIZE = 1
SEQ_LEN = config.seq_len
D_MODEL = config.d_model
X = torch.randn(BATCH_SIZE, SEQ_LEN, D_MODEL)
X.shape

torch.Size([1, 4096, 5120])

In [None]:
attn = MultiHeadAttentionBlock(config)
output = attn(X, mask)
output.shape