<a href="https://colab.research.google.com/github/rczhen/code_ml/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

@dataclass
class Config:
    block_size: int = 8
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layers: int = 2
    n_heads: int = 4
    n_kv_heads: int = 2
    d_model: int = 16
    dropout_rate: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

# v1: standard CausalSelfAttention

In [3]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.w_qkv(x).split(self.d_model, dim=2) # (B, T, D) @ (D, 3D) --> (B, T, 3D) --> split at dim=2 --> (B, T, D)
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y


"""
why scaled?
if no normalization, the variance of weights wil be on the order of head_size, here is 16
when deviding by sqrt(head_size), bring the variance back
"Scaled" attention additional divides attention scores by 1/sqrt(head_size).
so when input Q,K are unit variance, attentions will be unit variance too
and Softmax will stay diffuse and not saturate too much.
"""


'\nwhy scaled? \nif no normalization, the variance of weights wil be on the order of head_size, here is 16\nwhen deviding by sqrt(head_size), bring the variance back\n"Scaled" attention additional divides attention scores by 1/sqrt(head_size). \nso when input Q,K are unit variance, attentions will be unit variance too\nand Softmax will stay diffuse and not saturate too much.\n'

In [4]:
x = torch.rand(1, Config.block_size, Config.d_model)
print(x)
layer = CausalSelfAttention(Config)
print(layer(x))

