<a href="https://colab.research.google.com/github/rczhen/code_ml/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

In [18]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

@dataclass
class Config:
    block_size: int = 8
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layers: int = 2
    n_heads: int = 4
    d_model: int = 16
    dropout_rate: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

In [23]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.w_qkv(x).split(self.d_model, dim=2) # (B, T, D) @ (D, 3D) --> (B, T, 3D) --> split at dim=2 --> (B, T, D)
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y


"""
why scaled?
if no normalization, the variance of weights wil be on the order of head_size, here is 16
when deviding by sqrt(head_size), bring the variance back
"Scaled" attention additional divides attention scores by 1/sqrt(head_size).
so when input Q,K are unit variance, attentions will be unit variance too
and Softmax will stay diffuse and not saturate too much.
"""


'\nwhy scaled? \nif no normalization, the variance of weights wil be on the order of head_size, here is 16\nwhen deviding by sqrt(head_size), bring the variance back\n"Scaled" attention additional divides attention scores by 1/sqrt(head_size). \nso when input Q,K are unit variance, attentions will be unit variance too\nand Softmax will stay diffuse and not saturate too much.\n'

In [20]:
x = torch.rand(1, Config.block_size, Config.d_model)
print(x)
layer = CausalSelfAttention(Config)
print(layer(x))

tensor([[[0.5348, 0.4341, 0.4442, 0.5332, 0.4855, 0.7372, 0.8768, 0.4357,
          0.3864, 0.5594, 0.1571, 0.1295, 0.4696, 0.7287, 0.9868, 0.8231],
         [0.7122, 0.8425, 0.4135, 0.6154, 0.7098, 0.0762, 0.3721, 0.8747,
          0.8312, 0.4790, 0.2323, 0.8670, 0.6137, 0.8522, 0.5907, 0.9117],
         [0.2219, 0.5944, 0.6904, 0.6469, 0.8516, 0.9672, 0.8572, 0.9335,
          0.1321, 0.1047, 0.6608, 0.3437, 0.5292, 0.3708, 0.1160, 0.3417],
         [0.4557, 0.9365, 0.1227, 0.0639, 0.4600, 0.2101, 0.2688, 0.0186,
          0.5006, 0.6234, 0.2310, 0.2598, 0.2191, 0.0989, 0.3002, 0.9022],
         [0.3349, 0.2037, 0.1727, 0.9087, 0.4454, 0.9141, 0.4793, 0.0218,
          0.7232, 0.7357, 0.3814, 0.8412, 0.6648, 0.3664, 0.3311, 0.3038],
         [0.2569, 0.0541, 0.4904, 0.1427, 0.8338, 0.5350, 0.7984, 0.9666,
          0.9394, 0.0824, 0.8052, 0.2830, 0.7710, 0.9442, 0.8201, 0.5451],
         [0.9633, 0.0586, 0.4428, 0.1028, 0.0578, 0.4362, 0.4324, 0.9554,
          0.9741, 0.3322, 0.4701