In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from functools import partial

In [3]:
# The goal is to implement callbacks to allow inserting different operations during training

## Download Dataset

In [4]:
# Download a dataset to train
from fastai import datasets
import gzip
import pickle

MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(torch.tensor, (x_train,y_train,x_valid,y_valid))

x_train,y_train,x_valid,y_valid = get_data()

In [5]:
x_train.shape, y_train.shape, x_valid.shape, y_valid.shape

(torch.Size([50000, 784]),
 torch.Size([50000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))

## Load Data

In [6]:
class MNIST_Dataset(Dataset):
    """This class holds and resizes MNIST datasets."""
    
    def __init__(self, x, y):
        self.x = self.mnist_resize(x)
        self.y = y
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, i):
            return self.x[i], self.y[i]  
        
    def mnist_resize(self,x):
        return x.view(-1,1,28,28)

In [7]:
ds_train = MNIST_Dataset(x_train,y_train)
ds_valid = MNIST_Dataset(x_valid,y_valid)

dl_train = DataLoader(ds_train, batch_size=64, shuffle=True)
dl_valid = DataLoader(ds_valid, batch_size=64, shuffle=False)

## Model

In [8]:
class Lambda(nn.Module):
    """This class 'simulates' a Pytorch nn layern and allows to insert functions with
    a single input and single output.
    """
    
    def __init__(self,func):
        super().__init__()
        self.func = func
        
    def __call__(self,x):
        return self.func(x)

In [9]:
def flatten(x):
    return x.view(x.shape[0],-1)

In [10]:
x_flatten = flatten(torch.Tensor(100,32,1,1))
x_flatten.shape

torch.Size([100, 32])

In [11]:
bs = 64

model = nn.Sequential(
            nn.Sequential(nn.Conv2d(1, 8, 5, stride=2, padding=2), 
                          nn.LeakyReLU(negative_slope=0.1),
                          nn.BatchNorm2d(8)), #14
            nn.Sequential(nn.Conv2d(8, 16, 3, stride=1, padding=2),
                          nn.LeakyReLU(negative_slope=0.1),
                          nn.BatchNorm2d(16)), #7
            nn.Sequential(nn.Conv2d(16, 32, 3, stride=1, padding=2),
                          nn.LeakyReLU(negative_slope=0.1),
                          nn.BatchNorm2d(32)), #4
            nn.Sequential(nn.Conv2d(32, 32, 3, stride=1, padding=2),
                          nn.LeakyReLU(negative_slope=0.1),
                          nn.BatchNorm2d(32)), #2
            nn.AdaptiveAvgPool2d(1),
            Lambda(flatten),
            nn.Linear(32,10)
        )
