<a href="https://colab.research.google.com/github/rczhen/code_ml/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

# reference: https://github.com/karpathy/nanoGPT/blob/master/model.py

In [None]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

@dataclass
class Config:
    block_size: int = 8
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layers: int = 2
    n_heads: int = 4
    n_kv_heads: int = 2
    d_model: int = 16
    dropout_rate: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster.
                      # for LayerNorm, the reason not using bias, may be the motivation of RMSNorm, i.e. re-scaling is important, re-centering is not

# v1: standard CausalSelfAttention

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.w_qkv(x).split(self.d_model, dim=2) # (B, T, D) @ (D, 3D) --> (B, T, 3D) --> split at dim=2 --> (B, T, D)
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y


"""
why scaled?
if no normalization, the variance of weights wil be on the order of head_size, here is 16
when deviding by sqrt(head_size), bring the variance back
"Scaled" attention additional divides attention scores by 1/sqrt(head_size).
so when input Q,K are unit variance, attentions will be unit variance too
and Softmax will stay diffuse and not saturate too much.
"""


'\nwhy scaled?\nif no normalization, the variance of weights wil be on the order of head_size, here is 16\nwhen deviding by sqrt(head_size), bring the variance back\n"Scaled" attention additional divides attention scores by 1/sqrt(head_size).\nso when input Q,K are unit variance, attentions will be unit variance too\nand Softmax will stay diffuse and not saturate too much.\n'

In [None]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = CausalSelfAttention(Config)
print(layer(x))

