In [1]:
# %pip install wandb
# %pip install datasets
# %pip install tokenizers
# %pip install -U transformers datasets tokenizers accelerate

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import math
import time

In [3]:
from torch.nn import RMSNorm
from torch.utils.data import DataLoader
from datasets import load_dataset
from dataclasses import dataclass
from tokenizers import Tokenizer
from transformers import AutoTokenizer
from pathlib import Path
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
@dataclass
class modelargs:
    block_size = 512
    batch_size = 32
    embeddings_dim = 512
    attn_dropout = 0.1
    vocab_size = 50257 #gpt2 bpe
    ignore_pad_token_in_loss = False #using packed blocks so padding never happens
    heads = 8 
    dropout = 0.1
    epochs = 1
    max_lr = 6e-4
    decoder_layers = 6
    weight_decay = 0.1
    beta1 = 0.9
    beta2 = 0.95
    clip=1.0
    experts = 8
    top_experts = 2
    use_shared_experts = True
    base_freq = 100000
    noisy_topk = False
    eps: float = 1e-8
    loss_scale = 0.03 
    use_aux_free_load_balancing = True
    aux_free_bias_update_rate = 0.001
    mtp_heads = 1
    latent_dim = 64

In [4]:
dataset = load_dataset('roneneldan/TinyStories')

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=True)
tokenizer.pad_token = tokenizer.eos_token #for batching

In [6]:
# vocab
vocab_size = tokenizer.vocab_size
vocab_size

50257

In [7]:
def encode(text: str):
    return tokenizer.encode(text, add_special_tokens = False)

In [5]:
artifacts = Path('artifacts')
artifacts.mkdir(exist_ok=True)

tokenized_dir = artifacts/'tokenized'
tokenized_dir.mkdir(exist_ok=True)

In [9]:
def chunked_tokenize(texts, chunk_size=10000):
    all_ids = []
    for i in tqdm(range(0, len(texts), chunk_size), desc='Tokenizing'):
        chunk = texts[i:i + chunk_size]
        chunk_text = '\n'.join(chunk)
        chunk_ids = encode(chunk_text)
        all_ids.extend(chunk_ids)
    return torch.tensor(all_ids, dtype=torch.long)

In [6]:
train_path = tokenized_dir / 'train_ids.pt'
val_path = tokenized_dir / 'val_ids.pt'

In [None]:
train_ids = chunked_tokenize(dataset['train']['text'])
val_ids = chunked_tokenize(dataset['validation']['text'])

torch.save(train_ids, train_path)
torch.save(val_ids, val_path)

Tokenizing:   0%|          | 0/212 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2162077 > 1024). Running this sequence through the model will result in indexing errors
Tokenizing: 100%|██████████| 212/212 [55:26<00:00, 15.69s/it]
Tokenizing: 100%|██████████| 3/3 [00:35<00:00, 11.79s/it]


In [7]:
train_ids = torch.load(train_path)
val_ids   = torch.load(val_path)

In [8]:
# converting flat token stream to blocks
class CausalDataset(torch.utils.data.Dataset):
    def __init__(self, tokens, block_size):
        self.tokens = tokens
        self.block_size = block_size

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx):
        x = self.tokens[idx : idx + self.block_size]
        y = self.tokens[idx + 1 : idx + self.block_size + 1]
        return x, y

In [9]:
# train val datasets
train_dataset = CausalDataset(train_ids, modelargs.block_size)
val_dataset   = CausalDataset(val_ids, modelargs.block_size)

In [10]:
train_loader = DataLoader(
    train_dataset,
    batch_size=modelargs.batch_size,
    shuffle=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=modelargs.batch_size,
    shuffle=False,
    drop_last=True
)

#### **SinusoidalPositionalEmbeddings**

In [11]:
def precompute_pos_embeddings(block_size=modelargs.block_size, embeddings_dim=modelargs.embeddings_dim, device='cuda'):
    pe = torch.zeros(block_size, embeddings_dim, device=device)
    position = torch.arange(0, block_size, dtype=torch.float, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embeddings_dim, 2, dtype=torch.float, device=device) * (-math.log(10000.0) / embeddings_dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0)  # (1, block_size, embeddings_dim)

In [12]:
def apply_pos_embeddings(x, pe):
    # x: (batch, seq_len, embeddings_dim)
    return x + pe[:, :x.size(1)]

#### **RMS Normalization**

In [13]:
class Normalization(nn.Module):
    def __init__(self, embeddings_dim: int = modelargs.embeddings_dim):
        super().__init__()
        self.rmsnorm_layer = RMSNorm(embeddings_dim)

    def forward(self, x):
        return self.rmsnorm_layer(x)

#### **Swish + SwiGLUExpertMoE**

In [14]:
def swish(x):
    return x * torch.sigmoid(x)


class SWiGLUExpert(nn.Module):
    def __init__(self, embeddings_dim=modelargs.embeddings_dim):
        super().__init__()
        self.hidden_dim = ((embeddings_dim * 2) * 4) // 3
        self.w1 = nn.Linear(embeddings_dim, self.hidden_dim, bias=False)
        self.w2 = nn.Linear(embeddings_dim, self.hidden_dim, bias=False)
        self.w3 = nn.Linear(self.hidden_dim, embeddings_dim, bias=False)

    def forward(self, x):
        return self.w3(swish(self.w1(x)) * self.w2(x))

