### Goal
This code example shows you how to save intermediate model checkpoints and use these during training. Why is this useful? Say the training loss suddenly spikes during training. You may be tempted to restart training with a lower learning rate. However, it is more efficient to intermittently save the best-performing model during training - called a checkpoint. Now whenever the training loss suddenly increases, you can load the model checkpoint, decrease the learning rate, and continue training.

In [1]:
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms

import devtorch
%load_ext autoreload
%autoreload 2

In [2]:
class ANNClassifier(devtorch.DevModel):
    
    def __init__(self, n_in, n_hidden, n_out):
        super().__init__()
        self._n_in = n_in
        self._n_hidden = n_hidden
        self._n_out = n_out
        self.layer1 = nn.Linear(n_in, n_hidden, bias=False)
        self.layer2 = nn.Linear(n_hidden, n_out, bias=False)
        self.init_weight(self.layer1.weight, "glorot_uniform")
        self.init_weight(self.layer2.weight, "glorot_uniform")
    
    @property
    def hyperparams(self):
        return {**super().hyperparams, "params": {"n_in": self._n_in, "n_hidden": self._n_hidden, "n_out": self._n_out}}
    
    def forward(self, x):
        x = F.leaky_relu(self.layer1(x.flatten(1, 3)))
        return F.leaky_relu(self.layer2(x))

In [3]:
class CheckpointTrainer(devtorch.Trainer):
    
    def __init__(self, root, model_id, model, train_dataset, n_epochs=100, batch_size=128, lr=0.001, device="cuda"):
        super().__init__(root=root, id=model_id, model=model, train_dataset=train_dataset, n_epochs=n_epochs, batch_size=batch_size, lr=lr, device=device)
        self._min_loss = np.inf
    
    @staticmethod
    def load_model(root, model_id):
        
        def model_loader(hyperparams):
            return ANNClassifier(**hyperparams["model"]["params"])
        
        return devtorch.load_model(root, model_id, model_loader)
    
    def loss(self, output, target, model):
        return F.cross_entropy(output, target.long())
    
    # Here we overwrite the on_epoch_complete hook
    def on_epoch_complete(self, save):
        train_loss = self.log["train_loss"][-1]
        
        # If a new minimum loss was achieved we save the model
        # otherwise if the loss spikes more than 5% compared to its minimum then we load
        # load the checkpoint and reduce the learning rate
        if train_loss < self._min_loss:
            print(f"Saving checkpoint train_loss={train_loss:.4f} < min_loss={self._min_loss:.4f}.")
            self._min_loss = train_loss
            self.save_model()
        elif train_loss > 1.05 * self._min_loss:
            print("=========> Loading checkpoint and decaying lr <=========")
            self.lr *= 0.1
            self.model = CheckpointTrainer.load_model(self.root, self.id)
            self.optimizer = self.optimizer_func(self.model.parameters(), self.lr, **self.optimizer_kwargs)
            

In [4]:
root = "../../data"  # where to save the checkpoint
model_id = "ann"  # the name of the checkpoint - if not is provided devtorch auto generates this.
model = ANNClassifier(784, 2000, 10)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST("../../data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("../../data", train=False, download=True, transform=transform)

trainer = CheckpointTrainer(root, model_id, model, train_dataset, n_epochs=10, batch_size=128, lr=0.01, device="cuda")
trainer.train(save=True)

INFO:trainer:Completed epoch 0 with loss 408.0909495726228 in 7.5060s
Saving checkpoint train_loss=408.0909 < min_loss=inf.
INFO:trainer:Completed epoch 1 with loss 92.43864496052265 in 7.4333s
Saving checkpoint train_loss=92.4386 < min_loss=408.0909.
INFO:trainer:Completed epoch 2 with loss 56.715450895018876 in 7.4289s
Saving checkpoint train_loss=56.7155 < min_loss=92.4386.
INFO:trainer:Completed epoch 3 with loss 45.24590137088671 in 7.4299s
Saving checkpoint train_loss=45.2459 < min_loss=56.7155.
INFO:trainer:Completed epoch 4 with loss 29.816457504639402 in 7.4320s
Saving checkpoint train_loss=29.8165 < min_loss=45.2459.
INFO:trainer:Completed epoch 5 with loss 22.350258031627163 in 7.4326s
Saving checkpoint train_loss=22.3503 < min_loss=29.8165.
INFO:trainer:Completed epoch 6 with loss 53.35907748359023 in 7.4237s
INFO:trainer:Completed epoch 7 with loss 13.679868848761544 in 7.4394s
Saving checkpoint train_loss=13.6799 < min_loss=22.3503.
INFO:trainer:Completed epoch 8 with los

In [5]:
def eval_metric(output, target):
    return (torch.max(output, 1)[1] == target).sum().cpu().item()

scores = devtorch.compute_metric(model, test_dataset, eval_metric, batch_size=256)
print(f"Accuracy = {torch.Tensor(scores).sum()/len(test_dataset)}")

Accuracy = 0.9639000296592712


**Exercise**: This tutorial just outlines one variant of how you might optimize the training process. Perhaps you can think of better and more exotic ways of doing this. As an exercise, you could try to extend the CheckpointTrainer to load model checkpoints and decay the learning rate whenever the training loss does not improve over a certain number of epochs.