In [219]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np  # reserved for later use
import einops
import time
import wandb
import os
from dotenv import load_dotenv
load_dotenv()

try:
    from muon import SingleDeviceMuonWithAuxAdam
    MUON_AVAILABLE = True
except ImportError:
    MUON_AVAILABLE = False
    print("Muon not available. Install with: pip install git+https://github.com/KellerJordan/Muon")

In [220]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

In [221]:
class Config():
    # FIXED PARAMS
    vocab_size = 50257  # tiktoken vocab size for gpt2
    
    # CHANGEABLE MODEL ARCHITECTURE PARAMS
    embedding_dim = 720
    num_blocks = 2
    num_heads = 2
    context_len = 1024
    
    # TRAINING HYPERPARAMS
    num_epochs = 1  # generally should keep this to 1
    lr = 3e-4
    muon_lr = 0.02 # from muon repo: "only the lr and weight decay have to be tuned"
    muon_momentum = 0.95
    betas = (0.9, 0.95)  # for controlling momentum
    eps = 1e-8
    weight_decay = 0.01
    use_muon = True  # whether to use Muon optimizer for hidden layers
    
    # LOGGING & OBSERVABILITY
    wandb_enabled = True
    print_per_layer_params = False
    
    def __init__(self, **kwargs):
        # Set all class attributes as instance attributes first
        for key in dir(self.__class__):
            if not key.startswith('_') and not callable(getattr(self.__class__, key)):
                setattr(self, key, getattr(self.__class__, key))
        
        # Override with any provided kwargs
        for key, value in kwargs.items():
            if hasattr(self.__class__, key):
                setattr(self, key, value)
            else:
                raise ValueError(f"Unknown config parameter: {key}")
        
        # Derived params that depend on other params
        self.mlp_dim = 2 * self.embedding_dim
        self.attention_dim = self.embedding_dim // self.num_heads
        
        # Device detection
        self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

        # COUNTING PARAMETERS
        # learnable params
        self.learnable_params_dict = {
            "embedding": self.vocab_size * self.embedding_dim, 
            "positional_embedding": self.context_len * self.embedding_dim, 
            "MLPs (Weights)": self.num_blocks * 2 * self.embedding_dim * self.mlp_dim, 
            "MLPs (Biases)": self.num_blocks * (self.mlp_dim + self.embedding_dim), 
            "W_Qs": self.num_blocks * self.embedding_dim * self.embedding_dim, 
            "W_Ks": self.num_blocks * self.embedding_dim * self.embedding_dim, 
            "W_Vs": self.num_blocks * self.embedding_dim * self.embedding_dim, 
            "W_Out": self.num_blocks * self.embedding_dim * self.embedding_dim}
        self.learnable_params = (lambda d: sum(d.values()))(self.learnable_params_dict)

        # non-learnable (fixed) params
        self.non_learnable_params_dict = {"deembedding (tied to embedding weights)": self.vocab_size * self.embedding_dim}
        self.non_learnable_params = (lambda d: sum(d.values()))(self.non_learnable_params_dict)
    
    def display_config(self, extended):
        if extended:
            print(f"learnable params dict: {config.learnable_params_dict}")
            print(f"total # of learnable params: {config.learnable_params:,}")
            print(f"non-learnable params dict: {config.non_learnable_params_dict}")
            print(f"total # of non-learnable params: {config.non_learnable_params:,}")
            print(f"** total # of params: {(config.learnable_params + config.non_learnable_params):,} **")
            print("*"*50)
        else:
            print(f"learnable params: {self.learnable_params:,}, non-learnable params: {self.non_learnable_params:,}, total params: {self.learnable_params + self.non_learnable_params:,}")

# config = Config()
# config.display_config(extended=True)

In [222]:
class DataLoader:
    def __init__(self, B, T):
        self.batch_size = B # num of sequences processed together in each batch
        self.seq_len = T # how many tokens are in each sequence/batch
    
        with open("tiny_shakespeare.txt", "r") as f:
            text = f.read()
        
        encoding = tokenizer.encode(text)
        self.tokens = torch.tensor(encoding)

        self.current_pos = 0 # maintain the index of the current data sample

        print(f"loaded {len(self.tokens)} tokens with batch size of {self.batch_size} sequences and {self.seq_len} tokens per sequence in the batch")
        print(f"each epoch has {len(self.tokens) / (self.batch_size * self.seq_len)} batches, with {self.seq_len * self.batch_size} tokens per batch, for a total of {self.seq_len * self.batch_size * (len(self.tokens) / (self.batch_size * self.seq_len))} tokens")
        print("*"*50)
        
    def next_batch(self):
        B, T = self.batch_size, self.seq_len
        x = self.tokens[self.current_pos:self.current_pos+B*T]
        y = self.tokens[self.current_pos+1:self.current_pos+B*T+1]
        x = x.view(B, T)
        y = y.view(B, T)
        self.current_pos += B * T

        if (len(self.tokens) - self.current_pos + 1) < B * T:
            self.current_pos = 0

        return x, y

