## Notebook Organization Notes and Errata
This notebook is meant to be both an educational tool, and lightweight-yet-fully-functional utility for training language models. These are two generally opposing goals, and this tension has motivated some of the choices and sacrifices I've made.

* This notebook is organized in a bottom-up approach. Lower-level components, such as specific embedding, attention, and feed-forward classes are defined first, and used later, so that one could feasibly shift+enter their way through the entire notebook without issue. This choice was also motivated by the 'educational' aspect of this notebook - the individual components and their structure/code are important, and we build up from there.

* I have presented, in many cases, several different classes for each component. This is an intentional choice. One can (hopefully, if I've done my job right) see the similarities between the methodologies in spite of their minor differences. My focus on the Llama(-esque) architecture as the "modern" approach is primarily dictated by the ubiquity of this architecture and their specific metholodological choices across different language models, and the clear success it has shown at *modelling language*. I've put some effort into making the components more interchangeable than their original forms, and I don't think I've introduced any significant errors (famous last words).

* For clarity of code in the notebook itself, non-architectural utilities and other methods have been moved to separate files, so that the focus of the code presented here is *how* the information is modified as it flows through the model. I've also neglected to include certain optimizations and additions (e.g. KV caches, etc) to avoid adding undue complexity. If you want to use transformer-based models for any task that requires performance, you have *plenty* of other frameworks to use.

* The training loop method uses several values set in globals.py (accessed via a widget interface), which it copies the values of so that they remain consistent during a single training run, even if the widget is changed. Widget controls aren't incredibly fine-grained, but are serviceable. This system will also "allow" you to use settings that will fail/use too much VRAM. Settings-tuning is on my radar, but is out of scope for now.

* It's *very* important to note that a vast majority of the code in this notebook (and in the utility files) is adapted from other sources. I've done my best to note those sources, and to explain any modifications I've made. This is the nature of Software Development, and it's my hope that the work I've done here will empower people to dive deeper into the sources themselves, armed with a better understanding of how it's working at the lowest level.

<sup>Thorold Tronrud, ML/AI Software Engineer, StarFish Medical</sup>

In [None]:
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.95, 0)
import math
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
import functools
from torch.utils.checkpoint import checkpoint
from enum import Enum
from dataclasses import dataclass, fields, asdict
from IPython import display
from torchinfo import summary
import tiktoken
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from tqdm import trange
import torch.jit as jit
from typing import List
import globals as g
from modelling_utils import apply_rotary_pos_emb, repeat_kv, Transformer, clean_class_name, pretty_generate
from training_utils import plot_grad_flow, get_lr, get_batch, generate_training_data, generate_training_data_mix, create_raw_text_dataset
import wandb
g.get_device()

# If running on machine with multiple GPUs, you can set a specific device like this:
g.set_device("cuda:0")

## Embeddings
Embeddings convert a stream of discrete tokens into a trajectory through a semantic space. In many cases, these word-token-embeddings are mixed with positional embeddings -- either absolute and learned (e.g. GPT2) or sinusoidal across the sequence length (e.g. LLama, etc). Positional embeddings provide the model information about relative placement of different tokens, and distance between them, which is sometimes useful in language. 

"NoPE", or "No Positional Embeddings" has shown success, however, so YMMV.

In [None]:
class AbsolutePositionalEmbeddings(nn.Module):
    """
    GPT-2 style word-token embeddings + absolute positional embeddings
        Combines word-token-embeddings and word-position-embeddings to transmit
        both pieces of information throughout the rest of the model
    """
    def __str__(self):
        return "AbsolutePositionalEmbeddings"
        
    def __init__(self, config):
        super().__init__()
        # Validity assertions
        assert config.max_seq_len > 0, f"Max positional embeddings must be > 0: set max_seq_len positive"
        
        self.vocab_size = config.vocab_size
        self.n_embed = config.n_embd
        self.max_positional_embeddings = config.max_seq_len

        self.wte = nn.Embedding(self.vocab_size, self.n_embed)
        self.wpe = nn.Embedding(self.max_positional_embeddings, self.n_embed)
        self.drop = nn.Dropout(config.dropout)

    def forward(self, idx):
        # Batch, seq_len
        b, seq_len = idx.size()
        assert seq_len <= self.max_positional_embeddings, f"Input sequence length {seq_len} is greater than max positional embeddings of {self.max_positional_embeddings}"

        tok_emb = self.wte(idx)
        pos_emb = self.wpe(torch.arange(0, seq_len, dtype=torch.long, device=idx.device))
        emb = self.drop(tok_emb + pos_emb)
        return emb

