# MLP Contrained training

In [1]:
import json
import argparse
import jax
import numpy as np
import copy
from pathlib import Path

import os
from pathlib import Path
import sys

notebook_path = Path(os.getcwd())
parent_path = notebook_path.parent / 'min_train_MLP'
sys.path.append(str(parent_path))

from configs import parse_config_from_json
from data_loaders import get_data_loader
from models import create_model
from optimizers import get_optimizer
from trainer import Trainer
from utils import Logger, save_results

Specify some model parameters

In [2]:
Constrained_MLP_Config = {
    'optimizer': 'muon',
    'beta1': 0.9,
    'beta2': 0.95,
    'wd': 0,
    'spectral_wd': 0,
    'lr': 0.398,
    'output_dim': 10,
    'input_dim': 32 * 32 * 3,
    'd_embed': 256,
    'num_blocks': 3,
    'model_dtype': 'float32',
    'project_dtype': 'float32',
    'zero_init': True,
    'project': {'default': 'soft_cap'},
    'w_max': 6,
    'sensitive_to_wmax': {'default': True},
    'batch_size': 512,
    'data': 'cifar',
    'randomize_labels': False,
    'val_iters': 20,
    'val_interval': 98,
    'steps': 4900,
    'accum_steps': 1,
    'pre_dualize': False,
    'post_dualize': True,
    'log_interval': 14,
    'schedule': 'linear'
}

Unconstrained_MLP_Config = {
    'optimizer': 'adam',
    'beta1': 0.9,
    'beta2': 0.95,
    'wd': 0.08,
    'spectral_wd': 0,
    'lr': 0.0013,
    'output_dim': 10,
    'input_dim': 32 * 32 * 3,
    'd_embed': 256,
    'num_blocks': 3,
    'model_dtype': 'float32',
    'project_dtype': 'float32',
    'zero_init': True,
    'project': {'default': 'none'},
    'w_max': 1,
    'sensitive_to_wmax': {'default': False},
    'batch_size': 512,
    'data': 'cifar',
    'randomize_labels': False,
    'val_iters': 20,
    'val_interval': 98, 
    'steps': 4900,
    'accum_steps': 1,
    'pre_dualize': False,
    'post_dualize': False,
    'log_interval': 14,
    'schedule': 'linear'
}

constrained_config = parse_config_from_json(Constrained_MLP_Config)
unconstrained_config = parse_config_from_json(Unconstrained_MLP_Config)

In [3]:
#! Specify here which config you want to use!!!

config = unconstrained_config

Set up experiment and initalize components

In [4]:
np.random.seed(0)
key = jax.random.PRNGKey(0)

In [5]:
train_loader, val_loader, loss_fn = get_data_loader(config)
model = create_model(config)
model.jit()
optimizer = get_optimizer(config)
logger = Logger(config)

Initialize model and optimizer

In [6]:
key, subkey = jax.random.split(key)
params = model.initialize(subkey)
opt_state = optimizer.init_state(params)

Create trainer

In [7]:
trainer = Trainer(
    model = model,
    optimizer = optimizer,
    train_loader = train_loader,
    val_loader = val_loader,
    loss_fn = loss_fn,
    config = config,
    logger = logger,
)

Train model

In [8]:
params, opt_state, key = trainer.train(params, opt_state, key)

results = logger.get_results()

[18:34:10 gpu 1.3G ram 3.8G] Step:14/4900 train_loss:2.1249 train_acc:0.2090 ETA:04:19:07
[18:34:10 gpu 1.3G ram 3.8G] Step:28/4900 train_loss:1.9817 train_acc:0.2617 ETA:02:13:53
[18:34:10 gpu 1.3G ram 3.8G] Step:42/4900 train_loss:1.9068 train_acc:0.3203 ETA:01:30:12
[18:34:10 gpu 1.3G ram 3.8G] Step:56/4900 train_loss:1.8717 train_acc:0.2969 ETA:01:07:59
[18:34:10 gpu 1.3G ram 3.8G] Step:70/4900 train_loss:1.8679 train_acc:0.3105 ETA:00:54:31
[18:34:11 gpu 1.3G ram 3.8G] Step:84/4900 train_loss:1.8900 train_acc:0.2969 ETA:00:45:29
[18:34:11 gpu 1.3G ram 3.8G] Step:98/4900 train_loss:1.7900 train_acc:0.4141 ETA:00:39:02
  Step:98/4900 val_loss:1.7866 val_acc:0.3563
[18:34:11 gpu 1.3G ram 3.8G] Step:112/4900 train_loss:1.7676 train_acc:0.3516 ETA:00:34:19
[18:34:11 gpu 1.3G ram 3.8G] Step:126/4900 train_loss:1.8046 train_acc:0.3594 ETA:00:30:30
[18:34:11 gpu 1.3G ram 3.8G] Step:140/4900 train_loss:1.6607 train_acc:0.4102 ETA:00:27:27
[18:34:11 gpu 1.3G ram 3.8G] Step:154/4900 train_lo