<a href="https://colab.research.google.com/github/rczhen/code_ml/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

# reference: https://github.com/karpathy/nanoGPT/blob/master/model.py

In [2]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

@dataclass
class Config:
    block_size: int = 8
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layers: int = 2
    n_heads: int = 4
    n_kv_heads: int = 2
    d_model: int = 16
    dropout_rate: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster.
                      # for LayerNorm, the reason not using bias, may be the motivation of RMSNorm, i.e. re-scaling is important, re-centering is not

# v1: standard CausalSelfAttention

In [3]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.w_qkv(x).split(self.d_model, dim=2) # (B, T, D) @ (D, 3D) --> (B, T, 3D) --> split at dim=2 --> (B, T, D)
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y


"""
why scaled?
if no normalization, the variance of weights wil be on the order of head_size, here is 16
when deviding by sqrt(head_size), bring the variance back
"Scaled" attention additional divides attention scores by 1/sqrt(head_size).
so when input Q,K are unit variance, attentions will be unit variance too
and Softmax will stay diffuse and not saturate too much.
"""


'\nwhy scaled? \nif no normalization, the variance of weights wil be on the order of head_size, here is 16\nwhen deviding by sqrt(head_size), bring the variance back\n"Scaled" attention additional divides attention scores by 1/sqrt(head_size). \nso when input Q,K are unit variance, attentions will be unit variance too\nand Softmax will stay diffuse and not saturate too much.\n'

In [4]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = CausalSelfAttention(Config)
print(layer(x))