class WordTokenEmbeddings(nn.Module):
    """
    Word-token-embeddings only
        For use in other architectures that handle positional embeddings
        differently (e.g. LLama)
    """
    def __str__(self):
        return "WordTokenEmbeddings"
        
    def __init__(self, config):
        super().__init__()
        self.vocab_size = config.vocab_size
        self.n_embed = config.n_embd

        self.wte = nn.Embedding(self.vocab_size, self.n_embed)
        self.drop = nn.Dropout(config.dropout)

    def forward(self, idx):
        tok_emb = self.wte(idx)
        return self.drop(tok_emb)

class RotaryEmbedding(nn.Module):
    """
    Rotary embeddings - encoding token position through a sinusoidal
    signal imprinted into the data passed to the attention heads
    """        
    def __init__(self, config, differential_layer = False):
        super().__init__()
        if hasattr(config, "rope_scaling_factor"):
            self.scaling_factor = config.scaling_factor
        else:
            self.scaling_factor = 1.0
        if hasattr(config, "qk_rope_dim"):
            self.dim = config.qk_rope_dim
        else:
            self.dim = config.n_embd//config.n_head
        if differential_layer:
            self.dim = self.dim//2 #differential attn
        self.max_position_embeddings = config.max_seq_len
        if hasattr(config, "rope_base"):
            self.base = config.rope_base
        else:
            self.base = 100000
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # For BC we register cos and sin cached
        self.max_seq_len_cached = self.max_position_embeddings
        self._set_cos_sin_cached(self.max_seq_len_cached, device = self.inv_freq.device, dtype = self.inv_freq.dtype)

    def _set_cos_sin_cached(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(
            0, self.max_seq_len_cached, device = self.inv_freq.device, dtype = self.inv_freq.dtype
        ).unsqueeze(0)
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(t.shape[0], -1, 1)
        position_ids_expanded = t[:, None, :].float()
        freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype).to(self.inv_freq.device), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype).to(self.inv_freq.device), persistent=False)
        

    @torch.no_grad()
    def forward(self, x):
        # x: [bs, num_attention_heads, seq_len, head_size]
        bs,nh,sl,embd = x.size()
        if sl > self.max_seq_len_cached:
            self._set_cos_sin_cached(seq_len = sl, device = x.device, dtype = x.dtype)
        return self.cos_cached[:,:sl].to(dtype=x.dtype), self.sin_cached[:,:sl].to(dtype=x.dtype)

"""
Here be dragons:
while RoPE can be extended infinitely, the model has only "learned"
to utilize differences within the training range. To expand context
on the fly, you can either fine-tune it, or modify the frequencies used
by RoPE, which is what we're doing here.
https://arxiv.org/abs/2309.00071
"""

# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(
    num_rotations, dim, base=10000, max_position_embeddings=2048
):
    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
        2 * math.log(base)
    )

