# Example workflow using the utils

Let's build a simple classifier for MNIST using the sacred trainer utils to make our job easier.

In [1]:
import writefile_run
filename = 'example_mnist.py'

In [2]:
%%writefile_run $filename

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import *
from pytorch_utils.updaters import *
import pytorch_utils.sacred_trainer as st
import torch.nn.functional as F
from sacred import Experiment
from sacred.observers import FileStorageObserver
from visdom_observer.visdom_observer import VisdomObserver
from torchvision.datasets import *
from torchvision.transforms import *

IPY=True
try:
    get_ipython()
except:
    IPY=False

# The sacred experiment

In [3]:
%%writefile_run $filename -a


ex = Experiment('mnistclassifier_example', interactive=IPY)
save_dir = 'MnistClassifier'
ex.observers.append(FileStorageObserver.create(save_dir))
ex.observers.append(VisdomObserver())

## Get the data

The first step is to load the dataset and split into train, test and validation sets.

In [4]:
%%writefile_run $filename -a


@ex.config
def dataset_config():
    val_split=0.1
    batch_size=32

In [5]:
%%writefile_run $filename -a


@ex.capture
def make_dataloaders(val_split, batch_size):
    
    # get dataset
    total_train_mnist = MNIST('Z:/MNIST',download=True,transform=ToTensor())
    test_mnist = MNIST('Z:/MNIST',train=False,transform=ToTensor())
    
    training_number = len(total_train_mnist)

    train_num = int(training_number*(1-val_split))
    val_num = int(training_number*val_split)

    train_mnist, val_mnist = torch.utils.data.dataset.random_split(total_train_mnist, 
                    [train_num, val_num])
    
    
    # split into data loaders
    train_loader = DataLoader(train_mnist, batch_size=batch_size,shuffle=True)
    val_loader = DataLoader(val_mnist, batch_size=batch_size,shuffle=True)
    test_loader = DataLoader(test_mnist, batch_size=batch_size,shuffle=True)
    
    return train_loader, val_loader, test_loader

## The model

Now, we create a bunch of modules and integrate them into a final model. Normally, there would be a heirarchy of many different modules, but for this simple example, we can make do with a single module.

In [6]:
%%writefile_run $filename -a


@ex.config
def model_config():
    hidden_size = 400 # hidden units in hidden layer
    output_size = 10 # number of output labels

In [7]:
%%writefile_run $filename -a


class MnistClassifier(nn.Module):
    """A simple two layer fully connected neural network"""
    
    def __init__(self, hidden_size, output_size):
        super(MnistClassifier, self).__init__()
        
        self.hidden_layer = nn.Linear(28*28, hidden_size)
        
        self.output_layer = nn.Linear(hidden_size, output_size)
        
        
    def forward(self, input):
        input = input.view(-1,28*28)
        
        out = F.relu(self.hidden_layer(input))
        out = self.output_layer(out)
        
        return out

In [8]:
%%writefile_run $filename -a


@ex.capture
def make_model(hidden_size, output_size):
    return MnistClassifier(hidden_size, output_size)

## Optimizer

Let's create an optimizer for training our model.

In [9]:
%%writefile_run $filename -a


@ex.config
def optimizer_config():
    lr=0.001 # learning rate
    weight_decay=0 # l2 regularization factor
    
@ex.capture
def make_optimizer(model, lr, weight_decay):
    return optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

## The train on batch function

Next, we have to write a function with the following signature:  

`trainOnBatch(model, batch, optimizer) -> tuple of metrics`  

This function is supposed to perform one training step on the model using the given optimizer and batch. Note that the format in which the data comes in through the batch can be specified by us, as we build our own `Dataset` classes and `DataLoader`s with custom collate functions for each application. The trainer utility only requires that this function take in the three arguments as shown above and return a tuple of scalar metrics.  

This function is called on each batch in the training `DataLoader` and the metrics it returns are updated each batch according to a set of running metric updaters. The final running metrics logged as scalars by the sacred experiment.

In [10]:
%%writefile_run $filename -a


def classification_train_on_batch(model,batch,optimizer):
    
    # batch is tuple containing (tensor of images, tensor of labels)
    outputs = model(batch[0]) # forward pass
    
    # compute loss
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs,batch[1])
    
    # backward pass and weight update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # compute and return metrics
    loss = loss.detach().cpu().numpy()
    acc = st.accuracy(outputs, batch[1])
    
    return loss, acc

## The callback function

The next function which the trainer util uses is the `callback` function. It is an optional function, and its main purpose is for validation. Note that it can be used for other purposes like learning rate scheduling as well. The basic signature required is as follows:  

`callback(model, val_loader) -> tuple of callback metrics`

The `callback` receives the model instance and the validation `DataLoader`. It is expected to return a tuple of scalar callback metrics. This function is called at the end of every epoch of training, and the metrics it returns are logged as scalars by the sacred experiment.

In [11]:
%%writefile_run $filename -a


def classification_callback(model, val_loader):
    with torch.no_grad(): # dont compute gradients
        criterion = nn.CrossEntropyLoss()
        
        model.eval() # eval mode
        
        batches = len(val_loader)
        loss=0
        acc=0
        for batch in val_loader:
            outputs = model(batch[0])
            loss += criterion(outputs,batch[1])
            acc += st.accuracy(outputs,batch[1])
        
        # find average loss and accuracy over whole vaildation set
        loss/= batches
        acc /= batches
        
        model.train() # go back to train mode
        
        # return metrics
        return loss.cpu().numpy(), acc

