# Guide 4: Research Projects with JAX

In [None]:
import os
import sys
from typing import Any, Sequence
import datetime
import json
from tqdm.auto import tqdm
import numpy as np
from copy import copy
from glob import glob
from collections import defaultdict
# JAX/Flax
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from flax.training import train_state, checkpoints
import optax
# Logging with Tensorboard or Weights and Biases
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

In [None]:
class TrainState(train_state.TrainState):
    # A simple extension of TrainState to also include batch statistics
    # If a model has no batch statistics, it is None
    batch_stats : Any = None
    # You can further extend the TrainState by any additional part here
    # For example, rng to keep for init, dropout, etc.
    rng : Any = None

In [None]:
class TrainerModule:

    def __init__(self, 
                 model_class : nn.Module,
                 model_hparams : dict,
                 optimizer_hparams : dict,
                 exmp_input : Any,
                 seed : int = 42,
                 logger_params : dict = None,
                 cluster : bool = False,
                 debug : bool = False,
                 check_val_every_n_epoch : int = 1,
                 **kwargs):
        super().__init__()
        self.model_class = model_class
        self.model_hparams = model_hparams
        self.optimizer_hparams = optimizer_hparams
        self.cluster = cluster
        self.debug = debug
        self.seed = seed
        self.check_val_every_n_epoch = check_val_every_n_epoch
        self.exmp_input = exmp_input
        self.config = {
            'model_class': model_class.__name__,
            'model_hparams': model_hparams,
            'optimizer_hparams': optimizer_hparams,
            'logger_params': logger_params,
            'cluster': self.cluster,
            'debug': self.debug,
            'seed': self.seed
        }
        self.config.update(kwargs)
        self.has_batch_norm = False
        # Create empty model. Note: no parameters yet
        self.model = self.model_class(**self.model_hparams)
        self.print_tabulate(exmp_input)
        # Init trainer parts
        self.init_logger(logger_params)
        self.create_jitted_functions()
        self.init_model(exmp_input)

    def init_logger(self, logger_params):
        if logger_params is None:
            logger_params = dict()
        
        base_log_dir = logger_params.get('base_log_dir', 'checkpoints/')
        # Prepare logging
        log_dir = os.path.join(base_log_dir, self.config["model_class"])
        if 'logger_name' in logger_params:
            log_dir = os.path.join(log_dir, logger_params['logger_name'])
        
        logger_type = logger_params.get('logger_type', 'TensorBoard').lower()
        if logger_type == 'tensorboard':
            self.logger = TensorBoardLogger(save_dir=log_dir, 
                                            name='')
        elif logger_type == 'wandb':
            self.logger = WandbLogger(name=logger_params.get('project_name', None),
                                      save_dir=log_dir, 
                                      config=self.config)
        else:
            assert False, f'Unknown logger type \"{logger_type}\"'
        log_dir = self.logger.log_dir
        os.makedirs(os.path.join(log_dir, 'metrics/'), exist_ok=True)
        with open(os.path.join(log_dir, 'hparams.json'), 'w') as f:
            json.dump(self.config, f, indent=4)
        self.log_dir = log_dir
    
    def create_jitted_functions(self):
        train_step, eval_step = self.create_functions()
        if self.debug:  # Skip jitting 
            print('Skipping jitting due to debug=True')
            self.train_step = train_step
            self.eval_step = eval_step
        else:
            self.train_step = jax.jit(train_step)
            self.eval_step = jax.jit(eval_step)

    def create_functions(self):
        def train_step(state, batch):
            metrics = {}
            return state, metrics
        def eval_step(state, batch):
            metrics = {}
            return metrics
        raise NotImplementedError

    def init_model(self, exmp_input):
        # Initialize model
        model_rng = random.PRNGKey(self.seed)
        model_rng, init_rng = random.split(model_rng)
        exmp_input = [exmp_input] if not isinstance(exmp_input, (list, tuple)) else exmp_input
        variables = self.run_model_init(exmp_input, init_rng)
        self.init_params = variables['params']
        self.init_batch_stats = variables.get('batch_stats')  # Returns none if no batch stats exist
        self.state = TrainState(step=0, 
                                apply_fn=self.model.apply,
                                params=variables['params'],
                                batch_stats=variables.get('batch_stats'),
                                rng=model_rng,
                                tx=None,
                                opt_state=None)

    def run_model_init(self, exmp_input, init_rng):
        return self.model.init(init_rng, *exmp_input, train=True)

    def print_tabulate(self, exmp_input):
        tabulate_fn = nn.tabulate(self.model, random.PRNGKey(0))
        print(tabulate_fn(*exmp_input, train=True))

    def init_optimizer(self, num_epochs, num_steps_per_epoch):
        hparams = copy(self.optimizer_hparams)

        # Initialize learning rate schedule and optimizer
        optimizer_name = hparams.pop('optimizer', 'adamw')
        if optimizer_name.lower() == 'adam':
            opt_class = optax.adam
        elif optimizer_name.lower() == 'adamw':
            opt_class = optax.adamw
        elif optimizer_name.lower() == 'sgd':
            opt_class = optax.sgd
        else:
            assert False, f'Unknown optimizer "{opt_class}"'
        
        lr = hparams.pop('lr', 1e-3)
        warmup = hparams.pop('warmup', 0)
        lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=0.0,
            peak_value=lr,
            warmup_steps=warmup,
            decay_steps=int(num_epochs * num_steps_per_epoch),
            end_value=0.01 * lr
        )
        # Clip gradients at max value, and evt. apply weight decay
        transf = [optax.clip_by_global_norm(hparams.pop('gradient_clip', 1.0))]
        if opt_class == optax.sgd and 'weight_decay' in hparams:  # wd is integrated in adamw
            transf.append(optax.add_decayed_weights(hparams.pop('weight_decay', 0.0)))
        optimizer = optax.chain(
            *transf,
            opt_class(lr_schedule, **hparams)
        )
        # Initialize training state
        self.state = TrainState.create(apply_fn=self.state.apply_fn,
                                       params=self.state.params,
                                       batch_stats=self.state.batch_stats,
                                       tx=optimizer,
                                       rng=self.state.rng)

    def train_model(self, train_loader, val_loader, test_loader=None, num_epochs=500):
        # Train model for defined number of epochs
        # We first need to create optimizer and the scheduler for the given number of epochs
        self.init_optimizer(num_epochs, len(train_loader))
        self.on_training_start()
        best_eval_metrics = None
        for epoch_idx in self.tracker(range(1, num_epochs+1), desc='Epochs'):
            train_metrics = self.train_epoch(train_loader)
            self.logger.log_metrics(train_metrics, step=epoch_idx)
            self.on_training_epoch_end(epoch_idx)
            if epoch_idx % self.check_val_every_n_epoch == 0:
                eval_metrics = self.eval_model(val_loader, log_prefix='val/')
                self.on_validation_epoch_end(epoch_idx, eval_metrics, val_loader)
                self.logger.log_metrics(eval_metrics, step=epoch_idx)
                self.save_metrics(f'eval_epoch_{str(epoch_idx).zfill(3)}', eval_metrics)
                if self.is_new_model_better(eval_metrics, best_eval_metrics):
                    best_eval_metrics = eval_metrics
                    self.save_model(step=epoch_idx)
                    self.save_metrics('best_eval', eval_metrics)
        if test_loader is not None:
            self.load_model()
            test_metrics = self.eval_model(test_loader, log_prefix='test/')
            self.logger.log_metrics(test_metrics, step=epoch_idx)
            self.save_metrics('test', test_metrics)
            best_eval_metrics.update(test_metrics)
        return best_eval_metrics

    def train_epoch(self, train_loader):
        # Train model for one epoch, and log avg loss and accuracy
        metrics = defaultdict(float)
        num_train_steps = len(train_loader)
        for batch in self.tracker(train_loader, desc='Training', leave=False):
            self.state, step_metrics = self.train_step(self.state, batch)
            for key in step_metrics:
                metrics['train/' + key] += step_metrics[key] / num_train_steps
        metrics = jax.device_get(metrics)
        return metrics

    def eval_model(self, data_loader, log_prefix=''):
        # Test model on all images of a data loader and return avg loss
        metrics = defaultdict(float)
        num_elements = 0
        for batch in data_loader:
            step_metrics = self.eval_step(self.state, batch)
            batch_size = batch[0].shape[0] if isinstance(batch, (list, tuple)) else batch.shape[0]
            for key in step_metrics:
                metrics[key] += step_metrics[key] * batch_size
            num_elements += batch_size
        metrics = {(log_prefix + key): (metrics[key] / num_elements).item() for key in metrics}
        return metrics

    def tracker(self, iterator, **kwargs):
        if not self.cluster:
            return tqdm(iterator, **kwargs)
        else:
            return iterator

    def is_new_model_better(self, new_metrics, old_metrics):
        if old_metrics is None:
            return True
        for key, is_larger in [('val/val_metric', False), ('val/acc', True), ('val/loss', False)]:
            if key in new_metrics:
                if is_larger:
                    return new_metrics[key] > old_metrics[key]
                else:
                    return new_metrics[key] < old_metrics[key]
        assert False, f'No known metrics to log on: {new_metrics}'

    def save_metrics(self, filename, metrics):
        with open(os.path.join(self.log_dir, f'metrics/{filename}.json'), 'w') as f:
            json.dump(metrics, f, indent=4)

    def on_training_start(self):
        pass

    def on_training_epoch_end(self, epoch_idx):
        pass

    def on_validation_epoch_end(self, epoch_idx, eval_metrics, val_loader):
        pass

    def save_model(self, step=0):
        # Save current model at certain training iteration
        checkpoints.save_checkpoint(ckpt_dir=self.log_dir,
                                    target={'params': self.state.params,
                                            'batch_stats': self.state.batch_stats},
                                    step=step,
                                    overwrite=True)

    def load_model(self):
        # Load model.
        state_dict = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=None)
        self.state = TrainState.create(apply_fn=self.model.apply,
                                       params=state_dict['params'],
                                       batch_stats=state_dict['batch_stats'],
                                       tx=self.state.tx if self.state else optax.sgd(0.1),   # Default optimizer
                                       rng=self.state.rng
                                      )