In [223]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_E = nn.Embedding(config.vocab_size, config.embedding_dim)
        self.W_pos = nn.Embedding(config.context_len, config.embedding_dim)

    def forward(self, tokens):
        # Don't convert to tensor if already a tensor
        if not isinstance(tokens, torch.Tensor):
            tokens = torch.tensor(tokens)
        
        embeddings = self.W_E(tokens)

        # Create positions tensor on the same device as tokens
        positions = torch.arange(tokens.shape[1], device=tokens.device)
        position_embeddings = self.W_pos(positions)

        return embeddings + position_embeddings

class DeEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_D = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)

    def forward(self, x):
        embeddings = self.W_D(x)

        return embeddings

In [224]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.W_Q = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        self.W_K = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        self.W_V = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        
        self.W_out = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)

        self.num_heads = config.num_heads
        self.attention_dim = config.attention_dim
        
        # Register causal mask as buffer
        self.register_buffer("causal_mask", torch.tril(torch.ones(config.context_len, config.context_len)))
    
    def forward(self, x):
        B, T, C = x.shape
        
        Q = einops.rearrange(self.W_Q(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)
        K = einops.rearrange(self.W_K(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)
        V = einops.rearrange(self.W_V(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)

        # Calculate attention scores
        scores = (Q @ K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.attention_dim))
        
        # Apply causal mask
        scores = scores.masked_fill(self.causal_mask[:T, :T] == 0, float('-inf'))
        
        # Apply softmax
        QK = torch.softmax(scores, dim=-1)

        QKV = einops.rearrange(QK @ V, 'batch head seq dim -> batch seq (head dim)', head=self.num_heads)
        QKV_Out = self.W_out(QKV)

        return QKV_Out

In [225]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer1 = nn.Linear(config.embedding_dim, config.mlp_dim)
        self.gelu = nn.GELU()
        self.layer2 = nn.Linear(config.mlp_dim, config.embedding_dim)
    
    def forward(self, x):
        x = self.layer2(self.gelu(self.layer1(x)))

        return x

In [226]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.Attention_Layers = Attention(config)
        self.MLP_Layers = MLP(config)

    def forward(self, x):
        x = x + self.Attention_Layers(x)
        return x + self.MLP_Layers(x)

In [227]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embed = Embedding(config)

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for i in range(config.num_blocks)
        ])

        self.deembed = DeEmbedding(config)
        self.deembed.W_D.weight = self.embed.W_E.weight  # tie weights

    def forward(self, x):

        x = self.embed(x)
        # print(f"after embed: {x}")

        for block in self.blocks:
            x = block(x)
        
        x = self.deembed(x)
        # print(f"after deembed: {x}")
        x = torch.softmax(x, dim=-1)

        return x

In [228]:
def inference(inference_config, inference_model, text="They fear us"):
    tokens = tokenizer.encode(text)
    x = torch.tensor(tokens)
    x = x.unsqueeze(0)

    print("*"*50)
    
    # Move input tensor to the same device as the model
    if inference_model:
        print("using passed in model for inference")
        device = next(inference_model.parameters()).device

    else:
        inference_model = Transformer(inference_config)
        print("using random model for inference")
        device = inference_config.device
    
    x = x.to(device)
    out = inference_model(x)
    pred_tokens = out.argmax(dim=-1)
    print(f"predicted tokens: {pred_tokens}")

    # Only take the prediction from the last position (next token after "fox")
    next_token = pred_tokens[0, -1].item()
    predicted_word = tokenizer.decode([next_token])
    print(f"predicted word: {predicted_word}")

    print(f"full sentence: {text}{predicted_word}")

    print("*"*50)
    print("sanity check: all predicted tokens")
    for num, token in enumerate(pred_tokens.flatten()):
        decoded = tokenizer.decode([token.item()])
        
        if num == (len(pred_tokens.flatten()) - 1):
            print(f"** Token {token} -> '{decoded}' **")
        else:
            print(f"Token {token} -> '{decoded}'")