In [12]:
%%writefile_run $filename -a


@ex.config
def train_config():
    epochs=10
    save_every=1 # epoch interval with which to save models

In [13]:
r = ex.run('print_config')

INFO - mnistclassifier_example - Running command 'print_config'
INFO - mnistclassifier_example - Started


Configuration (modified, added, typechanged, doc):
  batch_size = 32
  epochs = 10
  hidden_size = 400
  lr = 0.001
  output_size = 10
  save_every = 1
  seed = 562109891                   # the random seed for this experiment
  val_split = 0.1
  weight_decay = 0


INFO - mnistclassifier_example - Completed after 0:00:00


## The training

Now, we have defined all the functions we need for training our model. Let us write the sacred command function to run the training. We'll call all the above methods, and then use the `sacred_trainer.loop` function to make our lives easier. We need to pass in the defined dataloaders, functions, model and optimizer, as well as the config variables we defined earlier.

# Setup the loop command

In [14]:
%%writefile_run $filename -a


@ex.command
def train(_run):
    
    train_loader, val_loader, test_loader = make_dataloaders()
    model = make_model()
    optimizer = make_optimizer(model)
    
    st.loop(
        _run=_run,
        model=model,
        optimizer=optimizer,
        save_dir=save_dir,
        trainOnBatch=classification_train_on_batch,
        train_loader=train_loader,
        val_loader=val_loader,
        callback=classification_callback,
        **_run.config,
        **st.CLASSIFICATION,
    )

In [15]:
%%writefile_run $filename -a -dr


if __name__ == '__main__':
    ex.run_commandline()

Now, we can run it in the notebook or as a script.

In [16]:
ex.run('train')

INFO - mnistclassifier_example - Running command 'train'
INFO - visdom - Visdom successfully connected to server
INFO - mnistclassifier_example - Started run with ID "28"
Epoch: 1: 100%|███████████████████████████████████████████████| 1688/1688 [00:13<00:00, 127.24it/s, acc=93, loss=0.243]


Callback metrics: val_loss=0.131072 val_acc=96.392952
MnistClassifier\28\epoch001_23-08_2225_val_loss0.1311_val_acc96.3930.statedict.pkl


Epoch: 2: 100%|██████████████████████████████████████████████| 1688/1688 [00:13<00:00, 123.13it/s, acc=97, loss=0.0986]


Callback metrics: val_loss=0.107750 val_acc=96.908245
MnistClassifier\28\epoch002_23-08_2225_val_loss0.1078_val_acc96.9082.statedict.pkl


Epoch: 3: 100%|████████████████████████████████████████████| 1688/1688 [00:14<00:00, 119.90it/s, acc=98.1, loss=0.0623]


Callback metrics: val_loss=0.086531 val_acc=97.672872
MnistClassifier\28\epoch003_23-08_2226_val_loss0.0865_val_acc97.6729.statedict.pkl


Epoch: 4: 100%|█████████████████████████████████████████████| 1688/1688 [00:13<00:00, 121.05it/s, acc=98.6, loss=0.044]


Callback metrics: val_loss=0.086501 val_acc=97.523271
MnistClassifier\28\epoch004_23-08_2226_val_loss0.0865_val_acc97.5233.statedict.pkl


Epoch: 5: 100%|██████████████████████████████████████████████| 1688/1688 [00:13<00:00, 122.96it/s, acc=99, loss=0.0312]


Callback metrics: val_loss=0.091303 val_acc=97.556516
MnistClassifier\28\epoch005_23-08_2226_val_loss0.0913_val_acc97.5565.statedict.pkl


Epoch: 6: 100%|████████████████████████████████████████████| 1688/1688 [00:14<00:00, 119.58it/s, acc=99.2, loss=0.0256]


Callback metrics: val_loss=0.087152 val_acc=97.722739
MnistClassifier\28\epoch006_23-08_2226_val_loss0.0872_val_acc97.7227.statedict.pkl


Epoch: 7: 100%|████████████████████████████████████████████| 1688/1688 [00:14<00:00, 118.46it/s, acc=99.4, loss=0.0199]


Callback metrics: val_loss=0.087594 val_acc=97.888963
MnistClassifier\28\epoch007_23-08_2227_val_loss0.0876_val_acc97.8890.statedict.pkl


Epoch: 8: 100%|████████████████████████████████████████████| 1688/1688 [00:13<00:00, 123.38it/s, acc=99.5, loss=0.0149]


Callback metrics: val_loss=0.088916 val_acc=97.672872
MnistClassifier\28\epoch008_23-08_2227_val_loss0.0889_val_acc97.6729.statedict.pkl


Epoch: 9: 100%|████████████████████████████████████████████| 1688/1688 [00:14<00:00, 118.85it/s, acc=99.6, loss=0.0137]


Callback metrics: val_loss=0.095934 val_acc=97.988697
MnistClassifier\28\epoch009_23-08_2227_val_loss0.0959_val_acc97.9887.statedict.pkl


Epoch: 10: 100%|███████████████████████████████████████████| 1688/1688 [00:13<00:00, 123.21it/s, acc=99.7, loss=0.0104]


Callback metrics: val_loss=0.114686 val_acc=97.523271
MnistClassifier\28\epoch010_23-08_2227_val_loss0.1147_val_acc97.5233.statedict.pkl
MnistClassifier\28\epoch010_23-08_2227_val_loss0.1147_val_acc97.5233.statedict.pkl


INFO - mnistclassifier_example - Completed after 0:02:25


<sacred.run.Run at 0x1d9fecf5630>