tensor([[[ 0.4004,  0.5980,  0.5896, -0.0539,  0.1002, -0.0575, -0.2603,
          -0.3139,  0.3912,  0.0885,  0.0016,  0.4248, -0.1102, -0.1208,
          -0.2419,  0.3066],
         [ 0.3270,  0.6477,  0.5268, -0.0702,  0.1035, -0.1017, -0.3017,
          -0.3382,  0.3299,  0.1279,  0.0534,  0.4702,  0.0201, -0.1228,
          -0.3237,  0.2428],
         [ 0.2985,  0.5682,  0.5130, -0.0482,  0.1458, -0.0866, -0.3149,
          -0.3121,  0.3063,  0.1244,  0.0519,  0.4533,  0.0221, -0.0707,
          -0.2767,  0.2646],
         [ 0.2555,  0.5477,  0.4681, -0.0656,  0.1501, -0.0931, -0.2824,
          -0.3103,  0.2383,  0.1213,  0.0534,  0.4383,  0.0465, -0.0278,
          -0.2521,  0.2358],
         [ 0.2507,  0.5460,  0.4752, -0.0700,  0.1498, -0.1363, -0.2806,
          -0.2952,  0.2235,  0.1202,  0.0859,  0.4288,  0.0455, -0.0308,
          -0.2474,  0.2685],
         [ 0.2504,  0.5350,  0.4752, -0.0766,  0.1567, -0.1345, -0.2461,
          -0.2719,  0.2220,  0.0921,  0.0915,  0.413

# v2: Talking Heads Attention

In [5]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config, talking_heads=True) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        # talking heads: 引入两个小的(n_heads, n_heads)矩阵, 在immediately before and after softmax, 进行head之间的线性变换
        # 2 learnable linear transformations, process (1) attention scores (right after Q*V and scaling, before masking),
        #   (2) logits (attention logits, after masking, optional padding, softmax, before weighted summing keys)
        self.talking_heads = talking_heads
        if self.talking_heads:
            self.w_talking_weights = nn.Linear(self.n_heads, self.n_heads)
            self.w_talking_logits = nn.Linear(self.n_heads, self.n_heads)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.w_qkv(x).split(self.d_model, dim=2) # (B, T, D) @ (D, 3D) --> (B, T, 3D) --> split at dim=2 --> (B, T, D)
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1) # (B, nh, T, T) --> (B, T, T, nh); (B, nh, T_q, T_k) --> (B, T_q, T_k, nh) when query and key have different length
            attention = self.w_talking_weights(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1)  # (B, nh, T, T) --> (B, T, T, nh)
            attention = self.w_talking_logits(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y

In [6]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = CausalSelfAttention(Config)
print(layer(x))

tensor([[[ 0.0665, -0.2212, -0.1294, -0.6010,  0.2886, -0.0849, -0.1126,
           0.5110, -0.2082, -0.4438, -0.1416,  0.0666,  0.1182,  0.2938,
           0.1773, -0.0441],
         [ 0.0755, -0.2447, -0.1591, -0.5085,  0.2118, -0.1085, -0.0031,
           0.5469, -0.2007, -0.4433, -0.1104,  0.1142,  0.0752,  0.1899,
           0.1563, -0.0866],
         [ 0.0481, -0.2075, -0.1224, -0.4976,  0.2144, -0.1059, -0.0053,
           0.5651, -0.1682, -0.4552, -0.0725,  0.0670,  0.0768,  0.2074,
           0.1868, -0.0627],
         [ 0.0658, -0.2205, -0.1184, -0.4823,  0.2013, -0.0866,  0.0091,
           0.5799, -0.1682, -0.4842, -0.0686,  0.0645,  0.0737,  0.2011,
           0.1866, -0.0600],
         [ 0.0394, -0.2117, -0.1212, -0.5000,  0.2139, -0.1206, -0.0051,
           0.5750, -0.1629, -0.4585, -0.0665,  0.0563,  0.0997,  0.1984,
           0.1722, -0.0850],
         [ 0.0494, -0.2151, -0.1200, -0.4883,  0.2011, -0.1035,  0.0121,
           0.5853, -0.1697, -0.4641, -0.0616,  0.043

# v3: GQA, MQA, MLA

In [9]:
class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention: Uses single key and value heads shared across all query heads.
    """
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        # MODIFIED: Separate projections - Q has n_heads, K&V have only 1 head each
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.head_size)
        self.w_v = nn.Linear(self.d_model, self.head_size)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # MODIFIED: Separate Q, K, V projections
        q = self.w_q(x).view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, D) --> (B, T, nh, hs) --> (B, nh, T, hs)
        k = self.w_k(x).view(B, T, 1, self.head_size).transpose(1, 2) # (B, T, hs) --> (B, T, 1, hs) --> (B, 1, T, hs)
        v = self.w_v(x).view(B, T, 1, self.head_size).transpose(1, 2) # (B, T, hs) --> (B, T, 1, hs) --> (B, 1, T, hs)

        # MODIFIED: Expand K&V to match Q's head dimension for broadcasting
        # expansion is not a must, but will make broadcasting more clear
        k = k.expand(B, self.n_heads, T, self.head_size) # (B, 1, T, hs) --> (B, nh, T, hs)
        v = v.expand(B, self.n_heads, T, self.head_size) # (B, 1, T, hs) --> (B, nh, T, hs)

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y

In [10]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = MultiQueryAttention(Config)
print(layer(x))

tensor([[[-0.0089, -0.1340, -0.2575, -0.0265,  0.1558,  0.3195,  0.2710,
          -0.2836, -0.1693, -0.0023,  0.1426, -0.0382, -0.1539,  0.0310,
           0.1011,  0.0823],
         [-0.0349, -0.1151, -0.3152, -0.0322,  0.1249,  0.2989,  0.2809,
          -0.2436, -0.1454,  0.0034,  0.1665, -0.0396, -0.1526,  0.0235,
           0.0900,  0.1204],
         [ 0.0193, -0.1448, -0.2811, -0.0327,  0.1149,  0.2927,  0.2965,
          -0.2318, -0.1360,  0.0403,  0.2389,  0.0062, -0.1406,  0.0070,
           0.0880,  0.0946],
         [ 0.0499, -0.1823, -0.2870, -0.0109,  0.1140,  0.3278,  0.3338,
          -0.2283, -0.1626,  0.0739,  0.3091,  0.0432, -0.1271,  0.0119,
           0.0722,  0.0947],
         [ 0.0237, -0.1871, -0.3040, -0.0161,  0.1049,  0.3428,  0.3317,
          -0.2379, -0.1550,  0.0535,  0.2957,  0.0377, -0.1396,  0.0060,
           0.0905,  0.1305],
         [ 0.0360, -0.1961, -0.2937, -0.0102,  0.1115,  0.3513,  0.3367,
          -0.2427, -0.1639,  0.0606,  0.3052,  0.044

In [11]:
class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention: Groups multiple query heads to share K&V heads. Compromise between MHA and MQA.
    Simple implementation. Using repeat_interleave which requires extra memory.
    """
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        # MODIFIED: Number of K&V heads (groups)
        assert config.n_heads % config.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_kv_heads = config.n_kv_heads
        self.n_reps = self.n_heads // self.n_kv_heads

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

         # MODIFIED: Separate QKV projections
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_v = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # MODIFIED: separate QKV
        q = self.w_q(x).view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, D)@(D, D)-->(B, T, D)-->(B, T, nh, hs)-->(B, nh, T, hs)
        k = self.w_k(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)
        v = self.w_v(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)

        # MODIFIED: Repeat K&V heads with repeat_interleave, extra memory used
        k = torch.repeat_interleave(k, self.n_reps, dim=1) # (B, nh_kv, T, hs) --> (B, nh, T, hs)
        v = torch.repeat_interleave(v, self.n_reps, dim=1) # (B, nh_kv, T, hs) --> (B, nh, T, hs)

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y


In [13]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = GroupedQueryAttention(Config)
print(layer(x))

tensor([[[-0.0879,  0.2400,  0.3814,  0.4572, -0.6883, -0.5715,  0.4995,
           0.0522,  0.3881,  0.4449,  0.2672,  0.5377, -0.1666, -0.1189,
          -0.2822, -0.0406],
         [-0.0295,  0.2227,  0.3507,  0.3563, -0.6217, -0.5170,  0.4981,
           0.0667,  0.2725,  0.4023,  0.1905,  0.4688, -0.1318, -0.0580,
          -0.3032, -0.0014],
         [-0.0785,  0.1981,  0.3914,  0.3099, -0.6436, -0.5381,  0.5371,
           0.0389,  0.3159,  0.4041,  0.1765,  0.4879, -0.1453, -0.0330,
          -0.3453,  0.0031],
         [-0.0412,  0.2127,  0.3815,  0.3468, -0.5874, -0.5222,  0.4987,
           0.0622,  0.2921,  0.3794,  0.1878,  0.4841, -0.1380, -0.0410,
          -0.3167, -0.0192],
         [-0.0593,  0.2154,  0.3756,  0.3351, -0.6243, -0.5268,  0.5278,
           0.0546,  0.3093,  0.3947,  0.1951,  0.5102, -0.1402, -0.0340,
          -0.3339,  0.0181],
         [-0.0679,  0.2016,  0.3807,  0.3361, -0.6380, -0.5300,  0.5290,
           0.0374,  0.3119,  0.4002,  0.1829,  0.489

In [14]:
# uses broadcasting and reshaping to avoid memory duplication:

# Reshape Q instead of expanding K,V: Groups Q heads into (B, n_kv_heads, n_rep, T, hs)
# Use broadcasting: K,V are unsqueezed to (B, n_kv_heads, 1, T, hs) for broadcasting
# No memory duplication: All operations work with views and broadcasting

class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention: Groups multiple query heads to share K&V heads. Compromise between MHA and MQA.
    """
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        # MODIFIED: Number of K&V heads (groups)
        assert config.n_heads % config.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_kv_heads = config.n_kv_heads
        self.n_reps = self.n_heads // self.n_kv_heads

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        # MODIFIED: Separate QKV projections
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_v = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        # MODIFIED: mask size to fit the (B, nh_kv, n_rep, T, T) attention shape
        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, 1, self.block_size, self.block_size))  # reshape for (B, nh_kv, n_rep, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # MODIFIED: separate QKV
        q = self.w_q(x).view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, D)@(D, D)-->(B, T, D)-->(B, T, nh, hs)-->(B, nh, T, hs)
        k = self.w_k(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)
        v = self.w_v(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)

        # MODIFIED v2: reshape Q + expanding KV, no repeat_interleave, no extra memory use.
        # Q: (B, nh, T, hs) -(reshape)-> (B, nh_kv, n_rep, T, hs),
        # K&V: (B, nh_kv, T, hs) -(reshape)-> (B, nh_kv, 1, T, hs) -(expand)-> (B, nh_kv, n_rep, T, hs)
        q = q.view(B, self.n_kv_heads, self.n_reps, T, self.head_size)
        k = k.view(B, self.n_kv_heads, 1, T, self.head_size).expand(B, self.n_kv_heads, self.n_reps, T, self.head_size)
        v = v.view(B, self.n_kv_heads, 1, T, self.head_size).expand(B, self.n_kv_heads, self.n_reps, T, self.head_size)

        # MODIFIED v2: attention
        attention = q @ k.transpose(-1, -2) # (B, nh_kv, n_rep, T, hs) @ (B, nh_kv, n_rep, hs, T) --> (B, nh_kv, n_rep, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh_kv, n_rep, T, T) @ (B, nh_kv, n_rep, T, hs) --> (B, nh_kv, n_rep, T, hs)
        y = y.permute(0, 3, 1, 2, 4).contiguous().view(B, T, D) # (B, nh_kv, n_rep, T, hs)-->(B, T, nh_kv, n_rep, hs)-->(B, T, D), was (B, nh, T, hs)-->(B, T, nh, hs)-->(B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y


In [15]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = GroupedQueryAttention(Config)
print(layer(x))

tensor([[[ 0.4045,  0.1040, -0.1319,  0.0223, -0.0322,  0.3646, -0.1368,
          -0.2340,  0.0120, -0.3832, -0.0699, -0.2371, -0.0072, -0.4912,
           0.0115, -0.3148],
         [ 0.2969,  0.1523, -0.0991,  0.0581, -0.1015,  0.2936, -0.1960,
          -0.1718, -0.0031, -0.4423,  0.0079, -0.1886, -0.0448, -0.4556,
           0.0661, -0.2229],
         [ 0.3067,  0.1638, -0.1413,  0.0782, -0.1533,  0.3009, -0.1730,
          -0.1629,  0.0260, -0.4450, -0.0553, -0.1745, -0.0252, -0.5111,
           0.0664, -0.3143],
         [ 0.2796,  0.2039, -0.1180,  0.0715, -0.1636,  0.3043, -0.1814,
          -0.1636,  0.0194, -0.4972, -0.0223, -0.1814, -0.0279, -0.5082,
           0.0759, -0.2658],
         [ 0.2330,  0.2026, -0.0992,  0.0606, -0.1850,  0.2850, -0.1863,
          -0.1281, -0.0032, -0.4625, -0.0174, -0.1755, -0.0331, -0.5127,
           0.0734, -0.2376],
         [ 0.2497,  0.2138, -0.1104,  0.0779, -0.2040,  0.2711, -0.1984,
          -0.1141,  0.0188, -0.4913, -0.0115, -0.174

In [18]:
class MultiHeadLatentAttention(nn.Module):
    """
    Multi-Head Latent Attention (MLA): Compresses K&V into lower-dimensional latent representations.
    paper: https://arxiv.org/pdf/2405.04434, section 2.1
    """
    def __init__(self, config, latent_dim=16) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        # MODIFIED
        self.latent_dim = latent_dim
        self.w_q = nn.Linear(self.d_model, self.d_model)

        # kv share the same latent vector, hence only "one" latent vector need to be cached during inference.
        # plus, during inference, w_k_decode/w_v_decode can be merged with w_q/w_o respectively, hence no decoding required
        self.w_kv_compress = nn.Linear(self.d_model, latent_dim, bias=False)
        self.w_k_decode = nn.Linear(latent_dim, self.d_model, bias=False)
        self.w_v_decode = nn.Linear(latent_dim, self.d_model, bias=False)

        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.w_q(x) # (B, T, D) @ (D, D) --> (B, T, D)
        latent = self.w_kv_compress(x) # (D, latent_dim)
        k = self.w_k_decode(latent) # (B, T, D)
        v = self.w_v_decode(latent) # (B, T, D)

        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs)
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs)
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs)

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y

In [19]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = MultiHeadLatentAttention(Config)
print(layer(x))

tensor([[[ 2.8464e-02, -2.1961e-01, -1.5361e-01,  4.4400e-02, -1.5287e-01,
           2.8343e-01,  1.2061e-01, -1.2442e-02,  2.9561e-02, -1.2725e-02,
          -1.0118e-01,  5.4990e-02,  2.6925e-01,  1.5270e-01,  4.3015e-01,
          -5.4792e-02],
         [ 9.9094e-03, -1.7150e-01, -1.3730e-01,  1.0227e-01, -1.8065e-01,
           2.9052e-01,  1.0725e-01, -4.3471e-02,  5.4774e-02, -2.5308e-02,
          -8.8920e-02, -3.6775e-04,  1.5665e-01,  1.1298e-01,  3.9812e-01,
           4.1996e-03],
         [-6.1434e-03, -1.2881e-01, -1.0459e-01,  1.3003e-01, -1.8325e-01,
           2.7150e-01,  6.8840e-02, -3.5086e-02,  9.8818e-02, -5.4927e-02,
          -8.5095e-02, -2.2444e-03,  1.3021e-01,  9.1252e-02,  3.7169e-01,
           6.3924e-03],
         [-1.2102e-02, -1.0252e-01, -9.9437e-02,  1.4513e-01, -1.8700e-01,
           2.3622e-01,  4.6900e-02, -3.2036e-02,  1.3167e-01, -7.9713e-02,
          -5.2629e-02, -1.1770e-02,  1.1101e-01,  7.5719e-02,  3.2709e-01,
           4.0546e-03],
    

# GPT

In [None]:
class MLP(nn.Moduele):
    def __init__(self, config) -> None:
        super().__init__()

        self.fc1 = nn.Linear(config.d_model, 4 * config.d_model)
        self.fc2 = nn.Linear(4 * config.d_model, 4 * config.d_model)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

# https://docs.pytorch.org/docs/stable/generated/torch.nn.GELU.html
# gelu(x) = x * phi(x), where phi(x) is the Cumulative Distribution Function for Gaussian Distribution.

In [20]:
class Block(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        self.attention_layer = CausalSelfAttention(config)
        self.ffn = MLP(config)
        self.ln1 = nn.LayerNorm(config.d_model)
        self.ln2 = nn.LayerNorm(config.d_model)

    def forward(self, x):
        # Pre-layer-norm: LN before attention and FFN
        x = x + self.attention_layer(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


# DIN (Q & KV have different seq length)