# Find dim range bounds based on rotations
def yarn_find_correction_range(
    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
    low = math.floor(
        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
    )
    high = math.ceil(
        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
    )
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case


def yarn_get_mscale(scale=1, mscale=1):
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


def yarn_linear_ramp_mask(min, max, dim):
    if min == max:
        max += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func


class YarnRotaryEmbedding(RotaryEmbedding):
    """
    LLama-style rotary embeddings adapted from Transformers LLama source
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L94

        !!Applied in the self-attention module, not as part of embedding block!!
    """        
    def __init__(self, config):
        self.original_max_position_embeddings = config.max_seq_len
        
        self.beta_fast = config.beta_fast if hasattr(config, "beta_fast") else 32
        self.beta_slow = config.beta_slow if hasattr(config, "beta_slow") else 1
        self.mscale = config.mscale if hasattr(config, "mscale") else 1
        self.mscale_all_dim = config.mscale_all_dim if hasattr(config, "mscale_all_dim") else 0

        super().__init__(config)

    def _set_cos_sin_cached(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        
        freq_extra = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(device=device)
        freq_inter = 1.0 / (self.scaling_factor * self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(device=device)

        low, high = yarn_find_correction_range (
            self.beta_fast,
            self.beta_slow,
            self.dim,
            self.base,
            self.original_max_position_embeddings
        )

        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, self.dim // 2).to(device=device, dtype = torch.float32)
        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        
        t = torch.arange(
            0, self.max_seq_len_cached, device = self.inv_freq.device, dtype = self.inv_freq.dtype
        ).unsqueeze(0)

        position_ids_expanded = t[:, None, :].float()
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(t.shape[0], -1, 1)
        
        freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
        emb = torch.cat((freqs, freqs), dim=-1)

        _mscale = float(
            yarn_get_mscale(self.scaling_factor, self.mscale)
            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
        )
        
        self.register_buffer("cos_cached", _mscale * emb.cos().to(dtype).to(self.inv_freq.device), persistent=False)
        self.register_buffer("sin_cached", _mscale * emb.sin().to(dtype).to(self.inv_freq.device), persistent=False)

## Attention
Attention is the main ingredient that makes Transformers so powerful. An attention block is composed of different attention "heads", which control the flow of information, interact between model layers, and each add their own information to the residual.

Each attention head computes "query", "key", and "value" vectors for each input token. The attention pattern is the dot product of the query and key vectors, giving a value for the effect each token has on every other token. Finally, multiplying this token-token effect matrix by the value matrix gives a result vector for the head.

The results for all heads are stacked side-by-side and passed through a final output layer, which produces the block's predicted "change" to the semantic position of all the tokens.

[This article](https://transformer-circuits.pub/2021/framework/index.html) does an amazing job breaking down the function of attention heads, albeit with a slightly different model for their behaviour.

In [None]:
class MHASelfAttention(nn.Module):
    """
    LLama-style Multi-Head Attention with rotary embeddings
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L577

    Modified to include same manual attention fallback
    """
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # Key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # Output projection
        self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # Regularization
        self.resid_dropout = nn.Dropout(config.dropout)
        self.using_rotary_encoding = False
        if config.rotary_encoding:
            self.rotary_emb = YarnRotaryEmbedding(config)
            self.using_rotary_encoding = True
        else:
            self.rotary_emb = None
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # Causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
                                        .view(1, 1, config.max_seq_len, config.max_seq_len))
            self.attn_dropout = nn.Dropout(config.dropout)
    
    def forward(self, x, position_embeddings = None):
        B, T, C = x.size() # Batch size, sequence length, embedding dimensionality (n_embd)

        # Calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        """
        Primary difference in positional embeddings implementation.
        This encodes the positions of each part of the inputs into the attention
        matrix, as opposed to encoding it directly into the embedding trajectory
        """
        if self.using_rotary_encoding:
            if position_embeddings is None:
                cos, sin = self.rotary_emb(v)
            else:
                cos, sin = position_embeddings
            q, k = apply_rotary_pos_emb(q, k, cos, sin)
        
        # Causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v
            
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.o_proj(y))
        return y

class GQASelfAttention(nn.Module):
    """
    LLama-style Grouped Query Attention with rotary embeddings
    https://arxiv.org/pdf/2305.13245
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L577

    Causal SA is really just the case where the number of KV heads is equal to the number of Q heads,
    but I've separated it here to have a "simpler" case available. Compilation *should* take care
    of most of the performance differences w/r/t performing matrix multiplications in one go or not
    """
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        assert config.n_head % config.n_kv_head == 0
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd//self.n_head
        self.num_key_value_heads = config.n_kv_head
        self.num_key_value_groups = self.n_head // self.num_key_value_heads
        
        # Q, K, and V aren't identical in size anymore, so can't batch all in one MatMul
        self.q_proj = nn.Linear(self.n_embd, self.n_head*self.head_dim, bias = config.bias)
        self.k_proj = nn.Linear(self.n_embd, self.num_key_value_heads*self.head_dim, bias = config.bias)
        self.v_proj = nn.Linear(self.n_embd, self.num_key_value_heads*self.head_dim, bias = config.bias)
        
        # Output projection
        self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        
        # Regularization
        self.resid_dropout = nn.Dropout(config.dropout)

        
        self.using_rotary_encoding = False
        if config.rotary_encoding:
            self.rotary_emb = RotaryEmbedding(config)
            self.using_rotary_encoding = True
        else:
            self.rotary_emb = None
        
        self.dropout = config.dropout
        
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # Causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
                                        .view(1, 1, config.max_seq_len, config.max_seq_len))

    def forward(self, x, position_embeddings = None):
        B, T, C = x.size() # Batch size, sequence length, embedding dimensionality (n_embd)

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
        k = k.view(B, T, self.num_key_value_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.num_key_value_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)

        
        if self.using_rotary_encoding:
            if position_embeddings is None:
                cos, sin = self.rotary_emb(v)
            else:
                cos, sin = position_embeddings
            q, k = apply_rotary_pos_emb(q, k, cos, sin)

        """
        Primary difference with rotary MHA -- KV groups are repeated to match
        Q dimension before being passed through SDPA
        """
        k = repeat_kv(k, self.num_key_value_groups)
        v = repeat_kv(v, self.num_key_value_groups)
        
        # Causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v
            
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.o_proj(y))
        return y

def lambda_init_fn(depth):
    return 0.8 - 0.6 * math.exp(-0.3 * depth)

class DifferentialAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        assert config.n_head % config.n_kv_head == 0
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd//self.n_head//2
        self.num_key_value_heads = config.n_kv_head//2
        self.num_key_value_groups = self.n_head // self.num_key_value_heads
        
        # Q, K, and V aren't identical in size anymore, so can't batch all in one MatMul
        self.q_proj = nn.Linear(self.n_embd, self.n_embd, bias = config.bias)
        self.k_proj = nn.Linear(self.n_embd, self.n_embd//self.num_key_value_groups, bias = config.bias)
        self.v_proj = nn.Linear(self.n_embd, self.n_embd//self.num_key_value_groups, bias = config.bias)
        
        # Output projection
        self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.lambda_init = lambda_init_fn(config.n_layer)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim).normal_(mean=0,std=0.1))
        
        # Regularization
        self.resid_dropout = nn.Dropout(config.dropout)

        self.rotary_emb = RotaryEmbedding(config, differential_layer = True)
        self.using_rotary_encoding = True
        self.dropout = config.dropout
        # replace boolean buffer with an index array
        # booleans cause the PyTorch 2.0 compiler to panic because it assumes
        # it's a changeable output shape, even if it isn't.
        self.register_buffer("every_other_mask_e", (torch.arange(0,self.n_head*2,2)))
        self.register_buffer("every_other_mask_o", (torch.arange(1,self.n_head*2,2)))
        

    def forward(self, x, position_embeddings = None):
        B, T, C = x.size() # Batch size, sequence length, embedding dimensionality (n_embd)
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        q = q.view(B, T, 2*self.n_head, self.head_dim).transpose(1,2)
        k = k.view(B, T, 2*self.num_key_value_heads, self.head_dim).transpose(1,2)
        v = v.view(B, T, self.num_key_value_heads, 2*self.head_dim).transpose(1,2)
        # QK Norm
        q,k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),))
        cos, sin = self.rotary_emb(k)
        q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)

        k = torch.repeat_interleave(k, dim=1, repeats=self.num_key_value_groups)
        v = torch.repeat_interleave(v, dim=1, repeats=self.num_key_value_groups*2)
        
        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1)).type_as(q)
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1)).type_as(q)
        lambda_full = lambda_1 - lambda_2 + self.lambda_init

        attn_weights = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        # attn_weight_size = B, 2xNH, T, HD
        attn = torch.index_select(attn_weights, 1, self.every_other_mask_e) - lambda_full * torch.index_select(attn_weights, 1, self.every_other_mask_o)
        
        
        attn = F.rms_norm(attn,(attn.size(-1),))
        attn = attn * (1 - self.lambda_init)
        attn = attn.transpose(1, 2).reshape(B, T, self.n_head * 2 * self.head_dim)
        y = self.o_proj(attn)
        return y

## Feed Forward/MLP Block

The feed-forward block is generally a shallow multi-layer perceptron (MLP) that calculates a residual on top of the output from the attention block. This further adjusts the direction of the next "step" through the model, although it is not [well understood](https://transformer-circuits.pub/2021/framework/index.html#additional-intuition) what these adjustments directly correspond to (or even, in fact, if they do correspond to any *specific* features outside of rare cases).

The slightly more complex gated MLP provides a more elegant channel for specific pieces of information passed from the attention block to be ignored, by allowing a gate value to be zero simply by generating a negative number in the gate projection layer, which then fully cancels out information without needing to be the exact negative value.

In [None]:
class MLP(nn.Module):
    """
    GPT-2 style MLP feed-forward block
    https://github.com/karpathy/nanoGPT/blob/master/model.py#L78

        Expands embedding dimension to intermediate size, then collapses
        back to pass to next block in transformer
    """
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class GatedMLP(nn.Module):
    """
    (simplified) LLama-style gated MLP feed-forward block
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186

        More parameters, but gating allows for (theoretically) improved information retention between hidden states
    """
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.n_embd
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
        self.act_fn = nn.SiLU()

    def forward(self, hidden_state):
        return self.dropout(self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)))

## Decoder Block
In a modern LM, a chain of decoder blocks uses the input signal to predict the most likely next stop in the semantic space of the model. It's composed of an attention block, and a feed forward block, which each in turn compute a residual "move" on top of the previous state. 

N blocks allows for N adjustments on top of the input, so with more layers a model can make finer moves, or take more factors into account to determine the next step. 

In [None]:
class Block(nn.Module):
    """
    Standard transformers block structure
        Passed attention class, feed forward class, and layer normalization class

        Forward involves attention residual followed by feed forward residual to be
        passed to next block in transformer
    """
    def __init__(self, config, i = 0):
        super().__init__()
        
        if i < len(config.attn_class):
            atncls = config.attn_class[i]
        else:
            atncls = config.attn_class[i%len(config.attn_class)]
        self.attn = atncls(config)
        
        if i < len(config.ff_class):
            ffcls = config.ff_class[i]
        else:
            ffcls = config.ff_class[i%len(config.ff_class)]
            
        self.ff = ffcls(config)

    def forward(self, x):
        x = x + self.attn(F.rms_norm(x,(x.size(-1),)))
        x = x + self.ff(F.rms_norm(x,(x.size(-1),)))
        return x      

## The Transformer
A Transformer model consists of an input-processing embedding layer, which produces a "semantic trajectory" for the inputs, and a series of decoder blocks, which use this trajectory to predict the next step in the semantic space.

To generate output, we need a de-translation layer, to convert from the position in the semantic space back to a mix of the closest tokens. We know this is the behaviour, because our embedding layer and output layer use the same weights, so the output of the "Causal LM" head is the similarity of each token's semantic position with that of the Transformer's final step!

In [None]:
class TransformerForCausalLM(Transformer):
    """
    Architecturally-relevant pieces based on Karpathy's nanoGPT
    https://github.com/karpathy/nanoGPT/blob/master/model.py#L118
    """
    def __init__(self, config):
        super().__init__()
        self.conf = {field.name: getattr(obj, field.name) for field in fields(config)}
        """
        Transformer decoder stack
            Word token embedding sequences are passed through
            a series of processing blocks before normalization
        """
        self.transformer = nn.ModuleDict(dict(
            # Embeddings + dropout
            wte = config.emb_class(config),
            # Backbone of Attn->FF blocks
            h = nn.ModuleList([
                Block(config, i) for i in range(config.n_layer)
            ]),

            # Output hidden state normalization
            #n_f = config.ln_class(config)
        ))
        """
        Language Modelling "head"
            maps the hidden state output by the transformer decoder stack
            to a token from its vocabulary
        """
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        """
        "Weight Tying"
            From Karpathy, which credits https://paperswithcode.com/method/weight-tying
            
            By making input embeddings and output decodings map within the same space, 
            the generation of a new token is more like a "movement" through this latent space.
            The word embeddings then must be clustered by frequency/semantic usage, with n_embd
            different potential similarities or differences(/nuances).
        """
        self.transformer.wte.wte.weight = self.lm_head.weight

        # Init all weights
        self.apply(self._init_weights)
        # Apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # Set up gradient checkpointing
        gradient_checkpointing_kwargs = {"use_reentrant": True}
        self.checkpoint_fn = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
        
        # Report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    """
    Forward method modified from Karpathy
    https://github.com/karpathy/nanoGPT/blob/master/model.py#L118

        Calculates cross entropy loss if targets are provided. Targets are just
        inputs shifted one to the left for foundational training (e.g. which token directly follows the input).
    """
    def forward(self, idx, targets = None):
        device = idx.device
        b, t = idx.size()

        hidden_state = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        for block in self.transformer.h:
            """
            when training, we want to use gradient checkpointing
            adds some extra compute overhead, but allows for much
            larger batch sizes by recomputing certain gradients
            during the backwards pass
            
            See:
            https://pytorch.org/docs/stable/checkpoint.html
            """
            if self.training: 
                hidden_state = self.checkpoint_fn(
                    block.__call__,
                    hidden_state
                )
            else:
                hidden_state = block(hidden_state)
            
        hidden_state = F.rms_norm(hidden_state,(hidden_state.size(-1),))

        if targets is not None:
            logits = self.lm_head(hidden_state)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # Inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(hidden_state[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

In [None]:
@dataclass
class Config:
    """
    Default config is a scaled-down version of LLama-3 for ~220M parameters
    """
    max_seq_len = 8192
    vocab_size = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer = 12
    n_head = 16
    n_kv_head = 8
    n_embd = 1024
    intermediate_size = 3584
    dropout = 0.0
    bias = False
    rotary_encoding = True
    rope_base = 100000
    
    # What component classes are we using?
    emb_class = WordTokenEmbeddings
    
    diffattn = True
    attn_class = [MHASelfAttention]#[GQASelfAttention, DifferentialAttention]
    
    ff_class = [MLP]#[GatedMLP]

In [None]:
"""
Download the dataset, tokenize, and create file "block" of tokens to grab
sections from for training.
Our total training length is relatively short - in 10k steps we might get 5e9/5 billion
tokens, or about half of the stored data. Our default weight decay is also very high, so
in general we can presume that every piece of text it sees is "new" and corresponds to
validation performance.

Lazy, and not good practice for wider-scope applications, but we're staying lean.
This is not language-modelling-to-save-the-world.

Downloads 10 billion tokens worth of text -- this takes several hours on gigabit
"""
# create a single-sourced chunk of data
#generate_training_data(dset = "HuggingFaceTB/smollm-corpus", dset_name = "cosmopedia-v2", tokens_to_save = 1e10)

# or, mix and match to incorporate more different types of text
generate_training_data_mix(dset = ["neuralwork/arxiver","mteb/raw_biorxiv","wikimedia/wikipedia","ccdv/pubmed-summarization", "HuggingFaceTB/smollm-corpus","togethercomputer/RedPajama-Data-1T","togethercomputer/RedPajama-Data-1T", "launch/gov_report"], 
                               dset_name = [None, None,"20231101.en","section", "fineweb-edu-dedup", "common_crawl", "stackexchange", "plain_text"], 
                               text_col = ["markdown","abstract","text","article", "text", "text", "text", "document"],
                               splits = ["train","train","train","train","train","train","train", "train"],
                               tokens_to_save = [2145236507,21967910,4630994150,483329429, 1e10, 4e9, 5e9, 2e8],
                               senc = "gpt2", num_proc = 1,
                               out_fname = "train.bin", overwrite = False)


## Training Loop
The training loop uses a set of global training configurations to train a language model using the "train.bin" file. In each step, the gradient is accumulated for a series of entries, and backpropagated through the model, changing the weights.

In [None]:
def train_model(model, seq_len = 1024, data_fname = "train.bin"):
    # Lock down training hyperparameters so they don't get changed on the fly
    LR = g.LR
    WD = g.WD
    MAX_STEPS = g.MAX_STEPS
    BATCH_S = g.BATCH_S
    GRAD_ACCUM_STEPS = g.GRAD_ACCUM_STEPS
    SHOW_GRADIENTS = g.SHOW_GRADIENT
    USING_WANDB = g.USE_WANDB
    CKPT_F = g.CHKPT_FREQ
    INIT_CKPT = g.loadchkpt
    print(f"Approx. {seq_len*BATCH_S*GRAD_ACCUM_STEPS} tokens per optimizer step")
    if USING_WANDB:
        wandb.init(project="OWTBot", name=f"LM_{model.get_num_params()//1e6}M", 
                   config={  'config':"",
                             'LR': LR,
                             'seq_len':seq_len,
                             'BatchSize':BATCH_S,
                             'Grad Accum': GRAD_ACCUM_STEPS,
                            })
    
    # Create AdamW optimizer and loss function
    # (0.9, 0.95) for adamW
    # Now, create AdEMAMix optimizer
    #optimizers, schedulers = model.configure_optimizers(weight_decay = WD,learning_rate = LR, use_adamw = True, betas=(0.9, 0.95), device_type=g.dev)
    optimizer = model.configure_optimizers(weight_decay = WD,learning_rate = LR, use_adamw = False, betas=(0.9, 0.95), device_type=g.dev)
    start = 0

    
    # Values and lists to carry through training
    losses = [] # Loss values after each step
    total_toks = [] # Total tokens trained on at each step
    total_tok_ct = 0 # Running total tokens
    
    # Attempt to load a previous checkpoint, if provided
    if len(INIT_CKPT) > 0:
        try:
            
            ckpt = torch.load(INIT_CKPT,map_location=torch.device('cpu'))
            if "model" in ckpt.keys():
                model.load_state_dict(ckpt['model'])
            else:
                print(f"No bundled model!!!\nYou'll be training a model from scratch, potentially using a different model's optimizer...")
            if "step" in ckpt.keys():
                start = ckpt['step']
            else:
                print(f"No bundled step count!!! Learning rate progress will be lost, and start from zero.")
            if "total_tokens" in ckpt.keys():
                total_tok_ct = ckpt['total_tokens']
            else:
                print(f"No bundled training token count!!! No downstream effects, but will restart from zero.")
            if "optimizer" in ckpt.keys():
                print("Loading checkpoint optimizer state")
                # load optimizer state to preserve gradient information
                optimizer.load_state_dicts(ckpt['optimizer'])
            else:
                print(f"No bundled optimizer!!!\nThis will likely cause issues in training, as the gradients and momentum are missing.")
        except Exception as e:
            print(f"Failed to load checkpoint {INIT_CKPT}, starting from scratch. {e}")
    
    loss_function = nn.CrossEntropyLoss(reduction="mean")
    
    # Get first batch
    ctxs,targ = get_batch(fname = data_fname, num = 0, block_size = seq_len, batch_size = GRAD_ACCUM_STEPS*BATCH_S, dev = g.device)
    

    if SHOW_GRADIENTS:
        # Display figure to update during training
        f,(ax1,ax2) = plt.subplots(2,1,figsize=(18,10))
    else:
        f,ax1 = plt.subplots(1,1,figsize=(18,5))
        
    dh = display.display(f, display_id=True)
    
    
    for ep in trange(start, MAX_STEPS):
        
        if ep % CKPT_F == 0 and ep > 0:
            checkpoint = { 
                'step': ep,
                'model': model.state_dict(),
                'optimizer':optimizer.get_state_dicts(),
                'total_tokens':total_tok_ct
            }
            torch.save(checkpoint, f"ckpt{ep}.pth")
            
        bloss = 0
        nb = 0
        toks = 0
    
        # Determine and set the learning rate for this iteration
        #lr = get_lr(ep,lr_decay_iters=6e5)
        #for param_group in optimizer.param_groups:
        #    param_group['lr'] = lr
        
        for _q in range(0,len(ctxs),BATCH_S):
            nb += 1
            log_probs,loss = model(ctxs[_q:_q+BATCH_S],targets=targ[_q:_q+BATCH_S])
            # accumulate gradient
            loss.backward()
            bloss += loss.detach().item()
        total_tok_ct += seq_len*len(ctxs)
        
        # Async load next batch while doing plot rendering
        ctxs,targ  = get_batch(fname = data_fname, num = ep+1, block_size = seq_len, batch_size = GRAD_ACCUM_STEPS*BATCH_S, dev = g.device)
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        
        if SHOW_GRADIENTS:
            if ep % 20 == 0:
                ax2.clear()
                plot_grad_flow(model.named_parameters(),f,ax2)
            
        # Keep track of the total number of tokens the model's seen    
        total_toks.append(total_tok_ct)
        
        # Take optimizer step with accumulated gradient
        optimizer.step()
        #for opt, sched in zip(optimizers, schedulers):
        #    opt.step()
        #    sched.step()
        model.zero_grad(set_to_none=True)
        losses.append(bloss/nb)

        if ep % 20 == 0:
            # Update figures
            ax1.clear()
            ax1.plot(total_toks,losses)
            ax1.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.2e'))
            plt.tight_layout()
            dh.update(f)
        
        if USING_WANDB:
            # W&B logging
            wandb.log({
                "step": ep,
                "total_toks":total_toks[-1],
                "train/loss": losses[-1],
                "lr": 0,
            })
    if USING_WANDB:
        wandb.finish()
        
    checkpoint = { 
        'step': MAX_STEPS,
        'model': model.state_dict(),
        'optimizer':optimizer.get_state_dicts(),
        'total_tokens':total_tok_ct
    }
    torch.save(checkpoint, f"ckpt{ep+1}.pth")

## Configurations
Some different potential configurations that define some different model architectures based on the blocks defined above.

In [None]:
"""
Default config produces a ~220M parameter model based on
the LLama-3 architecture
"""
conf_default = Config()

"""
Very small 135M model based on SmolLM-135M:
https://huggingface.co/HuggingFaceTB/SmolLM-135M
"""
conf_smol = Config()
conf_smol.n_layer = 30
conf_smol.n_embd = 512
conf_smol.n_head = 8
conf_smol.n_kv_head = 4
conf_smol.intermediate_size = 1408

"""
Very small 230 model based on SmolLM-135M:
https://huggingface.co/HuggingFaceTB/SmolLM-135M
"""
conf_smed = Config()
conf_smed.n_layer = 30
conf_smed.n_embd = 768
conf_smed.n_head = 12
conf_smed.n_kv_head = 6
#conf_smed.qk_nope_dim = 48
#conf_smed.qk_rope_dim = 16
conf_smed.intermediate_size = 2048
#conf_smed.attn_class = [SkipRopeAttention]

"""

"""
conf_med = Config()
conf_med.n_layer = 32
conf_med.n_embd = 960
conf_med.n_head = 15
conf_med.n_kv_head = 5
conf_med.qk_nope_dim = 48
conf_med.qk_rope_dim = 16
conf_med.intermediate_size = 2560

"""
Reproduction of LLama-3.2 1B, from model specific config
"""
conf_1B = Config()
conf_1B.n_layer = 16
conf_1B.n_embd = 2048
conf_1B.n_head = 32
conf_1B.n_kv_head = 8
conf_1B.intermediate_size = 8192
conf_1B.attn_class = [GQASelfAttention, DifferentialAttention]

"""
Reproduction of GPT-2 124M, based on default config in Karpathy's nanoGPT
"""
conf_gpt2 = Config()
conf_gpt2.n_layer = 12
conf_gpt2.n_head = 12
conf_gpt2.n_kv_head = 12
conf_gpt2.n_embd = 768
conf_gpt2.intermediate_size = 3072
conf_gpt2.rotary_encoding = False
conf_gpt2.emb_class = AbsolutePositionalEmbeddings
conf_gpt2.attn_class = [MHASelfAttention]
conf_gpt2.ff_class = [MLP]
conf_gpt2.bias = True



## Let 'er rip!
Here we (finally) 1. Modify the training configuration, 2. create the model, 3. tell PyTorch to "compile" it, to optimize the modules, and 4. perform the training!

In [None]:
SEQ_LEN = 1024
C = conf_smol

C.max_seq_len = SEQ_LEN
model = TransformerForCausalLM(C).to(g.device, dtype=g.dtype)
# output an architectural summary - n.b. the total number of parameters reported here
# will have double the amount of embedding params -- remember, we've linked those two sets
# of weights.
model_stats = summary(model, input_size=(1, SEQ_LEN), dtypes=[torch.long],depth=5)
g.estimate_batch_size_globals(model_stats, seq_len=SEQ_LEN)
print(str(model_stats))

In [None]:
g.train_control_setup()

In [None]:
# Optional -- compile model with torch
# Uses torch's jit system to create custom kernels for the model
# ideally, speeding up training and execution.
# Anecdotally, I've found it makes the first training step take significantly longer,
# but subsequent iterations are faster, leading to ~4 hours saved over 10k steps.
# https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
model = torch.compile(model)

In [None]:
torch.manual_seed(420691337)
np.random.seed(420691337)
train_model(model, seq_len = SEQ_LEN, data_fname = "train.bin")