In [None]:
TrainState(step=0, apply_fn=None, params=None, tx=None, opt_state=None)

In [None]:
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.utils.data as data
import torch

DATASET_PATH = '../data/'
CHECKPOINT_PATH = '../saved_models/guide4/'

# Transformations applied on each image => bring them into a numpy array
DATA_MEANS = np.array([0.49139968, 0.48215841, 0.44653091])
DATA_STD = np.array([0.24703223, 0.24348513, 0.26158784])
def image_to_numpy(img):
    img = np.array(img, dtype=np.float32)
    img = (img / 255. - DATA_MEANS) / DATA_STD
    return img

# We need to stack the batch elements
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


test_transform = image_to_numpy
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      image_to_numpy
                                     ])
# Loading the training dataset. We need to split it into a training and validation part
# We need to do a little trick because the validation set should not use the augmentation.
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
train_set, _ = data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))
_, val_set = data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

# We define a set of data loaders that we can use for training and validation
train_loader = data.DataLoader(train_set,
                               batch_size=128,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=numpy_collate,
                               num_workers=8,
                               persistent_workers=True,
                               generator=torch.Generator().manual_seed(42))
val_loader   = data.DataLoader(val_set,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)
test_loader  = data.DataLoader(test_set,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)

In [None]:
class MLPClassifier(nn.Module):
    hidden_dims : Sequence[int]
    num_classes : int
    dropout_prob : float = 0.0
        
    @nn.compact
    def __call__(self, x, train=True):
        x = x.reshape(x.shape[0], -1)
        for dims in self.hidden_dims:
            x = nn.Dropout(self.dropout_prob)(x, deterministic=not train)
            x = nn.Dense(dims)(x)
            x = nn.BatchNorm()(x, use_running_average=not train)
            x = nn.swish(x)
        x = nn.Dropout(self.dropout_prob)(x, deterministic=not train)
        x = nn.Dense(self.num_classes)(x)
        return x

