In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

# Training a neural network

This tutorial will show how to set up neural network training in a stateful way.
Why stateful? Because we want to be able to pause training, save the model, and resume training later.
This is especially valuable when training expensive models on large datasets, where you don't want to lose training progress if something goes wrong.

## Generating the training objects

We will use a factory pattern to generate the training objects (e.g. model, optimizer, loss function, dataloader, etc).

In [2]:
from ml_templates.state_evolution.train_with_checkpoints import (
    state_factory, 
    read_yaml,
    run_training
)

The callable `state_factory` is the factory that generates stuff.
You tell it the type of object you want (as a `str`) as well as any hyperparameters.
It will then return an instance of that object.

Note that factories can have subfactories. That is actually what we have done here. Check out the `state_factory` module to see the implementation details.

Let's create a train state object:

In [3]:
# These hyperparameters uniquely specify a state.
state_hyperparams = dict(
    model_name             = "cnn",
    model_hyperparams      = dict(seed=0),
    dataloader_name        = "mnist",
    dataloader_hyperparams = dict(raw_data_dir="MNIST", batch_size=32),
    loss_name              = "mse",
    loss_hyperparams       = dict(),
    optimizer_name         = "adam",
    optimizer_hyperparams  = dict(learning_rate=1e-3)
)

# Use the factory to generate a state object.
state = state_factory.generate("state", state_hyperparams)
state

State(
  model=CNN(
    layers=[
      Conv2d(
        num_spatial_dims=2,
        weight=f32[32,1,3,3],
        bias=f32[32,1,1],
        in_channels=1,
        out_channels=32,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding=((0, 0), (0, 0)),
        dilation=(1, 1),
        groups=1,
        use_bias=True,
        padding_mode='ZEROS'
      ),
      <wrapped function relu>,
      MaxPool2d(
        init=-inf,
        operation=<function max>,
        num_spatial_dims=2,
        kernel_size=(2, 2),
        stride=(1, 1),
        padding=((0, 0), (0, 0)),
        use_ceil=False
      ),
      Conv2d(
        num_spatial_dims=2,
        weight=f32[64,32,3,3],
        bias=f32[64,1,1],
        in_channels=32,
        out_channels=64,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding=((0, 0), (0, 0)),
        dilation=(1, 1),
        groups=1,
        use_bias=True,
        padding_mode='ZEROS'
      ),
      <wrapped function relu>,
      MaxPool2d(
 

## Training the model

To keep track of training along the way, we need to use the `orbax.checkpoint` module to create a `ChackpointManager` object.
This object is passed into the training loop function.

Run the following cell, but manually interrupt it after 20 seconds to simulate a crash.

In [16]:
import orbax.checkpoint as ocp

# Read in the hyperparameters
hyperparams = read_yaml("hyperparams.yml")

# Set up the checkpoint manager
ckpt_dir = hyperparams['train']['checkpoint_directory']
ckpt_options = ocp.CheckpointManagerOptions(max_to_keep=5, enable_async_checkpointing=True)
checkpoint_manager = ocp.CheckpointManager(directory=ckpt_dir, options=ckpt_options)

# Do training
run_training(hyperparams=hyperparams, reset=True, checkpoint_manager=checkpoint_manager)

# Close the checkpoint manager
checkpoint_manager.close()

Epoch 1/3	Batch 20/1875	Loss: 1.79795	Total run time:  0 h  0 m  6.4 s	Time since last restart:  0 h  0 m  6.4 s
Epoch 1/3	Batch 40/1875	Loss: 0.73513	Total run time:  0 h  0 m 13.1 s	Time since last restart:  0 h  0 m 13.0 s
Epoch 1/3	Batch 60/1875	Loss: 0.41160	Total run time:  0 h  0 m 19.4 s	Time since last restart:  0 h  0 m 19.3 s
Epoch 1/3	Batch 80/1875	Loss: 0.31436	Total run time:  0 h  0 m 26.1 s	Time since last restart:  0 h  0 m 26.1 s
Epoch 1/3	Batch 100/1875	Loss: 0.25971	Total run time:  0 h  0 m 32.7 s	Time since last restart:  0 h  0 m 32.6 s
Epoch 1/3	Batch 120/1875	Loss: 0.18693	Total run time:  0 h  0 m 38.7 s	Time since last restart:  0 h  0 m 38.6 s
Epoch 1/3	Batch 140/1875	Loss: 0.16544	Total run time:  0 h  0 m 44.7 s	Time since last restart:  0 h  0 m 44.6 s
Epoch 1/3	Batch 160/1875	Loss: 0.18894	Total run time:  0 h  0 m 51.1 s	Time since last restart:  0 h  0 m 51.0 s
Epoch 1/3	Batch 180/1875	Loss: 0.15636	Total run time:  0 h  0 m 57.5 s	Time since last rest

KeyboardInterrupt: 

Now, try running the next cell to resume training from where it left off.
(You can also interrupt this cell after a 20 seconds.)

In [18]:
run_training(hyperparams=hyperparams, reset=False, checkpoint_manager=checkpoint_manager)

Epoch 1/3	Batch 220/1875	Loss: 0.14781	Total run time:  0 h  1 m 40.9 s	Time since last restart:  0 h  0 m  6.3 s
Epoch 1/3	Batch 240/1875	Loss: 0.15307	Total run time:  0 h  1 m 47.8 s	Time since last restart:  0 h  0 m 13.2 s


KeyboardInterrupt: 

Training continues from the last checkpoint!

## Saving and loading models

We can load the train state and/or the model from any checkpoint:

In [21]:
from ml_templates.state_evolution.train_with_checkpoints import load_state, load_model

# Loads the state from the latest checkpoint
latest_state = load_state(checkpoint_manager)

# Loads the state from a user-specified checkpoint
user_specified_step = checkpoint_manager.all_steps()[1]
another_state = load_state(checkpoint_manager, step=user_specified_step)

# Loads the model from the latest checkpoint
latest_model = load_model(checkpoint_manager)

# Loads the model from a user-specified checkpoint
another_model = load_model(checkpoint_manager, step=user_specified_step)

If we are done training, then we can save the latest model as the final model:

In [22]:
from ml_templates.state_evolution.train_with_checkpoints import save_final_model, load_final_model

# Save the final model
save_final_model(latest_state, hyperparams)

# Load the final model
final_model = load_final_model(hyperparams)

final_model == latest_state.model

Array(True, dtype=bool)