<a href="https://colab.research.google.com/github/rczhen/code_ml/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

In [14]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

@dataclass
class Config:
    block_size: int = 8
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layers: int = 2
    n_heads: int = 4
    d_model: int = 16
    dropout_rate: float = 0.1
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

In [15]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_size = config.d_model // config.n_heads
        self.block_size = config.block_size

        self.attention_dropout = nn.Dropout(config.dropout_rate) # after softmax
        self.residual_dropout = nn.Dropout(config.dropout_rate) # after attention block, before adding with residual connection

        self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.register_buffer("mask", torch.tril(torch.ones(self.block_size, self.block_size)) # register buffer for low triangular matrix mask
                                    .view(1, 1, self.block_size, self.block_size))  # reshape for (B, n_head, T, T) inputs

    def forward(self, x):
        B, T, D = x.size() # batch size, sequence length, d_model

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.w_qkv(x).split(self.d_model, dim=2) # (B, T, D) @ (D, 3D) --> (B, T, 3D) --> split at dim=2 --> (B, T, D)
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size
        v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) # (B, T, nh, hs) --> (B, nh, T, hs), hs for head_size

        # attention
        attention = q @ k.transpose(-1, -2) # (B, nh, T, hs) @ (B, nh, hs, T) --> (B, nh, T, T)
        attention *= self.head_size ** -0.5 # scaled dot product attention
        attention = attention.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        # output
        y = attention @ v # (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, D) # (B, nh, T, hs) --> (B, T, nh, hs) --> (B, T, D)
        y = self.w_o(y) # (B, T, D) @ (D, D) --> (B, T, D)
        y = self.residual_dropout(y)

        return y



In [16]:
x = torch.rand(1, Config.block_size, Config.d_model)
print(x)
layer = CausalSelfAttention(Config)
print(layer(x))

tensor([[[0.3122, 0.0226, 0.8950, 0.0783, 0.1770, 0.9433, 0.9514, 0.1017,
          0.8026, 0.9098, 0.1098, 0.4137, 0.2848, 0.6199, 0.9206, 0.3427],
         [0.4449, 0.6739, 0.8439, 0.0493, 0.0081, 0.6511, 0.3104, 0.7010,
          0.5342, 0.0650, 0.9289, 0.1583, 0.7377, 0.7175, 0.9571, 0.8403],
         [0.6918, 0.3198, 0.3399, 0.0931, 0.6521, 0.7145, 0.9907, 0.2112,
          0.6763, 0.5887, 0.5782, 0.3483, 0.3435, 0.1437, 0.8163, 0.6681],
         [0.3850, 0.5111, 0.5583, 0.6228, 0.3647, 0.8345, 0.2221, 0.8506,
          0.2525, 0.1807, 0.2781, 0.2693, 0.7080, 0.1854, 0.1001, 0.9481],
         [0.8707, 0.5919, 0.0291, 0.1116, 0.9614, 0.3914, 0.1568, 0.6211,
          0.6675, 0.4568, 0.7768, 0.0650, 0.8675, 0.5075, 0.0551, 0.6097],
         [0.4323, 0.3263, 0.6006, 0.4719, 0.7960, 0.0774, 0.3016, 0.0890,
          0.2929, 0.8220, 0.3581, 0.1270, 0.4589, 0.2781, 0.5647, 0.6605],
         [0.3650, 0.5132, 0.2096, 0.2240, 0.3324, 0.7566, 0.9816, 0.5836,
          0.5430, 0.3579, 0.5746