model

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.1)
    (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): Sequential(
    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.1)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.1)
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (3): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.1)
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (4): AdaptiveAvgPool2d(output_size=1)
  (5): Lambda()
  (6): Linear(in_features=32, ou

In [12]:
# Initialize the random weights with Kaiming
for l in model:
    if isinstance(l, nn.Sequential):
        nn.init.kaiming_normal_(l[0].weight)
        l[0].bias.data.zero_()

## Callbacks

In [13]:
class Callback():
    """This class holds the expected structure for callbacks and shall be inherited from 
    other custom callbacks. To stop execution early different Exceptions are available 
    that can be raised. These exceptions shall be caught by the main training loop.
    """
    
    def set_trainer(self, train): 
        self.train=train
    
    def begin_training(self): pass
    
    def after_training(self): pass
    
    def cancel_training(self): pass
    
    def begin_epoch(self): pass
    
    def after_epoch(self): pass
    
    def cancel_epoch(self): pass
    
    def begin_batch(self): pass
    
    def after_batch(self): pass
    
    def cancel_batch(self): pass
    
    def after_pred(self): pass
    
    def after_loss(self): pass
    
    def after_backward(self): pass
    
    def after_step(self): pass
    
    def begin_validation(self): pass
    
    
class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

In [14]:
class PrintAccuarcyCallback(Callback):
    
    def __init__(self):
        self.acc = []
    
    def accuracy(self, pred, yb): 
        return (torch.argmax(pred, dim=1)==yb).float().mean()
    
    def begin_validation(self):
        self.acc = []
    
    def after_pred(self):
        self.acc.append(self.accuracy(self.train.pred, self.train.yb).item())
        
    def after_epoch(self):
        mean_acc = sum(self.acc) / len(self.acc)
        
        print(f'The validation accuracy for epoch {self.train.epoch+1} is: {mean_acc:.4f}')

In [15]:
class CancelMidEpochCallback(Callback):
    
    def after_batch(self):
        print(self.train.batch)
        if self.train.batch >= 20:
            raise CancelTrainException
            
    def cancel_training(self):
        print('Canceled Training')

## Training Loop

In [16]:
class Trainer():
    """This class initializes a train that can be called with n epochs to run.
    It needs following input:
    model: A Pytorch nn.Module
    dl_train: A Pytorch DataLoader
    dl_valid: A Pytorch DataLoader
    optimizer: A Pytorch optimizer functipon
    loss_func: A Pytorch loss function
    callbacks: A list of Callback Class Instances
    """
    
    def __init__(self, model, dl_train, dl_valid, optimizer, loss_func, callbacks):
        self.model = model
        self.dl_train, self.dl_valid = dl_train, dl_valid
        self.opt = optimizer
        self.loss_func = loss_func
        self.cbs = callbacks
        self.valid_mode = False
        
        for cb in self.cbs:
            cb.set_trainer(self)
            
    def _run_one_batch(self, xb, yb):
        
        try:
            for cb in self.cbs: cb.begin_batch()                 # Begin Batch

            self.pred = self.model(xb)
            for cb in self.cbs: cb.after_pred()                  # After Pred
            self.loss = self.loss_func(self.pred,yb)
            for cb in self.cbs: cb.after_loss()                  # After Loss
            if self.valid_mode: return
            self.loss.backward()
            for cb in self.cbs: cb.after_backward()              # After Backward
                
            self.opt.step()
            for cb in self.cbs: cb.after_step()                  # After Step 
            self.opt.zero_grad()

            for cb in self.cbs: cb.after_batch()                 # After Batch
                
        except CancelBatchException: 
            for cb in self.cbs: cb.cancel_batch()                # Cancel Batch
    
    def _run_one_epoch(self, dl):
        self.batch = 1
        try:
            for xb, yb in dl:
                self.xb, self.yb = xb, yb

                self._run_one_batch(xb, yb)
                self.batch += 1
                 
        except CancelEpochException:
            for cb in self.cbs: cb.cancel_epoch()               # Cancel Epoch
              
    def __call__(self, epochs):
        for cb in self.cbs: cb.begin_training()                 # Begin Training
        try:
            for epoch in range(epochs):
                self.epoch = epoch
                for cb in self.cbs: cb.begin_epoch()            # Begin Epoch
                self._run_one_epoch(dl_train)
                            
                with torch.no_grad():
                    for cb in self.cbs: cb.begin_validation()   # Begin Validation
                    self.valid_mode = True
                    self._run_one_epoch(dl_valid)
                    self.valid_mode = False

                for cb in self.cbs: cb.after_epoch()            # After Epoch
                
                
                
        except CancelTrainException:
            for cb in self.cbs: cb.cancel_training()            # Cancel Training
                
        for cb in self.cbs: cb.after_training()                 # After Training

In [17]:
opt = torch.optim.SGD(model.parameters(), lr = 0.1)
loss_func = F.cross_entropy
cbs = [PrintAccuarcyCallback()]

train = Trainer(model, dl_train, dl_valid, opt, loss_func, cbs)

In [18]:
train(3)

The validation accuracy for epoch 1 is: 0.9170
The validation accuracy for epoch 2 is: 0.9580
The validation accuracy for epoch 3 is: 0.9695


In [19]:
opt = torch.optim.SGD(model.parameters(), lr = 0.1)
loss_func = F.cross_entropy
cbs = [PrintAccuarcyCallback(), CancelMidEpochCallback()]

train = Trainer(model, dl_train, dl_valid, opt, loss_func, cbs)

In [20]:
train(1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Canceled Training


In [21]:
# Fin