In [2]:
from dataclasses import dataclass
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from train_gpt import GPT

@dataclass
class Hyperparameters:
    # data
    train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on
    val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on
    val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
    train_seq_len: int = 48*1024 # FlexAttention sequence length
    val_seq_len: int = 4*64*1024 # FlexAttention sequence length for validation
    # optimization
    num_iterations: int = 1770 # number of iterations to run
    cooldown_frac: float = 0.4 # fraction of training spent cooling down the learning rate
    # architecture
    vocab_size: int = 50257
    # evaluation and logging
    val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end
    save_checkpoint: bool = False
    wandb: bool = False
    # optimizer setup
    lr: float = 0.05  # this is the default for muon
    opt1: str = "adam"  # choices = ['adam', 'adamw_sn']
    optimizer: str = "muon"  # choices = ['muon', 'adamw_sn', 'adamw_snsm']
    beta1: float = 0.9  # momentum 
    beta2: float = 0.95  
    use_momentum_sched: bool = False 
    muon_momentum_warmup: bool = True     
    single_gpu: bool = True
    # scheduler
    scheduler: str = "default"  # choices: ['linear', 'default']. default is "constant then decay".
    warmup: int = 0  # specify the number of warmup steps
    final_rate: float = 0.1  # final lr decay fraction
    # adamw_snsm args
    rank: int = 256
    update_proj_gap: int = 200
    
    

args = Hyperparameters()

In [3]:
print("Constructing model")
model: nn.Module = GPT(vocab_size=args.vocab_size, num_layers=12, num_heads=6, model_dim=768,
        max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda()
for m in model.modules():
    if isinstance(m, nn.Embedding):
        m.bfloat16()

print("Constructing optimizers")
# collect the parameters to optimize
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [model.lm_head.weight]


Constructing model
Constructing optimizers


In [4]:
for n, p in model.blocks.named_parameters():
    if p.ndim == 2:
        print(f"\t{n}: shape = {p.shape}")
    if p.ndim > 2:
        print(f"{n}: shape = {p.shape}")

0.attn.qkv_w: shape = torch.Size([3, 768, 768])
	0.attn.c_proj.weight: shape = torch.Size([768, 768])
	0.mlp.c_fc.weight: shape = torch.Size([3072, 768])
	0.mlp.c_proj.weight: shape = torch.Size([768, 3072])
1.attn.qkv_w: shape = torch.Size([3, 768, 768])
	1.attn.c_proj.weight: shape = torch.Size([768, 768])
	1.mlp.c_fc.weight: shape = torch.Size([3072, 768])
	1.mlp.c_proj.weight: shape = torch.Size([768, 3072])
2.attn.qkv_w: shape = torch.Size([3, 768, 768])
	2.attn.c_proj.weight: shape = torch.Size([768, 768])
	2.mlp.c_fc.weight: shape = torch.Size([3072, 768])
	2.mlp.c_proj.weight: shape = torch.Size([768, 3072])
3.attn.qkv_w: shape = torch.Size([3, 768, 768])
	3.attn.c_proj.weight: shape = torch.Size([768, 768])
	3.mlp.c_fc.weight: shape = torch.Size([3072, 768])
	3.mlp.c_proj.weight: shape = torch.Size([768, 3072])
4.attn.qkv_w: shape = torch.Size([3, 768, 768])
	4.attn.c_proj.weight: shape = torch.Size([768, 768])
	4.mlp.c_fc.weight: shape = torch.Size([3072, 768])
	4.mlp.c_proj.