# Training a Lipschitz constrained model

This notebook has two settings to choose from:
1. MLP on CIFAR-10
2. Transformer on Shakespeare text

Within these configs, you can set the optimizer (AdamW, Muon), the weight norm constraint method (none, spectral capping, spectral normalization), and other hyperparameters.

To train a 145M parameter transformer, check out the `/nanogpt` directory.

In [None]:
import jax
import numpy as np

from configs import parse_config_from_json
from data_loaders import get_data_loader
from models import create_model
from optimizers import get_optimizer
from trainer import Trainer
from utils import Logger

Specify the training setup. All the options are available in `configs.py`

In [None]:
cifar_mlp_muon_constrained = {
    'optimizer': 'muon',  # or adam
    'project': {'default': 'soft_cap'},  # specify per-layer (none, soft_cap, spec_normalize, etc.)
    'w_max': 6,
    'lr': 0.2,
    'beta1': 0.9,
    'beta2': 0.95,
    'wd': 0,
    'spectral_wd': 0,
    'input_dim': 32 * 32 * 3,
    'output_dim': 10,
    'd_embed': 256,
    'num_blocks': 3,
    'model_dtype': 'float32',
    'project_dtype': 'float32',
    'zero_init': True,
    'sensitive_to_wmax': {'default': True},  # False for spec_hammer
    'data': 'cifar',
    'randomize_labels': False,
    'val_iters': 20,
    'val_interval': 100,
    'batch_size': 512,
    'steps': 2000,
    'accum_steps': 1,
    'pre_dualize': False,
    'post_dualize': True,
    'log_interval': 50,
    'schedule': 'linear'
}

cifar_mlp_adam_unconstrained = {
    'optimizer': 'adam',
    'project': {'default': 'none'},
    'w_max': 1,
    'lr': 0.0013,
    'beta1': 0.9,
    'beta2': 0.95,
    'wd': 0.08,
    'spectral_wd': 0,
    'input_dim': 32 * 32 * 3,
    'output_dim': 10,
    'd_embed': 256,
    'num_blocks': 3,
    'model_dtype': 'float32',
    'project_dtype': 'float32',
    'zero_init': True,
    'sensitive_to_wmax': {'default': False},
    'data': 'cifar',
    'randomize_labels': False,
    'val_iters': 20,
    'val_interval': 100, 
    'batch_size': 512,
    'steps': 2000,
    'accum_steps': 1,
    'pre_dualize': False,
    'post_dualize': False,
    'log_interval': 50,
    'schedule': 'linear'
}

shakespeare_gpt_muon_constrained = {
    'optimizer': 'muon',  # or adam
    'project': {'default': 'soft_cap'},  # specify per-layer (none, soft_cap, spec_normalize, etc.)
    'w_max': 6,
    'lr': 0.1,
    'beta1': 0.9,
    'beta2': 0.95,
    'wd': 0,
    'spectral_wd': 0,
    'd_embed': 256,
    'seq_len': 256,
    'num_blocks': 3,
    'num_heads': 4,
    'softmax_scale': 1,
    'final_scale': 1,
    'residual_scale': 1,
    'scales_learnable': False,
    'blocks_mass': 16,
    'layernorm_substitute': 'none',  # no layer norm
    'max_embed_inflation_factor': 16,  # prevents embedding gradient columns from increasing too much under dualization
    'use_unembed': False,
    'model_dtype': 'float32',
    'project_dtype': 'float32',
    'zero_init': True,
    'sensitive_to_wmax': {'default': True},  # False for spec_hammer
    'data': 'shakespeare',
    'vocab_size': 65,
    'model': 'gpt',
    'randomize_labels': False,
    'val_iters': 20,
    'val_interval': 100,
    'batch_size': 512,
    'steps': 2000,
    'accum_steps': 1,
    'pre_dualize': False,
    'post_dualize': True,
    'log_interval': 1,
    'schedule': 'linear'
}

cifar_mlp_constrained_config = parse_config_from_json(cifar_mlp_muon_constrained)
cifar_mlp_unconstrained_config = parse_config_from_json(cifar_mlp_adam_unconstrained)
shakespeare_gpt_constrained_config = parse_config_from_json(shakespeare_gpt_muon_constrained)

Specify here which config you want to use!

In [35]:
config = shakespeare_gpt_constrained_config
# config = cifar_mlp_constrained_config
# config = cifar_mlp_unconstrained_config

Set up experiment and initialize components

In [36]:
np.random.seed(0)
key = jax.random.PRNGKey(0)

In [37]:
train_loader, val_loader, loss_fn = get_data_loader(config)
model = create_model(config)
model.jit()
optimizer = get_optimizer(config)
logger = Logger(config)

Initialize model and optimizer

In [38]:
key, subkey = jax.random.split(key)
params = model.initialize(subkey)
opt_state = optimizer.init_state(params)

Create trainer

In [39]:
trainer = Trainer(
    model = model,
    optimizer = optimizer,
    train_loader = train_loader,
    val_loader = val_loader,
    loss_fn = loss_fn,
    config = config,
    logger = logger,
)

Let's train!

In [None]:
params, opt_state, key = trainer.train(params, opt_state, key)

results = logger.get_results()