# Example workflow using the utils

Let's build a simple classifier for MNIST using the sacred trainer utils to make our job easier.

In [1]:
import writefile_run
filename = 'example_mnist.py'

In [2]:
%%writefile_run $filename

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import *
from pytorch_utils.updaters import *
import pytorch_utils.sacred_trainer as st
import torch.nn.functional as F
from sacred import Experiment
from sacred.observers import FileStorageObserver
from visdom_observer.visdom_observer import VisdomObserver
from torchvision.datasets import *
from torchvision.transforms import *

IPY=True
try:
    get_ipython()
except:
    IPY=False

# The sacred experiment

In [3]:
%%writefile_run $filename -a


ex = Experiment('mnistclassifier_example', interactive=IPY)
save_dir = 'MnistClassifier'
ex.observers.append(VisdomObserver())
ex.observers.append(FileStorageObserver.create(save_dir))

## Get the data

The first step is to load the dataset and split into train, test and validation sets.

In [4]:
%%writefile_run $filename -a


@ex.config
def dataset_config():
    val_split=0.1
    batch_size=32

In [5]:
%%writefile_run $filename -a


@ex.capture
def make_dataloaders(val_split, batch_size):
    
    # get dataset
    total_train_mnist = MNIST('Z:/MNIST',download=True,transform=ToTensor())
    test_mnist = MNIST('Z:/MNIST',train=False,transform=ToTensor())
    
    training_number = len(total_train_mnist)

    train_num = int(training_number*(1-val_split))
    val_num = int(training_number*val_split)

    train_mnist, val_mnist = torch.utils.data.dataset.random_split(total_train_mnist, 
                    [train_num, val_num])
    
    
    # split into data loaders
    train_loader = DataLoader(train_mnist, batch_size=batch_size,shuffle=True)
    val_loader = DataLoader(val_mnist, batch_size=batch_size,shuffle=True)
    test_loader = DataLoader(test_mnist, batch_size=batch_size,shuffle=True)
    
    return train_loader, val_loader, test_loader

## The model

Now, we create a bunch of modules and integrate them into a final model. Normally, there would be a heirarchy of many different modules, but for this simple example, we can make do with a single module.

In [6]:
%%writefile_run $filename -a


@ex.config
def model_config():
    hidden_size = 400 # hidden units in hidden layer
    output_size = 10 # number of output labels

In [7]:
%%writefile_run $filename -a


class MnistClassifier(nn.Module):
    """A simple two layer fully connected neural network"""
    
    def __init__(self, hidden_size, output_size):
        super(MnistClassifier, self).__init__()
        
        self.hidden_layer = nn.Linear(28*28, hidden_size)
        
        self.output_layer = nn.Linear(hidden_size, output_size)
        
        
    def forward(self, input):
        input = input.view(-1,28*28)
        
        out = F.relu(self.hidden_layer(input))
        out = self.output_layer(out)
        
        return out

In [8]:
%%writefile_run $filename -a


@ex.capture
def make_model(hidden_size, output_size):
    return MnistClassifier(hidden_size, output_size)

## Optimizer

Let's create an optimizer for training our model.

In [9]:
%%writefile_run $filename -a


@ex.config
def optimizer_config():
    lr=0.001 # learning rate
    weight_decay=0 # l2 regularization factor
    
@ex.capture
def make_optimizer(model, lr, weight_decay):
    return optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

## The train on batch function

Next, we have to write a function with the following signature:  

`trainOnBatch(model, batch, optimizer) -> tuple of metrics`  

This function is supposed to perform one training step on the model using the given optimizer and batch. Note that the format in which the data comes in through the batch can be specified by us, as we build our own `Dataset` classes and `DataLoader`s with custom collate functions for each application. The trainer utility only requires that this function take in the three arguments as shown above and return a tuple of scalar metrics.  

This function is called on each batch in the training `DataLoader` and the metrics it returns are updated each batch according to a set of running metric updaters. The final running metrics logged as scalars by the sacred experiment.

In [10]:
%%writefile_run $filename -a


def classification_train_on_batch(model,batch,optimizer):
    
    # batch is tuple containing (tensor of images, tensor of labels)
    outputs = model(batch[0]) # forward pass
    
    # compute loss
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs,batch[1])
    
    # backward pass and weight update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # compute and return metrics
    loss = loss.detach().cpu().numpy()
    acc = st.accuracy(outputs, batch[1])
    
    return loss, acc

## The callback function

The next function which the trainer util uses is the `callback` function. It is an optional function, and its main purpose is for validation. Note that it can be used for other purposes like learning rate scheduling as well. The basic signature required is as follows:  

`callback(model, val_loader) -> tuple of callback metrics`

The `callback` receives the model instance and the validation `DataLoader`. It is expected to return a tuple of scalar callback metrics. This function is called at the end of every epoch of training, and the metrics it returns are logged as scalars by the sacred experiment.

In [11]:
%%writefile_run $filename -a


def classification_callback(model, val_loader):
    with torch.no_grad(): # dont compute gradients
        criterion = nn.CrossEntropyLoss()
        
        model.eval() # eval mode
        
        batches = len(val_loader)
        loss=0
        acc=0
        for batch in val_loader:
            outputs = model(batch[0])
            loss += criterion(outputs,batch[1])
            acc += st.accuracy(outputs,batch[1])
        
        # find average loss and accuracy over whole vaildation set
        loss/= batches
        acc /= batches
        
        model.train() # go back to train mode
        
        # return metrics
        return loss.cpu().numpy(), acc

In [12]:
%%writefile_run $filename -a