In [None]:
class MLPClassTrainer(TrainerModule):
    
    def __init__(self,
                 hidden_dims : Sequence[int],
                 num_classes : int,
                 dropout_prob : float,
                 trial : Any = None,
                 **kwargs):
        super().__init__(model_class=MLPClassifier,
                         model_hparams={
                             'hidden_dims': hidden_dims,
                             'num_classes': num_classes,
                             'dropout_prob': dropout_prob
                         },
                         **kwargs)
        self.trial = trial
    
    def create_functions(self):
        def loss_function(params, batch_stats, rng, batch, train):
            imgs, labels = batch
            labels_onehot = jax.nn.one_hot(labels, num_classes=self.model.num_classes)
            rng, dropout_rng = random.split(rng)
            output = self.model.apply({'params': params, 'batch_stats': batch_stats},
                                      imgs,
                                      train=train,
                                      rngs={'dropout': dropout_rng},
                                      mutable=['batch_stats'] if train else False)
            logits, new_model_state = output if train else (output, None)
            loss = optax.softmax_cross_entropy(logits, labels_onehot).mean()
            acc = (logits.argmax(axis=-1) == labels).mean()
            return loss, (rng, new_model_state, acc)
        
        def train_step(state, batch):
            loss_fn = lambda params: loss_function(params, state.batch_stats, state.rng, batch, train=True)
            ret, grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
            loss, rng, new_model_state, acc = ret[0], *ret[1]
            state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'], rng=rng)
            metrics = {'loss': loss, 'acc': acc}
            return state, metrics
        
        def eval_step(state, batch):
            _, (_, _, acc) = loss_function(state.params, state.batch_stats, state.rng, batch, train=False)
            return {'acc': acc}
        
        return train_step, eval_step
    
    def run_model_init(self, exmp_input, init_rng):
        imgs, _ = exmp_input
        init_rng, dropout_rng = random.split(init_rng)
        return self.model.init({'params': init_rng, 'dropout': dropout_rng}, x=imgs, train=True)
    
    def print_tabulate(self, exmp_input):
        imgs, _ = exmp_input
        print(self.model.tabulate(rngs={'params': random.PRNGKey(0), 'dropout': random.PRNGKey(0)}, x=imgs, train=True))
        
    def on_validation_epoch_end(self, epoch_idx, eval_metrics, val_loader):
        if self.trial:
            self.trial.report(eval_metrics['val/acc'], step=epoch_idx)
            if self.trial.should_prune():
                raise optuna.exceptions.TrialPruned()