tensor([[[ 0.1569,  0.2025,  0.0082, -0.0556, -0.3686,  0.0776,  0.2143,
          -0.0328, -0.1326,  0.0038,  0.1973,  0.1777,  0.3261,  0.4472,
          -0.1291, -0.1470],
         [ 0.0839,  0.1502,  0.0342, -0.0937, -0.3915,  0.1091,  0.2111,
           0.0225, -0.1306,  0.1197,  0.1535,  0.1027,  0.3091,  0.4546,
          -0.0607, -0.1843],
         [ 0.0777,  0.1380,  0.0271, -0.1116, -0.3637,  0.1233,  0.1699,
           0.0354, -0.1330,  0.0848,  0.0887,  0.1241,  0.2519,  0.4527,
          -0.0549, -0.1606],
         [ 0.0457,  0.1185,  0.0582, -0.0996, -0.3259,  0.0922,  0.1785,
           0.0802, -0.1493,  0.0712,  0.0738,  0.0730,  0.3146,  0.4426,
          -0.1139, -0.1492],
         [ 0.0536,  0.1174,  0.0461, -0.0833, -0.3345,  0.0757,  0.1689,
           0.0940, -0.1408,  0.1073,  0.0535,  0.0697,  0.3289,  0.4242,
          -0.1137, -0.1474],
         [ 0.0273,  0.1011,  0.0419, -0.0870, -0.3143,  0.0929,  0.1638,
           0.1210, -0.1451,  0.1236,  0.0162,  0.078

# v2: Talking Heads Attention

https://arxiv.org/pdf/2003.02436

In [None]:
class TalkingHeadsCausalSelfAttention(nn.Module):
    def __init__(self, config, talking_heads=True) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        # talking heads: 引入两个小的(n_heads, n_heads)矩阵, 在immediately before and after softmax, 进行head之间的线性变换
        # 2 learnable linear transformations, process (1) attention scores (right after Q*V and scaling, before masking),
        #   (2) logits (attention logits, after masking, optional padding, softmax, before weighted summing keys)
        self.talking_heads = talking_heads
        if self.talking_heads:
            self.w_talking_weights = nn.Linear(self.n_heads, self.n_heads)
            self.w_talking_logits = nn.Linear(self.n_heads, self.n_heads)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.w_qkv(x).split(self.d_model, dim=2) # (B, T, D) @ (D, 3D) --> (B, T, 3D) --> split at dim=2 --> (B, T, D)
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1) # (B, nh, T, T) --> (B, T, T, nh); (B, nh, T_q, T_k) --> (B, T_q, T_k, nh) when query and key have different length
            attention = self.w_talking_weights(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        if self.talking_heads:
            attention = attention.permute(0, 2, 3, 1)  # (B, nh, T, T) --> (B, T, T, nh)
            attention = self.w_talking_logits(attention) # (B, T, T, nh) @ (nh, nh) --> (B, T, T, nh)
            attention = attention.permute(0, 3, 1, 2) # (B, T, T, nh) --> (B, nh, T, T)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y

In [None]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = TalkingHeadsCausalSelfAttention(Config)
print(layer(x))

tensor([[[-2.8340, -0.0426,  1.5041,  ..., -0.4150, -3.3890, -0.2439],
         [-2.8062,  0.0277,  1.5124,  ..., -0.4782, -3.4129, -0.2836],
         [-2.7896,  0.0685,  1.5190,  ..., -0.5058, -3.4253, -0.2956],
         ...,
         [-2.8243,  0.0383,  1.5419,  ..., -0.4916, -3.4096, -0.2543],
         [-2.8244,  0.0396,  1.5425,  ..., -0.4902, -3.4108, -0.2540],
         [-2.8255,  0.0398,  1.5407,  ..., -0.4893, -3.4124, -0.2549]]],
       grad_fn=<ViewBackward0>)


# v3: GQA, MQA, MLA

In [None]:
class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention: Uses single key and value heads shared across all query heads.
    """
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        # MODIFIED: Separate projections - Q has n_heads, K&V have only 1 head each
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.head_size)
        self.w_v = nn.Linear(self.d_model, self.head_size)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # MODIFIED: Separate Q, K, V projections
        q = self.w_q(x).view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, D) --> (B, T, nh, hs) --> (B, nh, T, hs)
        k = self.w_k(x).view(B, T, 1, self.head_size).transpose(1, 2) # (B, T, hs) --> (B, T, 1, hs) --> (B, 1, T, hs)
        v = self.w_v(x).view(B, T, 1, self.head_size).transpose(1, 2) # (B, T, hs) --> (B, T, 1, hs) --> (B, 1, T, hs)

        # MODIFIED: Expand K&V to match Q's head dimension for broadcasting
        # expansion is not a must, but will make broadcasting more clear
        k = k.expand(B, self.n_heads, T, self.head_size) # (B, 1, T, hs) --> (B, nh, T, hs)
        v = v.expand(B, self.n_heads, T, self.head_size) # (B, 1, T, hs) --> (B, nh, T, hs)

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y

In [None]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = MultiQueryAttention(Config)
print(layer(x))

tensor([[[-0.1048,  0.0370,  0.1984,  0.1773, -0.1839,  0.0440,  0.2067,
          -0.4353, -0.0939,  0.3132, -0.3061,  0.0848,  0.1260,  0.0328,
          -0.1256,  0.0107],
         [ 0.0250,  0.0641,  0.1823,  0.1719, -0.0930,  0.0568,  0.0918,
          -0.2752, -0.2561,  0.0892, -0.3545,  0.1193,  0.1654, -0.0068,
          -0.0926,  0.0482],
         [ 0.0592,  0.1439,  0.2031,  0.0955,  0.0162,  0.0046, -0.0912,
          -0.2180, -0.2133, -0.0308, -0.4258,  0.1769,  0.0639,  0.0609,
          -0.0916,  0.0108],
         [ 0.0729,  0.1623,  0.2129,  0.0760,  0.0450, -0.0102, -0.1399,
          -0.2054, -0.2114, -0.0646, -0.4486,  0.2016,  0.0450,  0.0770,
          -0.0845, -0.0055],
         [ 0.0589,  0.1553,  0.2128,  0.0523,  0.0501,  0.0032, -0.1384,
          -0.2100, -0.1703, -0.0458, -0.4251,  0.2069,  0.0277,  0.0849,
          -0.0972, -0.0168],
         [ 0.0599,  0.1499,  0.2060,  0.0518,  0.0500,  0.0150, -0.1305,
          -0.1953, -0.1744, -0.0463, -0.4096,  0.199

In [None]:
class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention: Groups multiple query heads to share K&V heads. Compromise between MHA and MQA.
    Simple implementation. Using repeat_interleave which requires extra memory.
    """
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        # MODIFIED: Number of K&V heads (groups)
        assert config.n_heads % config.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_kv_heads = config.n_kv_heads
        self.n_reps = self.n_heads // self.n_kv_heads

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

         # MODIFIED: Separate QKV projections
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_v = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # MODIFIED: separate QKV
        q = self.w_q(x).view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, D)@(D, D)-->(B, T, D)-->(B, T, nh, hs)-->(B, nh, T, hs)
        k = self.w_k(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)
        v = self.w_v(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)

        # MODIFIED: Repeat K&V heads with repeat_interleave, extra memory used
        k = torch.repeat_interleave(k, self.n_reps, dim=1) # (B, nh_kv, T, hs) --> (B, nh, T, hs)
        v = torch.repeat_interleave(v, self.n_reps, dim=1) # (B, nh_kv, T, hs) --> (B, nh, T, hs)

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y


In [None]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = GroupedQueryAttention(Config)
print(layer(x))

tensor([[[-0.0998, -0.2203,  0.2570,  0.3095, -0.5632, -0.4385, -0.1056,
           0.2443, -0.3195,  0.3184, -0.3649, -0.3327, -0.1754, -0.1521,
           0.0035, -0.4032],
         [-0.0241, -0.1858,  0.3369,  0.3096, -0.5109, -0.4098, -0.0690,
           0.3041, -0.2755,  0.2990, -0.3211, -0.2464, -0.1262, -0.2037,
           0.0553, -0.4309],
         [-0.0709, -0.1929,  0.3086,  0.3396, -0.5252, -0.4374, -0.0714,
           0.3116, -0.2691,  0.3350, -0.3915, -0.2731, -0.1580, -0.1831,
           0.0336, -0.4853],
         [-0.1138, -0.1755,  0.2827,  0.3573, -0.5378, -0.4545, -0.0732,
           0.3054, -0.2665,  0.3543, -0.4031, -0.2883, -0.1889, -0.2050,
           0.0135, -0.4879],
         [-0.1142, -0.1764,  0.2786,  0.3451, -0.5458, -0.4456, -0.0904,
           0.3044, -0.2773,  0.3661, -0.4040, -0.2972, -0.1840, -0.1962,
           0.0040, -0.4900],
         [-0.1215, -0.1853,  0.3020,  0.3758, -0.5414, -0.4579, -0.0802,
           0.3391, -0.2675,  0.3769, -0.4411, -0.298

In [None]:
# uses broadcasting and reshaping to avoid memory duplication:

# Reshape Q instead of expanding K,V: Groups Q heads into (B, n_kv_heads, n_rep, T, hs)
# Use broadcasting: K,V are unsqueezed to (B, n_kv_heads, 1, T, hs) for broadcasting
# No memory duplication: All operations work with views and broadcasting

class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention: Groups multiple query heads to share K&V heads. Compromise between MHA and MQA.
    """
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        # MODIFIED: Number of K&V heads (groups)
        assert config.n_heads % config.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_kv_heads = config.n_kv_heads
        self.n_reps = self.n_heads // self.n_kv_heads

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        # MODIFIED: Separate QKV projections
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_v = nn.Linear(self.d_model, self.head_size * self.n_kv_heads)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        # MODIFIED: mask size to fit the (B, nh_kv, n_rep, T, T) attention shape
        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, 1, self.block_size, self.block_size))  # reshape for (B, nh_kv, n_rep, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # MODIFIED: separate QKV
        q = self.w_q(x).view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, D)@(D, D)-->(B, T, D)-->(B, T, nh, hs)-->(B, nh, T, hs)
        k = self.w_k(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)
        v = self.w_v(x).view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2) # (B, T, nh_kv*hs)-->(B, T, nh_kv, hs)-->(B, nh_kv, T, hs)

        # MODIFIED v2: reshape Q + expanding KV, no repeat_interleave, no extra memory use.
        # Q: (B, nh, T, hs) -(reshape)-> (B, nh_kv, n_rep, T, hs),
        # K&V: (B, nh_kv, T, hs) -(reshape)-> (B, nh_kv, 1, T, hs) -(expand)-> (B, nh_kv, n_rep, T, hs)
        q = q.view(B, self.n_kv_heads, self.n_reps, T, self.head_size)
        k = k.view(B, self.n_kv_heads, 1, T, self.head_size).expand(B, self.n_kv_heads, self.n_reps, T, self.head_size)
        v = v.view(B, self.n_kv_heads, 1, T, self.head_size).expand(B, self.n_kv_heads, self.n_reps, T, self.head_size)

        # MODIFIED v2: attention
        attention = q @ k.transpose(-1, -2) # (B, nh_kv, n_rep, T, hs) @ (B, nh_kv, n_rep, hs, T) --> (B, nh_kv, n_rep, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh_kv, n_rep, T, T) @ (B, nh_kv, n_rep, T, hs) --> (B, nh_kv, n_rep, T, hs)
        y = y.permute(0, 3, 1, 2, 4).contiguous().view(B, T, D) # (B, nh_kv, n_rep, T, hs)-->(B, T, nh_kv, n_rep, hs)-->(B, T, D), was (B, nh, T, hs)-->(B, T, nh, hs)-->(B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y


In [None]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = GroupedQueryAttention(Config)
print(layer(x))

tensor([[[ 0.0894, -0.1228,  0.1479,  0.1442,  0.1159, -0.0993,  0.0558,
          -0.0434, -0.1432, -0.1889,  0.3234, -0.0256, -0.3636,  0.2188,
           0.3530, -0.3256],
         [ 0.1047, -0.1770,  0.2269,  0.1470,  0.1749, -0.1796,  0.1427,
           0.1188, -0.4242, -0.3087,  0.4202,  0.1873, -0.4917,  0.1946,
           0.4046, -0.4180],
         [ 0.1040, -0.1385,  0.1524,  0.1543,  0.1503, -0.1322,  0.0959,
           0.0836, -0.3886, -0.3345,  0.4065,  0.1637, -0.4828,  0.1715,
           0.3891, -0.4110],
         [ 0.1297, -0.1269,  0.1058,  0.1403,  0.1385, -0.1293,  0.0749,
           0.0716, -0.3534, -0.3310,  0.3972,  0.1511, -0.4868,  0.1714,
           0.3889, -0.4163],
         [ 0.0965, -0.1708,  0.0774,  0.1488,  0.1709, -0.1204,  0.1045,
           0.0893, -0.3621, -0.3165,  0.3536,  0.1568, -0.4565,  0.1682,
           0.3469, -0.3900],
         [ 0.0837, -0.1741,  0.0810,  0.1271,  0.1624, -0.1013,  0.1034,
           0.1005, -0.3408, -0.3286,  0.3465,  0.137

In [None]:
class MultiHeadLatentAttention(nn.Module):
    """
    Multi-Head Latent Attention (MLA): Compresses K&V into lower-dimensional latent representations.
    paper: https://arxiv.org/pdf/2405.04434, section 2.1
    """
    def __init__(self, config, latent_dim=16) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        # MODIFIED
        self.latent_dim = latent_dim
        self.w_q = nn.Linear(self.d_model, self.d_model)

        # kv share the same latent vector, hence only "one" latent vector need to be cached during inference.
        # plus, during inference, w_k_decode/w_v_decode can be merged with w_q/w_o respectively, hence no decoding required
        self.w_kv_compress = nn.Linear(self.d_model, latent_dim, bias=False)
        self.w_k_decode = nn.Linear(latent_dim, self.d_model, bias=False)
        self.w_v_decode = nn.Linear(latent_dim, self.d_model, bias=False)

        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.w_q(x) # (B, T, D) @ (D, D) --> (B, T, D)
        latent = self.w_kv_compress(x) # (D, latent_dim)
        k = self.w_k_decode(latent) # (B, T, D)
        v = self.w_v_decode(latent) # (B, T, D)

        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs)
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs)
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs)

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y

In [None]:
x = torch.rand(1, Config.block_size, Config.d_model)
layer = MultiHeadLatentAttention(Config)
print(layer(x))

tensor([[[-0.2075,  0.2085,  0.1388, -0.0365, -0.1302,  0.0492,  0.0641,
          -0.0202,  0.2320, -0.0848,  0.2486, -0.0417,  0.0033, -0.0470,
           0.2917, -0.1073],
         [-0.2155,  0.2575,  0.1465, -0.0253, -0.1123,  0.0525,  0.0087,
          -0.0201,  0.2502, -0.1432,  0.2767,  0.0568, -0.0008, -0.0500,
           0.3169, -0.1091],
         [-0.2121,  0.2804,  0.1482, -0.0247, -0.1220,  0.0549, -0.0424,
          -0.0265,  0.2719, -0.1567,  0.2916,  0.0959,  0.0063, -0.0386,
           0.3454, -0.1459],
         [-0.2080,  0.2604,  0.1383, -0.0291, -0.1334,  0.0558, -0.0374,
          -0.0242,  0.2811, -0.1271,  0.2855,  0.0742,  0.0047, -0.0344,
           0.3431, -0.1514],
         [-0.2201,  0.2577,  0.1248, -0.0352, -0.1288,  0.0751, -0.0552,
          -0.0147,  0.2945, -0.1353,  0.2889,  0.0880, -0.0098, -0.0156,
           0.3605, -0.1758],
         [-0.2299,  0.2581,  0.1070, -0.0353, -0.1213,  0.0848, -0.0753,
          -0.0078,  0.3108, -0.1415,  0.3037,  0.099

# GPT

In [None]:
class MLP(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        self.fc1 = nn.Linear(config.d_model, 4 * config.d_model)
        self.fc2 = nn.Linear(4 * config.d_model, config.d_model)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

# https://docs.pytorch.org/docs/stable/generated/torch.nn.GELU.html
# gelu(x) = x * phi(x), where phi(x) is the Cumulative Distribution Function for Gaussian Distribution.

In [None]:
class Block(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        ATTENTION_REGISTRY = {
            "causal": CausalSelfAttention,
            "talk_heads": TalkingHeadsCausalSelfAttention,
            "mqa": MultiQueryAttention,
            "gqa": GroupedQueryAttention,
            "mla": MultiHeadLatentAttention,
        }
        self.attention_layer = ATTENTION_REGISTRY[config.attention_type](config)
        self.ffn = MLP(config)
        self.ln1 = nn.LayerNorm(config.d_model)
        self.ln2 = nn.LayerNorm(config.d_model)

    def forward(self, x):
        # Pre-layer-norm: LN before attention and FFN
        x = x + self.attention_layer(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


In [None]:
class GPT(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            w_token_emb = nn.Embedding(config.vocab_size, config.d_model),
            w_pos_emb = nn.Embedding(config.block_size, config.d_model),
            trans_in_dropout = nn.Dropout(config.dropout_rate), # transformer input dropout
            trans_out_ln = nn.LayerNorm(config.d_model), # layer norm after transformer layers, before feeding into lm_head
            trans_layers = nn.ModuleList(Block(config) for _ in range(config.n_layers)),
        ))

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # init weights
        # report num of params

    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers
        device = idx.device

        B, T = idx.size()
        assert T <= self.config.block_size, f"Sequence length {T}, block size {self.config.block_size}"

        pos = torch.arange(0, T, dtype=torch.long, device=device) # shape: (T), arange returns a 1-D tensor with range [start, end)

        # input
        token_emb = self.transformer.w_token_emb(idx) # (B, T, d_model)
        pos_emb = self.transformer.w_pos_emb(pos) # (T, d_model)
        x = token_emb + pos_emb # (B, T, d_model) + (T, d_model) with broadcasting --> (B, T, d_model)

        # forward
        x = self.transformer.trans_in_dropout(x)
        for layer in self.transformer.trans_layers:
            x = layer(x)
        x = self.transformer.trans_out_ln(x) # (B, T, d_model)
        logits = self.lm_head(x) # (B, T, vocab_size)

        # loss
        if targets is not None:
            logits = logits.view(B * T, self.config.vocab_size)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None

        return logits, loss


# Test Attentions with a Simple Task

https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py

In [None]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
torch.manual_seed(1337)

batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
@dataclass
class Config:
    attention_type: str = 'causal'
    block_size: int = 32
    vocab_size: int = 65
    n_layers: int = 4
    n_heads: int = 4
    n_kv_heads: int = 2
    d_model: int = 64
    dropout_rate: float = 0.0
    bias: bool = True


model = GPT(Config)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# # generate from the model
# context = torch.zeros((1, 1), dtype=torch.long, device=device)
# print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

0.210592 M parameters
step 0: train loss 4.4870, val loss 4.4812
step 100: train loss 2.6412, val loss 2.6553
step 200: train loss 2.3580, val loss 2.3581
step 300: train loss 2.2012, val loss 2.2083
step 400: train loss 2.0734, val loss 2.0922
step 500: train loss 1.9744, val loss 1.9897
step 600: train loss 1.8508, val loss 1.8655
step 700: train loss 1.6373, val loss 1.6542
step 800: train loss 1.1933, val loss 1.2377
step 900: train loss 0.8628, val loss 0.9049
step 1000: train loss 0.6628, val loss 0.6903
step 1100: train loss 0.4888, val loss 0.5265
step 1200: train loss 0.4063, val loss 0.4287
step 1300: train loss 0.3410, val loss 0.3660
step 1400: train loss 0.2917, val loss 0.3150
step 1500: train loss 0.2529, val loss 0.2758
step 1600: train loss 0.2398, val loss 0.2597
step 1700: train loss 0.2310, val loss 0.2444
step 1800: train loss 0.2068, val loss 0.2260
step 1900: train loss 0.1864, val loss 0.2049
step 2000: train loss 0.1806, val loss 0.1971
step 2100: train loss 0.

In [None]:
@dataclass
class Config:
    attention_type: str = 'talk_heads'
    block_size: int = 32
    vocab_size: int = 65
    n_layers: int = 4
    n_heads: int = 4
    n_kv_heads: int = 2
    d_model: int = 64
    dropout_rate: float = 0.0
    bias: bool = True


model = GPT(Config)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# # generate from the model
# context = torch.zeros((1, 1), dtype=torch.long, device=device)
# print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

0.210592 M parameters
step 0: train loss 4.2271, val loss 4.2197
step 100: train loss 2.6239, val loss 2.6475
step 200: train loss 2.3401, val loss 2.3560
step 300: train loss 2.1910, val loss 2.2089
step 400: train loss 2.0817, val loss 2.1092
step 500: train loss 1.9874, val loss 2.0083
step 600: train loss 1.8646, val loss 1.8761
step 700: train loss 1.6545, val loss 1.6903
step 800: train loss 1.2388, val loss 1.2895
step 900: train loss 0.8888, val loss 0.9229
step 1000: train loss 0.6572, val loss 0.6915
step 1100: train loss 0.5364, val loss 0.5698
step 1200: train loss 0.4070, val loss 0.4465
step 1300: train loss 0.3373, val loss 0.3666
step 1400: train loss 0.2996, val loss 0.3205
step 1500: train loss 0.2669, val loss 0.2877
step 1600: train loss 0.2441, val loss 0.2681
step 1700: train loss 0.2169, val loss 0.2313
step 1800: train loss 0.2198, val loss 0.2385
step 1900: train loss 0.1972, val loss 0.2215
step 2000: train loss 0.1864, val loss 0.2029
step 2100: train loss 0.

In [None]:
@dataclass
class Config:
    attention_type: str = 'mqa'
    block_size: int = 32
    vocab_size: int = 65
    n_layers: int = 4
    n_heads: int = 4
    n_kv_heads: int = 2
    d_model: int = 64
    dropout_rate: float = 0.0
    bias: bool = True


model = GPT(Config)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# # generate from the model
# context = torch.zeros((1, 1), dtype=torch.long, device=device)
# print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

0.185472 M parameters
step 0: train loss 4.3284, val loss 4.3242
step 100: train loss 2.6253, val loss 2.6457
step 200: train loss 2.4919, val loss 2.5095
step 300: train loss 2.3973, val loss 2.4069
step 400: train loss 2.3516, val loss 2.3733
step 500: train loss 2.3048, val loss 2.3052
step 600: train loss 2.2480, val loss 2.2665
step 700: train loss 2.2068, val loss 2.2288
step 800: train loss 2.1658, val loss 2.1994
step 900: train loss 2.1147, val loss 2.1528
step 1000: train loss 2.0840, val loss 2.1300
step 1100: train loss 2.0514, val loss 2.1113
step 1200: train loss 2.0286, val loss 2.0981
step 1300: train loss 2.0089, val loss 2.0779
step 1400: train loss 1.9822, val loss 2.0471
step 1500: train loss 1.9602, val loss 2.0363
step 1600: train loss 1.9391, val loss 2.0137
step 1700: train loss 1.9245, val loss 2.0073
step 1800: train loss 1.9038, val loss 2.0025
step 1900: train loss 1.8975, val loss 1.9981
step 2000: train loss 1.8744, val loss 1.9879
step 2100: train loss 1.

In [None]:
@dataclass
class Config:
    attention_type: str = 'gqa'
    block_size: int = 32
    vocab_size: int = 65
    n_layers: int = 4
    n_heads: int = 4
    n_kv_heads: int = 2
    d_model: int = 64
    dropout_rate: float = 0.0
    bias: bool = True


model = GPT(Config)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# # generate from the model
# context = torch.zeros((1, 1), dtype=torch.long, device=device)
# print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

0.193792 M parameters
step 0: train loss 4.2811, val loss 4.2823
step 100: train loss 2.6382, val loss 2.6425
step 200: train loss 2.4949, val loss 2.5020
step 300: train loss 2.4018, val loss 2.4161
step 400: train loss 2.3529, val loss 2.3487
step 500: train loss 2.2826, val loss 2.2993
step 600: train loss 2.2474, val loss 2.2645
step 700: train loss 2.1883, val loss 2.2123
step 800: train loss 2.1669, val loss 2.1887
step 900: train loss 2.1161, val loss 2.1564
step 1000: train loss 2.0819, val loss 2.1237
step 1100: train loss 2.0446, val loss 2.1096
step 1200: train loss 2.0093, val loss 2.0791
step 1300: train loss 1.9978, val loss 2.0664
step 1400: train loss 1.9771, val loss 2.0363
step 1500: train loss 1.9579, val loss 2.0327
step 1600: train loss 1.9309, val loss 2.0276
step 1700: train loss 1.9093, val loss 2.0122
step 1800: train loss 1.8935, val loss 1.9917
step 1900: train loss 1.8828, val loss 1.9743
step 2000: train loss 1.8619, val loss 1.9665
step 2100: train loss 1.

In [None]:
@dataclass
class Config:
    attention_type: str = 'mla'
    block_size: int = 32
    vocab_size: int = 65
    n_layers: int = 4
    n_heads: int = 4
    n_kv_heads: int = 2
    d_model: int = 64
    dropout_rate: float = 0.0
    bias: bool = True


model = GPT(Config)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# # generate from the model
# context = torch.zeros((1, 1), dtype=torch.long, device=device)
# print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

0.18944 M parameters
step 0: train loss 4.3211, val loss 4.3351
step 100: train loss 2.6422, val loss 2.6467
step 200: train loss 2.5181, val loss 2.5328
step 300: train loss 2.4352, val loss 2.4318
step 400: train loss 2.3668, val loss 2.3769
step 500: train loss 2.3253, val loss 2.3384
step 600: train loss 2.2967, val loss 2.3073
step 700: train loss 2.2611, val loss 2.2794
step 800: train loss 2.2204, val loss 2.2563
step 900: train loss 2.1763, val loss 2.2236
step 1000: train loss 2.1505, val loss 2.1800
step 1100: train loss 2.1271, val loss 2.1670
step 1200: train loss 2.1007, val loss 2.1408
step 1300: train loss 2.0637, val loss 2.1127
step 1400: train loss 2.0502, val loss 2.0949
step 1500: train loss 2.0251, val loss 2.0857
step 1600: train loss 1.9922, val loss 2.0839
step 1700: train loss 1.9729, val loss 2.0589
step 1800: train loss 1.9651, val loss 2.0461
step 1900: train loss 1.9336, val loss 2.0342
step 2000: train loss 1.9321, val loss 2.0240
step 2100: train loss 1.9

In [None]:
Results:
causal: step 4999: train loss 0.1054, val loss 0.1122, 0.210592 M parameters
talk: step 4999: train loss 0.1087, val loss 0.1157, 0.210592 M parameters
mqa: step 4999: train loss 1.6523, val loss 1.8181, 0.185472 M parameters
gqa: step 4999: train loss 1.6479, val loss 1.7981 (n_heads=4, n_kv_heads=2), 0.193792 M parameters
mla: step 4999: train loss 1.6794, val loss 1.8569 (d_model=64, latent_dim=16), 0.18944 M parameters

# DIN (Q & KV have different seq length)

In [None]:
"""
User behavior: [B, L, D], batch size, length of behavior sequence, dim of embedding.
Candidate item: [B, D].

Expected tensors:
Output user embedding: [B, D]
当然不一定像 BERT 那样为了保证可以叠 layer，从而使 user embedding 保持 dim of D
最新的 spice，click history 3 heads, purchase history 1 head, 最终 user sequence vector 为 3 * 64 + 64

attention scores: [B, n_heads, L], attention scores of candidate item, wrt all L behaviors

To perform multi-head attentions, there must be linear transformations on input embeddings.
Inner product between candidate embedding and behavior embedding, can only provide one-head embedding.

Attention 也不一定像 cross attention 这样实现
DIN 原文里，是 k, q out product + 自身, 过一层 dim=36 的 FC layer。SP 也是用过 MLP 得到 attention scores
"""


class DIN:
    def __init__(self, d_model, n_heads):
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.w_q = nn.Linear(d_model, d_model)
        self.w_kv = nn.Linear(d_model, 2 * d_model)

        # self.out = nn.Linear(d_model, d_model)

    def forward(self, seq, candidates):
        B, L, D = seq.size()

        q = self.w_q(candidates)  # q: (B, d_model)
        k, v = self.w_kv(seq).split(self.d_model, dim=2)  # k, v: (B, L, d_model)

        q = q.view(B, 1, self.n_heads, self.d_head).transpose(1, 2)  # (B, n_heads, 1, d_head)
        k = k.view(B, L, self.n_heads, self.d_head).transpose(1, 2)  # (B, n_heads, L, d_head)
        v = v.view(B, L, self.n_heads, self.d_head).transpose(1, 2)  # (B, n_heads, L, d_head)

        attentions = q @ k.transpose(-2, -1)  # (B, n_heads, 1, d_head) @ (B, n_heads, d_head, L) -> (B, n_heads, 1, L)
        attentions = attentions / (self.d_head ** 0.5)  # scaled
        attentions = F.softmax(attentions, dim=-1)

        y = attentions @ v  # (B, n_heads, 1, L) @ (B, n_heads, L, d_head) -> (B, n_heads, 1, d_head)

        y = y.transpose(1, 2).contiguous().view(B, self.d_model)  # (B, n_heads, 1, d_head) -> (B, d_model)

        # y = self.out(y)  # (B, d_model) -> (B, d_model)

        return y


"""
attention can also be implemented with broadcasting + reduce sum
attentions: (B, n_heads, 1, L)
v: (B, n_heads, L, d_head)
v * attentions.transpose(-1, -2), with broadcasting the (L, 1) part d_head time -> (B, n_heads, L, d_head)
then sum along axis = -2 -> (B, n_heads, d_head)

broadcasting: [m, n] +-*/ [m, 1] or [1, n], will duplicate the latter n/m times
https://www.youtube.com/watch?v=tKcLaGdvabM&ab_channel=DeepLearningAI
"""