@ex.config
def train_config():
    epochs=10
    save_every=1 # epoch interval with which to save models

In [13]:
r = ex.run('print_config')

INFO - mnistclassifier_example - Running command 'print_config'
INFO - mnistclassifier_example - Started


Configuration (modified, added, typechanged, doc):
  batch_size = 32
  epochs = 10
  hidden_size = 400                  # hidden units in hidden layer
  lr = 0.001                         # learning rate
  output_size = 10                   # number of output labels
  save_every = 1                     # epoch interval with which to save models
  seed = 311952711                   # the random seed for this experiment
  val_split = 0.1
  weight_decay = 0                   # l2 regularization factor


INFO - mnistclassifier_example - Completed after 0:00:00


## The training

Now, we have defined all the functions we need for training our model. Let us write the sacred command function to run the training. We'll call all the above methods, and then use the `sacred_trainer.loop` function to make our lives easier. We need to pass in the defined dataloaders, functions, model and optimizer, as well as the config variables we defined earlier.

# Setup the loop command

In [14]:
%%writefile_run $filename -a


@ex.command
def train(_run):
    
    train_loader, val_loader, test_loader = make_dataloaders()
    model = make_model()
    optimizer = make_optimizer(model)
    
    st.loop(
        _run=_run,
        model=model,
        optimizer=optimizer,
        save_dir=save_dir,
        trainOnBatch=classification_train_on_batch,
        train_loader=train_loader,
        val_loader=val_loader,
        callback=classification_callback,
        **_run.config,
        **st.CLASSIFICATION,
    )

In [15]:
%%writefile_run $filename -a -dr


if __name__ == '__main__':
    ex.run_commandline()

Now, we can run it in the notebook or as a script.

In [16]:
ex.run('train')

INFO - mnistclassifier_example - Running command 'train'
INFO - visdom - Visdom successfully connected to server
INFO - mnistclassifier_example - Started run with ID "33"
Epoch: 1: 100%|█████████████████████████████████████████████| 1688/1688 [00:16<00:00, 103.02it/s, acc=92.9, loss=0.248]


Callback metrics: val_loss=0.117990 val_acc=96.575798
MnistClassifier\33\epoch001_23-08_2358_val_loss0.1180_val_acc96.5758.statedict.pkl


Epoch: 2: 100%|████████████████████████████████████████████| 1688/1688 [00:14<00:00, 117.16it/s, acc=97.1, loss=0.0985]


Callback metrics: val_loss=0.087185 val_acc=97.523271
MnistClassifier\33\epoch002_23-08_2358_val_loss0.0872_val_acc97.5233.statedict.pkl


Epoch: 3: 100%|██████████████████████████████████████████████| 1688/1688 [00:14<00:00, 115.97it/s, acc=98, loss=0.0645]


Callback metrics: val_loss=0.089751 val_acc=97.257314
MnistClassifier\33\epoch003_23-08_2359_val_loss0.0898_val_acc97.2573.statedict.pkl


Epoch: 4: 100%|████████████████████████████████████████████| 1688/1688 [00:14<00:00, 120.47it/s, acc=98.5, loss=0.0468]


Callback metrics: val_loss=0.077829 val_acc=97.656250
MnistClassifier\33\epoch004_23-08_2359_val_loss0.0778_val_acc97.6562.statedict.pkl


Epoch: 5: 100%|██████████████████████████████████████████████| 1688/1688 [00:14<00:00, 117.69it/s, acc=99, loss=0.0327]


Callback metrics: val_loss=0.078846 val_acc=97.789229
MnistClassifier\33\epoch005_23-08_2359_val_loss0.0788_val_acc97.7892.statedict.pkl


Epoch: 6: 100%|████████████████████████████████████████████| 1688/1688 [00:16<00:00, 101.17it/s, acc=99.2, loss=0.0254]


Callback metrics: val_loss=0.073868 val_acc=97.905585
MnistClassifier\33\epoch006_23-08_2359_val_loss0.0739_val_acc97.9056.statedict.pkl


Epoch: 7: 100%|████████████████████████████████████████████| 1688/1688 [00:15<00:00, 108.12it/s, acc=99.4, loss=0.0194]


Callback metrics: val_loss=0.068194 val_acc=97.822473
MnistClassifier\33\epoch007_24-08_0000_val_loss0.0682_val_acc97.8225.statedict.pkl


Epoch: 8: 100%|████████████████████████████████████████████| 1688/1688 [00:14<00:00, 115.17it/s, acc=99.5, loss=0.0151]


Callback metrics: val_loss=0.084216 val_acc=97.905585
MnistClassifier\33\epoch008_24-08_0000_val_loss0.0842_val_acc97.9056.statedict.pkl


Epoch: 9: 100%|████████████████████████████████████████████| 1688/1688 [00:15<00:00, 111.94it/s, acc=99.6, loss=0.0129]


Callback metrics: val_loss=0.083135 val_acc=97.855718
MnistClassifier\33\epoch009_24-08_0000_val_loss0.0831_val_acc97.8557.statedict.pkl


Epoch: 10: 100%|████████████████████████████████████████████| 1688/1688 [00:16<00:00, 99.33it/s, acc=99.7, loss=0.0106]


Callback metrics: val_loss=0.092859 val_acc=97.606383
MnistClassifier\33\epoch010_24-08_0000_val_loss0.0929_val_acc97.6064.statedict.pkl
MnistClassifier\33\epoch010_24-08_0000_val_loss0.0929_val_acc97.6064.statedict.pkl


INFO - mnistclassifier_example - Completed after 0:02:39


<sacred.run.Run at 0x2937d693550>