## Understanding GPT

In [1]:
import torch
import torch.nn as nn
import  torch.nn.functional as ff

In [13]:
import math
from functools import partial
from dataclasses import dataclass
from turtle import forward

In [5]:
from nanochat.common import get_dist_info
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW

In [14]:
@dataclass
class GPTConfig:
    seq_len = 1024
    vocab_size = 50304
    n_layer = 12
    n_head = 66
    n_kv_head = 6
    emb_dim = 768
    head_dim = emb_dim/n_head

In [15]:
@torch.compile
def RMS(x, epsilon):
    return x / torch.sqrt(torch.mean(x**2, dim =-1, keepdim = True) + epsilon)

def relu(x):
    return x * ( x > 0)
        

In [16]:
class CausalSelfAttention(nn.Module):

    def __init__(self, config, layer_idx) -> None:
        super().__init__()
        self.layer_idx = layer_idx
        self.n_head = config.n_head
        self.model_dim = config.model_dim
        self.head_dim = self.model_dim // self.n_head
        self.n_kv_head = config.n_kv_head
        assert self.model_dim % self.n_head == 0
        assert self.n_kv_head <= self.n_head
        assert self.n_head % self.n_kv_head == 0
        self.q = nn.Linear(self.model_dim , self.n_head * self.head_dim, bias = False)
        self.k = nn.Linear(self.model_dim , self.n_head * self.head_dim, bias = False)
        self.v = nn.Linear(self.model_dim , self.n_head * self.head_dim, bias = False)
        self.proj == nn.Linear(self.model_dim , self.model_dim, bias = False)

class MLP(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.up_proj = nn.Linear(config.emb_dim, 4*config.emb_dim, bias = False)
        self.down_proj = nn.Linear(config.emb_dim * 4, config.emb_dim, bias = False)

    def forward(self, x):
        return self.down_proj((relu(self.uproj(x)).square()))
    

In [17]:
class Block(nn.Module):
    def __init__(self, config, layer_idx) -> None:
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(RMS(x))
        x = x + self.mlp(RMS(x))
        return x

In [18]:
class GPT(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.emb_dim)
        self.blocks = [Block(config,layer_idx) for layer_idx in range(config.n_layer)]
        self.lm_head(config.seq_len, config.vocab_size, bias = False)
        cos, sin = self._precompute_rotary_embeddings(config.seq_len*10, config.head_dim)
        self.register_buffer("cos", cos, persistent = True)
        self.register_buffer("sin", sin, persistent = True)
        

    def forward(self):
        pass
    
    def get_device(self):
        return self.wte.weight.device

    def _precompute_rotary_embeddings(self, seq_len, head_dim, base= 1000, device = None):
        if device is None:
            device= self.get_device()
        # Stride over the channels
        
        channel_range = torch.arange(0, head_dim, 2, dtype = torch.float32, device = device)
        inv_freq = 1.0/(base**(2*channel_range/head_dim))
        token_index = torch.arange(seq_len, dtype = torch.float32, device = device)
        
        freqs = torch.outer(token_index, inv_freq)
        cos, sin = torch.cos(freqs), torch.sin(freq)
        # shape = (b x seq_len x n_heads x head_dim/2)
        cos, sin = cos.bfloat16(), sin.bfloat16()
        cos = torch.unsqueeze(torch.unsqueeze(cos, 0), 2)
        sin = torch.unsqueeze(torch.unsqueeze(sin, 0), 2)
        
        return cos, sin
        
        
    @torch.inference_mode()
    def generate(self, tokens, max_tokens, temperature, top_k = None, seed = 42):
        assert isinstance(tokens, list)
        device= self.get_device()
        rng = None
        if temperature > 0:
            rng = torch.Generator(device = device)
            rng.manual_seed = seed
        ids = torch.tensor([tokens], dtype = torch.long, device = device)
        for _ in range(max_tokens):
            
            logits = self.forward(ids)
            # Batch x Seq_Len x Vocab_Size
            logits = logits[:, -1, : ]
            if top_k is not None:
                # Batch x 1 x k
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits = torch.where(logits < v[:, [-1]], -float('inf'), logits)
            if temperature > 0:
                logits = logits / temperature
                probs = F.softmax(logits, dim = -1)
                next_ids = torch.multinomial(probs, num_samples = 1, generator=rng)
            else:
                next_ids = torch.argmax(logits, dim=-1, keep = True)
            ids = torch.cat( (ids, next _ids), 1)
            token = next_ids.item()
            yield token
            

# Training GPT

In this notebook you will:
- Inspect NanoChat's GPT architecture
- Run a forward pass
- Compute masked language modeling loss
- Train a tiny chat model for a few steps

In [3]:
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import get_tokenizer

tokenizer = get_tokenizer()