#### **Mixture of Experts Layer**

In [15]:
class MoeLayer(nn.Module):
    def __init__(self, embeddings_dim=modelargs.embeddings_dim):
        super().__init__()

        self.experts = nn.ModuleList([SWiGLUExpert() for _ in range(modelargs.experts)])
        self.gate = nn.Linear(embeddings_dim, modelargs.experts, bias=False)

        if modelargs.use_shared_experts:
            self.shared_expert = SWiGLUExpert()
        else:
            self.shared_expert = None

        if modelargs.noisy_topk:
            self.noise = nn.Linear(embeddings_dim, modelargs.experts, bias=False)

        if modelargs.use_aux_free_load_balancing:
            self.register_buffer('routing_bias', torch.zeros(modelargs.experts))
            self.bias_update_speed = modelargs.aux_free_bias_update_rate

    def forward(self, x):
        gate_out = self.gate(x)  # (B, T, experts)

        # noisy top-k
        if modelargs.noisy_topk:
            noise = F.softplus(self.noise(x)) * torch.randn_like(gate_out)
            gate_out = gate_out + noise

        # aux-free bias
        if modelargs.use_aux_free_load_balancing:
            gate_out = gate_out + self.routing_bias

        # top-k selection
        top_k_values, top_k_indices = torch.topk(gate_out, k=modelargs.top_experts)

        # mask everything except top-k, then softmax
        masked = torch.full_like(gate_out, float('-1e20'))
        masked.scatter_(-1, top_k_indices, top_k_values)
        probs = F.softmax(masked, dim=-1)  # (B, T, experts)

        # shared expert
        shared_output = 0
        if self.shared_expert is not None:
            shared_output = self.shared_expert(x)

        # route tokens to experts
        out = torch.zeros_like(x)
        flat_x = x.view(-1, x.size(-1))

        for i in range(modelargs.experts):
            expert_mask = (top_k_indices == i).any(dim=-1)  # (B, T)

            if not expert_mask.any():
                continue

            flat_mask = expert_mask.reshape(-1)
            selected = flat_x[flat_mask]
            expert_out = self.experts[i](selected)

            weights = probs[:, :, i][expert_mask].unsqueeze(-1)
            weighted_out = expert_out * weights

            temp = torch.zeros_like(x)
            temp.masked_scatter_(expert_mask.unsqueeze(-1).expand_as(x), weighted_out)
            out = out + temp

        out = out + shared_output

        # update routing bias (aux-free load balancing)
        if modelargs.use_aux_free_load_balancing and self.training:
            with torch.no_grad():
                ci = probs.sum(dim=(0, 1))
                error_i = ci.mean() - ci
                self.routing_bias.add_(self.bias_update_speed * torch.sign(error_i))

        return out

# citation: opus 4.6

#### **Latent Attention**

In [16]:
class LatentAttention(nn.Module):
    def __init__(
        self,
        embeddings_dim=modelargs.embeddings_dim,
        attn_dropout=modelargs.attn_dropout,
    ):
        super().__init__()
        self.head_size = embeddings_dim // modelargs.heads
        self.latent_dim = modelargs.latent_dim

        # KV compression: embeddings_dim -> latent_dim -> head_size
        self.W_dkv = nn.Linear(embeddings_dim, self.latent_dim, bias=False)
        self.W_k = nn.Linear(self.latent_dim, self.head_size, bias=False)
        self.W_v = nn.Linear(self.latent_dim, self.head_size, bias=False)

        # query
        self.query = nn.Linear(embeddings_dim, self.head_size, bias=False)

        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, x, kv_cache=None, mask=None):
        B, T, D = x.shape

        # compress KV into latent
        latent = self.W_dkv(x)  # (B, T, latent_dim)

        # KV cache for inference
        if kv_cache is None:
            kv_cache = latent
        else:
            kv_cache = torch.cat([kv_cache, latent], dim=1)

        # decompress
        k = self.W_k(kv_cache)  # (B, T_cache, head_size)
        v = self.W_v(kv_cache)  # (B, T_cache, head_size)

        # absorbed query: fuse query projection with key decompression
        absorbed_q = self.query.weight.T @ self.W_k.weight  # (D, head_size) @ (head_size, latent_dim) -> not right
        # actually: query.weight is (head_size, D), W_k.weight is (head_size, latent_dim)
        # absorbed_q = query.weight.T @ W_k.weight = (D, head_size) @ (head_size, latent_dim) = (D, latent_dim)
        absorbed_q = self.query.weight.T @ self.W_k.weight  # (D, latent_dim)
        q_res = x @ absorbed_q  # (B, T, latent_dim)

        # attention
        weights = q_res @ kv_cache.transpose(-2, -1) * (self.head_size ** -0.5)  # (B, T, T_cache)

        if mask is not None:
            weights = weights.masked_fill(mask == 0, float('-1e20'))

        # causal mask
        causal = torch.tril(torch.ones(T, kv_cache.size(1), device=x.device))
        weights = weights.masked_fill(causal[:T, :kv_cache.size(1)] == 0, float('-1e20'))

        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)

        out = weights @ v  # (B, T, head_size)
        return out, kv_cache