# Training a Lipschitz constrained model

This notebook has two settings to choose from:
1. MLP on CIFAR-10
2. Transformer on Shakespeare text

Within these configs, you can set the optimizer (AdamW, Muon), the weight norm constraint method (none, spectral capping, spectral normalization), and other hyperparameters.

To train a 145M parameter transformer, check out the `/nanogpt` directory.

In [22]:
import jax
import numpy as np

from configs import parse_config_from_json
from data_loaders import get_data_loader
from models import create_model
from optimizers import get_optimizer
from trainer import Trainer
from utils import Logger

Specify the training setup. All the options are available in `configs.py`

In [23]:
cifar_mlp_muon_constrained = {
    'optimizer': 'muon',  # or adam
    'project': {'default': 'soft_cap'},  # specify per-layer (none, soft_cap, spec_normalize, etc.)
    'w_max': 6,
    'lr': 0.2,
    'beta1': 0.9,
    'beta2': 0.95,
    'wd': 0,
    'spectral_wd': 0,
    'input_dim': 32 * 32 * 3,
    'output_dim': 10,
    'd_embed': 256,
    'num_blocks': 3,
    'model_dtype': 'float32',
    'project_dtype': 'float32',
    'zero_init': True,
    'sensitive_to_wmax': {'default': True},  # False for spec_hammer
    'data': 'cifar',
    'randomize_labels': False,
    'val_iters': 20,
    'val_interval': 100,
    'batch_size': 512,
    'steps': 2000,
    'accum_steps': 1,
    'pre_dualize': False,
    'post_dualize': True,
    'log_interval': 50,
    'schedule': 'linear'
}

cifar_mlp_adam_unconstrained = {
    'optimizer': 'adam',
    'project': {'default': 'none'},
    'w_max': 1,
    'lr': 0.0013,
    'beta1': 0.9,
    'beta2': 0.95,
    'wd': 0.08,
    'spectral_wd': 0,
    'input_dim': 32 * 32 * 3,
    'output_dim': 10,
    'd_embed': 256,
    'num_blocks': 3,
    'model_dtype': 'float32',
    'project_dtype': 'float32',
    'zero_init': True,
    'sensitive_to_wmax': {'default': False},
    'data': 'cifar',
    'randomize_labels': False,
    'val_iters': 20,
    'val_interval': 100, 
    'batch_size': 512,
    'steps': 2000,
    'accum_steps': 1,
    'pre_dualize': False,
    'post_dualize': False,
    'log_interval': 50,
    'schedule': 'linear'
}

shakespeare_gpt_muon_constrained = {
    'optimizer': 'muon',  # or adam
    'project': {'default': 'soft_cap'},  # specify per-layer (none, soft_cap, spec_normalize, etc.)
    'w_max': 6,
    'lr': 0.1,
    'beta1': 0.9,
    'beta2': 0.95,
    'wd': 0,
    'spectral_wd': 0,
    'd_embed': 256,
    'seq_len': 256,
    'num_blocks': 3,
    'num_heads': 4,
    'softmax_scale': 1,
    'final_scale': 1,
    'residual_scale': 1,
    'scales_learnable': False,
    'blocks_mass': 16,
    'layernorm_substitute': 'none',  # no layer norm
    'max_embed_inflation_factor': 16,  # prevents embedding gradient columns from increasing too much under dualization
    'use_unembed': False,
    'model_dtype': 'float32',
    'project_dtype': 'float32',
    'zero_init': True,
    'sensitive_to_wmax': {'default': True},  # False for spec_hammer
    'data': 'shakespeare',
    'vocab_size': 65,
    'model': 'gpt',
    'randomize_labels': False,
    'val_iters': 20,
    'val_interval': 100,
    'batch_size': 512,
    'steps': 2000,
    'accum_steps': 1,
    'pre_dualize': False,
    'post_dualize': True,
    'log_interval': 1,
    'schedule': 'linear'
}

Specify here which config you want to use!

In [24]:
#config = shakespeare_gpt_constrained
config_dict = cifar_mlp_muon_constrained
# config = cifar_mlp_adam_unconstrained

config = parse_config_from_json(config_dict)

Set up experiment and initialize components

In [25]:
np.random.seed(0)
key = jax.random.PRNGKey(0)

In [26]:
train_loader, val_loader, loss_fn = get_data_loader(config)
model = create_model(config)
model.jit()
optimizer = get_optimizer(config)
logger = Logger(config)

Initialize model and optimizer

In [27]:
key, subkey = jax.random.split(key)
params = model.initialize(subkey)
opt_state = optimizer.init_state(params)

Create trainer

In [28]:
trainer = Trainer(
    model = model,
    optimizer = optimizer,
    train_loader = train_loader,
    val_loader = val_loader,
    loss_fn = loss_fn,
    config = config,
    logger = logger,
)

Let's train!

In [29]:
params, opt_state, key = trainer.train(params, opt_state, key)

results = logger.get_results()

[22:31:44 gpu -1.0G ram 2.7G] Step:50/2000 train_loss:1.6408 train_acc:0.4727 ETA:00:16:01
[22:32:07 gpu -1.0G ram 2.7G] Step:100/2000 train_loss:1.4056 train_acc:0.5332 ETA:00:15:05
  Step:100/2000 val_loss:1.4424 val_acc:0.4928
[22:32:34 gpu -1.0G ram 2.7G] Step:150/2000 train_loss:1.3619 train_acc:0.5859 ETA:00:15:13
[22:32:57 gpu -1.0G ram 2.7G] Step:200/2000 train_loss:1.1653 train_acc:0.6172 ETA:00:14:36
  Step:200/2000 val_loss:1.3741 val_acc:0.5083
[22:33:23 gpu -1.0G ram 2.7G] Step:250/2000 train_loss:1.2460 train_acc:0.6270 ETA:00:14:25
[22:33:47 gpu -1.0G ram 2.7G] Step:300/2000 train_loss:1.0298 train_acc:0.6934 ETA:00:13:51
  Step:300/2000 val_loss:1.3391 val_acc:0.5273
[22:34:13 gpu -1.0G ram 2.7G] Step:350/2000 train_loss:1.1002 train_acc:0.6406 ETA:00:13:36
[22:34:36 gpu -1.0G ram 2.7G] Step:400/2000 train_loss:0.9660 train_acc:0.7500 ETA:00:13:06
  Step:400/2000 val_loss:1.2954 val_acc:0.5447
[22:35:03 gpu -1.0G ram 2.7G] Step:450/2000 train_loss:1.0783 train_acc:0.691