tensor([[[0.4969, 0.6570, 0.5957, 0.2806, 0.9898, 0.4197, 0.0525, 0.3336,
          0.1488, 0.0221, 0.7811, 0.7002, 0.5456, 0.2738, 0.3161, 0.4923],
         [0.3783, 0.2664, 0.5428, 0.1738, 0.7641, 0.9601, 0.5239, 0.8026,
          0.0244, 0.0028, 0.8761, 0.8630, 0.2437, 0.8981, 0.1249, 0.6858],
         [0.3772, 0.4514, 0.9887, 0.0030, 0.9620, 0.2002, 0.7762, 0.4336,
          0.9140, 0.2241, 0.4045, 0.1784, 0.4262, 0.4291, 0.9030, 0.4159],
         [0.1943, 0.2074, 0.1744, 0.0545, 0.1915, 0.9909, 0.2459, 0.7920,
          0.9644, 0.3800, 0.5882, 0.2471, 0.3493, 0.0587, 0.8093, 0.4968],
         [0.0784, 0.3818, 0.0194, 0.9578, 0.2897, 0.9427, 0.2215, 0.1784,
          0.7407, 0.5388, 0.2052, 0.3051, 0.4202, 0.8896, 0.3862, 0.0851],
         [0.0419, 0.6332, 0.2759, 0.0148, 0.9401, 0.9824, 0.6808, 0.8424,
          0.5608, 0.0333, 0.7411, 0.7728, 0.7805, 0.3241, 0.6716, 0.4933],
         [0.4805, 0.7104, 0.3978, 0.9024, 0.8240, 0.5586, 0.9696, 0.4489,
          0.0664, 0.7289, 0.3661

# v2: Talking Heads Attention

In [5]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config, talking_heads=True) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        # talking heads: 引入两个小的(n_heads, n_heads)矩阵, 在immediately before and after softmax, 进行head之间的线性变换
        # 2 learnable linear transformations, process (1) attention scores (right after Q*V and scaling, before masking),
        #   (2) logits (attention logits, after masking, optional padding, softmax, before weighted summing keys)
        self.talking_heads = talking_heads
        if self.talking_heads:
            self.w_talking_weights = nn.Linear(self.n_heads, self.n_heads)
            self.w_talking_logits = nn.Linear(self.n_heads, self.n_heads)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.w_qkv(x).split(self.d_model, dim=2) # (B, T, D) @ (D, 3D) --> (B, T, 3D) --> split at dim=2 --> (B, T, D)
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1) # (B, nh, T, T) --> (B, T, T, nh); (B, nh, T_q, T_k) --> (B, T_q, T_k, nh) when query and key have different length
            attention = self.w_talking_weights(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1)  # (B, nh, T, T) --> (B, T, T, nh)
            attention = self.w_talking_logits(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y

In [6]:
x = torch.rand(1, Config.block_size, Config.d_model)
print(x)
layer = CausalSelfAttention(Config)
print(layer(x))

tensor([[[4.7509e-01, 5.5477e-01, 3.4897e-01, 2.4837e-01, 6.4799e-01,
          5.6900e-01, 1.4102e-01, 6.8804e-01, 1.6640e-01, 9.2718e-01,
          1.6572e-01, 6.8809e-01, 8.1066e-02, 9.5104e-03, 4.3837e-01,
          7.3487e-01],
         [6.1554e-01, 3.5329e-01, 7.0374e-01, 2.1320e-01, 5.9138e-01,
          8.1152e-01, 4.6939e-01, 6.5154e-01, 5.7460e-01, 3.1282e-01,
          9.4076e-01, 5.3829e-01, 1.0570e-01, 9.5753e-01, 8.8933e-01,
          4.8705e-01],
         [6.1586e-01, 7.4315e-02, 7.2140e-01, 8.6935e-01, 8.8169e-01,
          8.5381e-01, 3.9540e-02, 8.6814e-01, 3.5598e-01, 4.8922e-01,
          4.8109e-01, 5.6206e-01, 1.2296e-02, 2.8425e-01, 1.8736e-01,
          8.2965e-01],
         [1.5933e-01, 1.6134e-01, 8.9811e-01, 7.9967e-01, 2.2775e-01,
          7.9299e-01, 9.1371e-01, 6.0428e-01, 4.6222e-01, 8.8225e-01,
          3.0300e-01, 4.2599e-01, 4.6615e-01, 6.6094e-01, 7.5380e-01,
          2.0109e-01],
         [8.9163e-04, 4.6125e-01, 4.9539e-01, 1.6191e-01, 8.1248e-01

# v3: GQA, MQA, MLA

In [7]:
class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention: Uses single key and value heads shared across all query heads.
    """
    def __init__(self, config, talking_heads=True) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        # MODIFIED: Separate projections - Q has n_heads, K&V have only 1 head each
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.head_size)
        self.w_v = nn.Linear(self.d_model, self.head_size)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        # talking heads: 引入两个小的(n_heads, n_heads)矩阵, 在immediately before and after softmax, 进行head之间的线性变换
        # 2 learnable linear transformations, process (1) attention scores (right after Q*V and scaling, before masking),
        #   (2) logits (attention logits, after masking, optional padding, softmax, before weighted summing keys)
        self.talking_heads = talking_heads
        if self.talking_heads:
            self.w_talking_weights = nn.Linear(self.n_heads, self.n_heads)
            self.w_talking_logits = nn.Linear(self.n_heads, self.n_heads)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # MODIFIED: Separate Q, K, V projections
        q = self.w_q(x).view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, D) --> (B, T, nh, hs) --> (B, nh, T, hs)
        k = self.w_k(x).view(B, T, 1, self.head_size).transpose(1, 2) # (B, T, hs) --> (B, T, 1, hs) --> (B, 1, T, hs)
        v = self.w_v(x).view(B, T, 1, self.head_size).transpose(1, 2) # (B, T, hs) --> (B, T, 1, hs) --> (B, 1, T, hs)

        # MODIFIED: Expand K&V to match Q's head dimension for broadcasting
        # expansion is not a must, but will make broadcasting more clear
        k = k.expand(B, self.n_heads, T, self.head_size) # (B, 1, T, hs) --> (B, nh, T, hs)
        v = v.expand(B, self.n_heads, T, self.head_size) # (B, 1, T, hs) --> (B, nh, T, hs)

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1) # (B, nh, T, T) --> (B, T, T, nh); (B, nh, T_q, T_k) --> (B, T_q, T_k, nh) when query and key have different length
            attention = self.w_talking_weights(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1)  # (B, nh, T, T) --> (B, T, T, nh)
            attention = self.w_talking_logits(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y

In [8]:
x = torch.rand(1, Config.block_size, Config.d_model)
print(x)
layer = MultiQueryAttention(Config)
print(layer(x))

tensor([[[1.4342e-01, 5.5843e-01, 9.4697e-01, 7.6076e-01, 2.7843e-01,
          2.4228e-02, 9.6862e-01, 3.9223e-01, 3.7804e-01, 3.9904e-01,
          3.2556e-02, 2.4766e-01, 6.2905e-01, 7.6078e-01, 6.8240e-01,
          6.8771e-01],
         [9.9645e-01, 2.3885e-01, 5.8239e-01, 6.3197e-01, 2.3005e-01,
          3.3318e-01, 3.1115e-01, 7.3059e-01, 9.8998e-01, 7.9331e-01,
          8.2870e-01, 1.4672e-01, 2.8890e-01, 7.9187e-02, 3.8313e-01,
          4.1179e-02],
         [3.2908e-01, 7.6473e-01, 4.6602e-01, 9.8476e-01, 2.4817e-01,
          6.2331e-01, 1.1480e-01, 4.7744e-01, 4.6600e-01, 7.3312e-01,
          2.9563e-02, 1.0326e-01, 6.4935e-01, 3.7376e-01, 1.6241e-01,
          6.3827e-01],
         [5.1974e-01, 8.6120e-01, 1.1577e-01, 9.7562e-01, 8.7746e-01,
          7.7222e-01, 5.5917e-01, 3.8288e-03, 8.3871e-01, 1.8970e-01,
          3.9696e-01, 5.4650e-01, 4.1318e-01, 6.5634e-01, 9.3445e-01,
          5.0466e-01],
         [7.1505e-01, 8.2612e-01, 9.8310e-01, 2.3374e-02, 7.2700e-01

In [10]:
class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention: Groups multiple query heads to share K&V heads. Compromise between MHA and MQA.
    """
    def __init__(self, config, talking_heads=True) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        # MODIFIED: Number of K&V heads (groups)
        assert config.n_heads % config.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_kv_heads = config.n_kv_heads
        self.n_reps = self.n_heads // self.n_kv_heads

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

         # MODIFIED: Separate QKV projections
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_v = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        # talking heads: 引入两个小的(n_heads, n_heads)矩阵, 在immediately before and after softmax, 进行head之间的线性变换
        # 2 learnable linear transformations, process (1) attention scores (right after Q*V and scaling, before masking),
        #   (2) logits (attention logits, after masking, optional padding, softmax, before weighted summing keys)
        self.talking_heads = talking_heads
        if self.talking_heads:
            self.w_talking_weights = nn.Linear(self.n_heads, self.n_heads)
            self.w_talking_logits = nn.Linear(self.n_heads, self.n_heads)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # MODIFIED: separate QKV
        q = self.w_q(x).view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, D)@(D, D)-->(B, T, D)-->(B, T, nh, hs)-->(B, nh, T, hs)
        k = self.w_k(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)
        v = self.w_v(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)

        # MODIFIED: Repeat K&V heads with repeat_interleave, extra memory used
        k = torch.repeat_interleave(k, self.n_reps, dim=1) # (B, nh_kv, T, hs) --> (B, nh, T, hs)
        v = torch.repeat_interleave(v, self.n_reps, dim=1) # (B, nh_kv, T, hs) --> (B, nh, T, hs)

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1) # (B, nh, T, T) --> (B, T, T, nh); (B, nh, T_q, T_k) --> (B, T_q, T_k, nh) when query and key have different length
            attention = self.w_talking_weights(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1)  # (B, nh, T, T) --> (B, T, T, nh)
            attention = self.w_talking_logits(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y


In [11]:
x = torch.rand(1, Config.block_size, Config.d_model)
print(x)
layer = GroupedQueryAttention(Config)
print(layer(x))

tensor([[[0.9923, 0.0274, 0.6775, 0.8857, 0.2252, 0.4074, 0.7280, 0.5103,
          0.0937, 0.0104, 0.3573, 0.6979, 0.3146, 0.7746, 0.0900, 0.8332],
         [0.4300, 0.5203, 0.0446, 0.5694, 0.2910, 0.3367, 0.9080, 0.8484,
          0.3832, 0.1859, 0.5531, 0.1580, 0.7124, 0.0344, 0.5398, 0.3613],
         [0.8035, 0.5223, 0.3308, 0.5732, 0.9738, 0.5370, 0.2289, 0.8004,
          0.0371, 0.0939, 0.5391, 0.7671, 0.2368, 0.8424, 0.3405, 0.1614],
         [0.7087, 0.6738, 0.9095, 0.4273, 0.9208, 0.7502, 0.5396, 0.9197,
          0.6661, 0.5823, 0.6719, 0.8124, 0.1856, 0.2895, 0.7377, 0.8691],
         [0.9879, 0.5011, 0.6483, 0.3833, 0.9163, 0.1785, 0.1416, 0.0341,
          0.1233, 0.2735, 0.1725, 0.6249, 0.1403, 0.5212, 0.9444, 0.8579],
         [0.5202, 0.8323, 0.9396, 0.9410, 0.0360, 0.9103, 0.4516, 0.3952,
          0.7082, 0.6757, 0.8405, 0.4517, 0.3588, 0.2418, 0.3794, 0.6383],
         [0.5471, 0.0027, 0.6618, 0.6652, 0.7158, 0.9049, 0.3120, 0.8153,
          0.8799, 0.1996, 0.3987

In [None]:
# uses broadcasting and reshaping to avoid memory duplication:
# Key Changes:

# Reshape Q instead of expanding K,V: Groups Q heads into (B, n_kv_heads, n_rep, T, hs)
# Use broadcasting: K,V are unsqueezed to (B, n_kv_heads, 1, T, hs) for broadcasting
# No memory duplication: All operations work with views and broadcasting

class GroupedQueryAttention(nn.Module):
    """
    Memory-efficient Grouped Query Attention (GQA): Compromise between MHA and MQA.
    Groups multiple query heads to share K&V heads WITHOUT using repeat_interleave.
    """
    def __init__(self, config, n_kv_heads=None, talking_heads=True) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        # Number of K&V heads (groups)
        self.n_kv_heads = n_kv_heads if n_kv_heads is not None else config.n_heads // 4
        assert self.n_heads % self.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"

        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size
        # Calculate how many query heads per KV group
        self.n_rep = self.n_heads // self.n_kv_heads

        self.attention_dropout = nn.Dropout(config.dropout_rate)
        self.residual_dropout = nn.Dropout(config.dropout_rate)

        # Different sizes for Q vs K&V projections
        self.w_q = nn.Linear(self.d_model, self.d_model)  # Query: full dimension
        self.w_k = nn.Linear(self.d_model, self.n_kv_heads * self.head_size)  # Key: reduced
        self.w_v = nn.Linear(self.d_model, self.n_kv_heads * self.head_size)  # Value: reduced
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.talking_heads = talking_heads
        if self.talking_heads:
            self.w_talking_weights = nn.Linear(self.n_heads, self.n_heads)
            self.w_talking_logits = nn.Linear(self.n_heads, self.n_heads)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size))
                                    .view(1, 1, self.block_size, self.block_size))

    def forward(self, x):
        B, T, D = x.size()

        # Project to Q, K, V
        q = self.w_q(x)  # (B, T, D)
        k = self.w_k(x)  # (B, T, n_kv_heads * head_size)
        v = self.w_v(x)  # (B, T, n_kv_heads * head_size)

        # Reshape Q with all heads
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2)  # (B, n_heads, T, hs)

        # Reshape K, V with KV heads only
        k = k.view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2)  # (B, n_kv_heads, T, hs)
        v = v.view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2)  # (B, n_kv_heads, T, hs)

        # MEMORY EFFICIENT: Reshape Q to group format instead of expanding K,V
        # Group Q heads: (B, n_heads, T, hs) -> (B, n_kv_heads, n_rep, T, hs)
        q_grouped = q.view(B, self.n_kv_heads, self.n_rep, T, self.head_size)

        # Expand K,V for broadcasting: (B, n_kv_heads, T, hs) -> (B, n_kv_heads, 1, T, hs)
        k_expanded = k.unsqueeze(2)  # (B, n_kv_heads, 1, T, hs)
        v_expanded = v.unsqueeze(2)  # (B, n_kv_heads, 1, T, hs)

        # Compute attention scores with broadcasting
        # (B, n_kv_heads, n_rep, T, hs) @ (B, n_kv_heads, 1, hs, T) -> (B, n_kv_heads, n_rep, T, T)
        attention = q_grouped @ k_expanded.transpose(-1, -2)
        attention *= self.head_size ** -0.5

        # Apply talking heads if enabled (need to reshape for linear layers)
        if self.talking_heads:
            # Reshape back to (B, n_heads, T, T) for talking heads
            attention = attention.view(B, self.n_heads, T, T)
            attention = attention.permute(0, 2, 3, 1)  # (B, T, T, n_heads)
            attention = self.w_talking_weights(attention)
            attention = attention.permute(0, 3, 1, 2)  # (B, n_heads, T, T)
            # Reshape back to grouped format
            attention = attention.view(B, self.n_kv_heads, self.n_rep, T, T)

        # Apply causal mask - broadcast across the group dimension
        attention = attention.masked_fill(self.mask[:, :, :T, :T].unsqueeze(2) == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)

        # Apply talking heads to probabilities if enabled
        if self.talking_heads:
            # Reshape back to (B, n_heads, T, T) for talking heads
            attention = attention.view(B, self.n_heads, T, T)
            attention = attention.permute(0, 2, 3, 1)  # (B, T, T, n_heads)
            attention = self.w_talking_logits(attention)
            attention = attention.permute(0, 3, 1, 2)  # (B, n_heads, T, T)
            # Reshape back to grouped format
            attention = attention.view(B, self.n_kv_heads, self.n_rep, T, T)

        attention = self.attention_dropout(attention)

        # Apply attention to values with broadcasting
        # (B, n_kv_heads, n_rep, T, T) @ (B, n_kv_heads, 1, T, hs) -> (B, n_kv_heads, n_rep, T, hs)
        y = attention @ v_expanded

        # Reshape back to standard format: (B, n_kv_heads, n_rep, T, hs) -> (B, n_heads, T, hs)
        y = y.view(B, self.n_heads, T, self.head_size)

        # Transpose and reshape to output format
        y = y.transpose(1, 2).contiguous().view(B, T, D)
        y = self.w_o(y)
        y = self.residual_dropout(y)

        return y


In [None]:
class MultiHeadLatentAttention(nn.Module):
    """
    Multi-Head Latent Attention (MLA): Compresses K&V into lower-dimensional latent representations.
    """
    def __init__(self, config, talking_heads=True) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        # talking heads: 引入两个小的(n_heads, n_heads)矩阵, 在immediately before and after softmax, 进行head之间的线性变换
        # 2 learnable linear transformations, process (1) attention scores (right after Q*V and scaling, before masking),
        #   (2) logits (attention logits, after masking, optional padding, softmax, before weighted summing keys)
        self.talking_heads = talking_heads
        if self.talking_heads:
            self.w_talking_weights = nn.Linear(self.n_heads, self.n_heads)
            self.w_talking_logits = nn.Linear(self.n_heads, self.n_heads)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.w_qkv(x).split(self.d_model, dim=2) # (B, T, D) @ (D, 3D) --> (B, T, 3D) --> split at dim=2 --> (B, T, D)
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1) # (B, nh, T, T) --> (B, T, T, nh); (B, nh, T_q, T_k) --> (B, T_q, T_k, nh) when query and key have different length
            attention = self.w_talking_weights(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1)  # (B, nh, T, T) --> (B, T, T, nh)
            attention = self.w_talking_logits(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y

# GPT