In [None]:
trainer = MLPClassTrainer(hidden_dims=[512, 512],
                          num_classes=10,
                          dropout_prob=0.4,
                          optimizer_hparams={
                              'weight_decay': 2e-4
                          },
                          logger_params={
                              'base_log_dir': CHECKPOINT_PATH
                          },
                          exmp_input=next(iter(train_loader)),
                          check_val_every_n_epoch=5)

In [None]:
# metrics = trainer.train_model(train_loader, 
#                               val_loader, 
#                               test_loader=test_loader, 
#                               num_epochs=200)

In [None]:
# print(f'Validation accuracy: {metrics["val/acc"]:4.2%}')
# print(f'Test accuracy: {metrics["test/acc"]:4.2%}')

## Automatic hyperparameter tuning with Optuna

In [None]:
import optuna

In [None]:
def objective(trial):
    my_train_loader = data.DataLoader(train_set,
                                      batch_size=128,
                                      shuffle=True,
                                      drop_last=True,
                                      collate_fn=numpy_collate,
                                      num_workers=8,
                                      persistent_workers=True,
                                      generator=torch.Generator().manual_seed(42))
    my_val_loader = data.DataLoader(val_set,
                                    batch_size=128,
                                    shuffle=False,
                                    drop_last=False,
                                    collate_fn=numpy_collate,
                                    num_workers=4,
                                    persistent_workers=True)
    trainer = MLPClassTrainer(hidden_dims=[512, 512],
                              num_classes=10,
                              dropout_prob=trial.suggest_float('dropout_prob', 0, 0.6),
                              optimizer_hparams={
                                  'weight_decay': trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
                                  'lr': trial.suggest_float('lr', 1e-4, 1e-2, log=True)
                              },
                              logger_params={
                                  'base_log_dir': CHECKPOINT_PATH
                              },
                              exmp_input=next(iter(my_loader)),
                              check_val_every_n_epoch=5,
                              trial=trial)
    metrics = trainer.train_model(my_loader,
                                  my_val_loader,
                                  num_epochs=200)
    my_loader._iterator._shutdown_workers()
    my_val_loader._iterator._shutdown_workers()
    del trainer
    del my_loader, my_val_loader
    return metrics['val/acc']

In [None]:
study = optuna.create_study(
    study_name='mlp_cifar10',
    storage=f'sqlite:///{CHECKPOINT_PATH}/optuna_hparam_search.db',
    direction='maximize',
    pruner=optuna.pruners.MedianPruner(n_startup_trials=2, n_warmup_steps=50),
    load_if_exists=True
)
study.optimize(objective, n_trials=100, n_jobs=1)

In [None]:
trial = study.best_trial
print(f'Best Value: {trial.value}')
print(f'Best Params:')
for key, value in trial.params.items():
    print(f'-> {key}: {value}')

---

[![Star our repository](https://img.shields.io/static/v1.svg?logo=star&label=⭐&message=Star%20Our%20Repository&color=yellow)](https://github.com/phlippe/uvadlc_notebooks/)  If you found this tutorial helpful, consider ⭐-ing our repository.    
[![Ask questions](https://img.shields.io/static/v1.svg?logo=star&label=❔&message=Ask%20Questions&color=9cf)](https://github.com/phlippe/uvadlc_notebooks/issues)  For any questions, typos, or bugs that you found, please raise an issue on GitHub. 

---