def training(model_config):
    device = model_config.device
    print(f"Using device: {device}")

    torch.manual_seed(42)

    if device == "cuda":
        torch.cuda.manual_seed(42)
    
    if config.wandb_enabled:
        # Initialize wandb
        wandb.init(
            project=os.getenv("WANDB_PROJECT"),
            entity=os.getenv("WANDB_ENTITY"),
            config={
                "vocab_size": model_config.vocab_size,
                "embedding_dim": model_config.embedding_dim,
                "mlp_dim": model_config.mlp_dim,
                "num_blocks": model_config.num_blocks,
                "num_heads": model_config.num_heads,
                "context_len": model_config.context_len,
                "attention_dim": model_config.attention_dim,
                "device": model_config.device,
                "num_epochs": model_config.num_epochs,
                "lr": model_config.lr,
                "betas": model_config.betas,
                "eps": model_config.eps,
                "weight_decay": model_config.weight_decay,
                "use_muon": model_config.use_muon,
                "muon_lr": model_config.muon_lr,
                "muon_momentum": model_config.muon_momentum,
                "learnable_params": model_config.learnable_params,
                "total_params": model_config.learnable_params + model_config.non_learnable_params
            }
        )

    train_loader = DataLoader(8, 1024)

    model = Transformer(model_config).to(device)
    # model = torch.compile(model) # temporary comment to resolve errors with metal
    losses = []
    
    # Create optimizer based on config
    if model_config.use_muon and MUON_AVAILABLE:
        print("Using Muon optimizer for hidden layers")
        
        # Separate parameters for Muon optimization
        hidden_weights = []
        other_params = []
        
        for name, param in model.named_parameters():
            # Hidden weights are 2D+ parameters in transformer blocks
            if param.ndim >= 2 and any(block_name in name for block_name in ['blocks',
        'Attention_Layers', 'MLP_Layers']):
                hidden_weights.append(param)
                if model_config.print_per_layer_params:
                    print(f"  MUON: {name} - shape: {param.shape} - params: {param.numel():,}")
            else:
                # Embeddings, biases, and other parameters
                other_params.append(param)
                if model_config.print_per_layer_params:
                    print(f"  ADAMW: {name} - shape: {param.shape} - params: {param.numel():,}")

        print(f"Muon params: {sum(p.numel() for p in hidden_weights):,}")
        print(f"AdamW params: {sum(p.numel() for p in other_params):,}")
        print(f"Total: {model_config.non_learnable_params + sum(p.numel() for p in hidden_weights) + sum(p.numel() for p in 
        other_params):,}")

        print("*"*50)
        
        param_groups = [
            {
                'params': hidden_weights, 
                'use_muon': True, 
                'lr': model_config.muon_lr, 
                'weight_decay': model_config.weight_decay,
                'momentum': model_config.muon_momentum
            },
            {
                'params': other_params, 
                'use_muon': False, 
                'lr': model_config.lr,
                'weight_decay': model_config.weight_decay,
                'betas': model_config.betas,
                'eps': model_config.eps
            }
        ]
        optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
    else:
        if model_config.use_muon and not MUON_AVAILABLE:
            print("Muon requested but not available, falling back to AdamW")
        else:
            print("Using AdamW optimizer")
        optimizer = torch.optim.AdamW(model.parameters(), lr=model_config.lr, 
                                     betas=model_config.betas, eps=model_config.eps, 
                                     weight_decay=model_config.weight_decay)

    for epoch in range(model_config.num_epochs):
        print(f"Epoch {epoch + 1}/{model_config.num_epochs}")

        # Reset data loader position at the start of each epoch
        train_loader.current_pos = 0

        num_batches = int(len(train_loader.tokens) / (train_loader.batch_size * train_loader.seq_len))

        for batch in range(num_batches):
            x, y = train_loader.next_batch()
            x, y = x.to(device), y.to(device)

            # Forward pass
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if config.wandb_enabled:
                # Log to wandb
                wandb.log({
                    "loss": loss.item(),
                    "epoch": epoch + 1,
                    "batch": batch,
                    "step": epoch * num_batches + batch
                })

            if batch % 1 == 0:
                print(f"Batch {batch}/{num_batches}, Loss: {loss.item():.4f}")

    if config.wandb_enabled:
        wandb.finish()
        
    return model

In [229]:
if __name__ == "__main__":
    config = Config(wandb_enabled=False)
    config.display_config(extended=True)
    print(f"theoretical start loss: {np.log(config.vocab_size)}")
    model = training(config)
    inference(config, model)

learnable params dict: {'embedding': 36185040, 'positional_embedding': 737280, 'MLPs (Weights)': 4147200, 'MLPs (Biases)': 4320, 'W_Qs': 1036800, 'W_Ks': 1036800, 'W_Vs': 1036800, 'W_Out': 1036800}
total # of learnable params: 45,221,040
non-learnable params dict: {'deembedding (tied to embedding weights)': 36185040}
total # of non-learnable params: 36,185,040
** total # of params: 81,406,080 **
**************************************************
theoretical start loss: 10.82490511970208
Using device: mps
loaded 338025 tokens with batch size of 8 sequences and 1024 tokens per sequence in the batch
each epoch has 41.2628173828125 batches, with 8192 tokens per batch, for a total of 338025.0 tokens
**************************************************
Using Muon optimizer for hidden layers
Muon params: 8,294,400
AdamW params: 36,926,640
Total: 81,406,080
**************************************************
Epoch 1/1
Batch 0/41, Loss: 10.8006
Batch 1/41, Loss: 10.8043
Batch 2/41, Loss: 10.7966
B

In [230]:
inference(config, model, text="The french President is")

**************************************************
using passed in model for inference
predicted tokens: tensor([[  464, 48718,  1992,   318]], device='mps:0')
predicted word:  is
full sentence: The french President is is
**************************************************
sanity check: all predicted tokens
Token 464 -> 'The'
Token 48718 -> ' french'
Token 1992 -> ' President'
** Token 318 -> ' is' **
