# FastResNet Hyperparameters tuning with [Ax](https://ax.dev/) on CIFAR10

In this notebook we provide an example of hyperparameter tuning with [Ax](https://ax.dev/) package. We will train a ResNet model from [awesome repository of David Page](https://github.com/davidcpage/cifar10-fast) on CIFAR10.


### Why Ax ?

This is a good question ... Maybe this page could better answer : https://ax.dev/docs/why-ax.html

> Ax is a platform for optimizing any kind of experiment, including machine learning experiments, A/B tests, and simulations. Ax can optimize discrete configurations (e.g., variants of an A/B test) using multi-armed bandit optimization, and continuous (e.g., integer or floating point)-valued configurations using Bayesian optimization. This makes it suitable for a wide range of applications.

There are also interesting packages as [ray-tune](https://ray.readthedocs.io/en/latest/tune.html), [optuna](https://github.com/pfnet/optuna) and many others. As a side note, optuna provides an example with Ignite [here](https://github.com/pfnet/optuna/blob/master/examples/ignite_simple.py).


### Fast ResNet model

We will reimplement a resnet model from David Page's [cifar-10 repository](https://github.com/davidcpage/cifar10-fast) which trains very fast (94% of test accuracy in 26 second on NVidia V100). For sake of simplicity, we will not apply all preprocessing used in the repository (please see [bag-of-trick notebook](https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb) for details).


### Setup dependencies

Please install 
- `torchvision`
- `Ax`
- `tensorboard`


In [None]:
!pip install pytorch-ignite tensorboardX ax-platform

In [None]:
import sys
sys.path.insert(0, "../../")

In [None]:
import torch
import ignite

torch.__version__, ignite.__version__

### Setup model

Cifar10-fast model is inspired of ResNet family models and in order to run fast it uses various tricks like:
- `conv + batch norm + activation + pool` -> `conv + pool + batch norm + activation`
- `batchnorm` -> `ghost batchnorm` -> `frozen ghost batchnorm`
- `ReLU` -> `CeLU`
- data whitening as convolution non-learnable operation (we will not implement it)

Network architecture looks like this:

![fastresnet](https://github.com/abdulelahsm/ignite/blob/update-tutorials/examples/notebooks/assets/fastresnet_v2.svg?raw=1)


Please see [bag-of-trick notebook](https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb) for more detail.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class GhostBatchNorm(nn.BatchNorm2d):
    """
    From : https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb

    Batch norm seems to work best with batch size of around 32. The reasons presumably have to do 
    with noise in the batch statistics and specifically a balance between a beneficial regularising effect 
    at intermediate batch sizes and an excess of noise at small batches.
    
    Our batches are of size 512 and we can't afford to reduce them without taking a serious hit on training times, 
    but we can apply batch norm separately to subsets of a training batch. This technique, known as 'ghost' batch 
    norm, is usually used in a distributed setting but is just as useful when using large batches on a single node. 
    It isn't supported directly in PyTorch but we can roll our own easily enough.
    """
    def __init__(self, num_features, num_splits, eps=1e-05, momentum=0.1, weight=True, bias=True):
        super(GhostBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum)
        self.weight.data.fill_(1.0)
        self.bias.data.fill_(0.0)
        self.weight.requires_grad = weight
        self.bias.requires_grad = bias        
        self.num_splits = num_splits
        self.register_buffer('running_mean', torch.zeros(num_features*self.num_splits))
        self.register_buffer('running_var', torch.ones(num_features*self.num_splits))

    def train(self, mode=True):
        if (self.training is True) and (mode is False):
            self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)
            self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)
        return super(GhostBatchNorm, self).train(mode)
        
    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            return F.batch_norm(
                input.view(-1, C*self.num_splits, H, W), self.running_mean, self.running_var, 
                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
                True, self.momentum, self.eps).view(N, C, H, W) 
        else:
            return F.batch_norm(
                input, self.running_mean[:self.num_features], self.running_var[:self.num_features], 
                self.weight, self.bias, False, self.momentum, self.eps)

        
class IdentityResidualBlock(nn.Module):

    def __init__(self, num_channels, 
                 conv_ksize=3, conv_pad=1,
                 gbn_num_splits=16):
        super(IdentityResidualBlock, self).__init__()
        self.res1 = nn.Sequential(
            Conv2d(num_channels, num_channels, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            GhostBatchNorm(num_channels, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3)         
        )
        self.res2 = nn.Sequential(
            Conv2d(num_channels, num_channels, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            GhostBatchNorm(num_channels, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3)    
        )

    def forward(self, x):
        residual = x
        x = self.res1(x)
        x = self.res2(x)
        return x + residual
    

# We override conv2d to get proper padding for kernel size = 2   
class Conv2d(nn.Conv2d):
    
    def __init__(self, *args, **kwargs):
        super(Conv2d, self).__init__(*args, **kwargs)
        if self.kernel_size == (2, 2):
            self.forward = self.ksize_2_forward
            self.ksize_2_padding = (0, self.padding[0], 0, self.padding[1])
            self.padding = (0, 0)
        
    def ksize_2_forward(self, x):
        x = F.pad(x, pad=self.ksize_2_padding)
        return super(Conv2d, self).forward(x)

In [None]:
class FastResNet(nn.Module):
        
    def __init__(self, num_classes=10, 
                 fmap_factor=64, conv_ksize=3, conv_pad=1, 
                 gbn_num_splits=512 // 32,                  
                 classif_scale=0.0625):
        super(FastResNet, self).__init__()
                
        self.prep = nn.Sequential(
            Conv2d(3, fmap_factor, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            GhostBatchNorm(fmap_factor, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3)
        )

        self.layer1 = nn.Sequential(
            Conv2d(fmap_factor, fmap_factor * 2, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            nn.MaxPool2d(kernel_size=2),
            GhostBatchNorm(fmap_factor * 2, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3),
            IdentityResidualBlock(fmap_factor * 2,
                                  conv_ksize=conv_ksize, conv_pad=conv_pad, 
                                  gbn_num_splits=gbn_num_splits)
        )
        
        self.layer2 = nn.Sequential(
            Conv2d(fmap_factor * 2, fmap_factor * 4, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            nn.MaxPool2d(kernel_size=2),
            GhostBatchNorm(fmap_factor * 4, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3),            
        )
        
        self.layer3 = nn.Sequential(
            Conv2d(fmap_factor * 4, fmap_factor * 8, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),
            nn.MaxPool2d(kernel_size=2),
            GhostBatchNorm(fmap_factor * 8, num_splits=gbn_num_splits, weight=False),
            nn.CELU(alpha=0.3),
            IdentityResidualBlock(fmap_factor * 8, 
                                  conv_ksize=conv_ksize, conv_pad=conv_pad, 
                                  gbn_num_splits=gbn_num_splits)
        )
        
        self.pool = nn.MaxPool2d(kernel_size=4)
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(fmap_factor * 8, num_classes)
        )
        self.scale = torch.tensor(0.0625, requires_grad=False)

    def forward(self, x):
        x = self.prep(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        y = self.classifier(x)
        return y * self.scale
      

In [None]:
model = FastResNet(10, fmap_factor=64)

In [None]:
def print_num_params(model, display_all_modules=False):
    total_num_params = 0
    for n, p in model.named_parameters():
        num_params = 1
        for s in p.shape:
            num_params *= s
        if display_all_modules: print(f"{n}: {num_params}")
        total_num_params += num_params
    print("-" * 50)
    print(f"Total number of parameters: {total_num_params:.2e}")
    

print_num_params(model)

In [None]:
model = FastResNet(10, fmap_factor=64, conv_ksize=2)

print_num_params(model)

### Setup dataflow

We will setup the dataflow using `torchvision` transformation and will not follow the suggestions of [bag-of-trick notebook](https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb). Data augmentations used to transform the dataset are

- Random Crop
- Flip left-right
- Cutout

In [None]:
import torch
from torchvision.transforms import Compose, Pad, RandomHorizontalFlip, RandomErasing, RandomCrop, Normalize

In [None]:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.cifar import CIFAR10


train_transform = Compose([
    Pad(4),
    RandomCrop(32),
    RandomHorizontalFlip(),
    ToTensor(),    
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    RandomErasing(scale=(0.0625, 0.0625), ratio=(1.0, 1.0))
])


test_transform = Compose([
    ToTensor(),    
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


train_ds = CIFAR10("/tmp/cifar10", train=True, download=True, transform=train_transform)
test_ds = CIFAR10("/tmp/cifar10", train=False, download=True, transform=train_transform)

In [None]:
def get_train_test_loaders():
    train_loader = DataLoader(train_ds, batch_size=512, num_workers=10, shuffle=True, drop_last=True, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=512, num_workers=10, shuffle=False, drop_last=False, pin_memory=True)
    return train_loader, test_loader

### Setup criterion, optimizer and lr scheduling

Following cifar10-fast, we will use label smoothing trick for improving the training speed and generalization of neural nets in classification problems.

In [None]:
import torch.nn as nn
import torch.optim as optim


class CriterionWithLabelSmoothing(nn.Module):
    
    def __init__(self, criterion, alpha=0.2):
        super(CriterionWithLabelSmoothing, self).__init__()
        self.criterion = criterion
        if self.criterion.reduction != 'none':
            raise ValueError("Input criterion should have reduction equal none")
        self.alpha = alpha
    
    def forward(self, logits, targets):
        loss = self.criterion(logits, targets)
        log_probs = torch.log_softmax(logits, dim=1)
        klloss = -log_probs.mean(dim=1)        
        out = (1.0 - self.alpha) * loss + self.alpha * klloss
        return out.mean(dim=0)

In [None]:
def get_criterion(alpha):
    return CriterionWithLabelSmoothing(nn.CrossEntropyLoss(reduction='none'), alpha=0.2)


def get_optimizer(model, momentum, weight_decay, nesterov):
    biases = [p for n, p in model.named_parameters() if "bias" in n]
    others = [p for n, p in model.named_parameters() if "bias" not in n]
    return optim.SGD(
        [{"params": others, "lr": 1.0, "weight_decay": weight_decay}, 
         {"params": biases, "lr": 1.0, "weight_decay": weight_decay / 64}], 
        momentum=momentum, nesterov=nesterov
    )


There is an implementation difference of current PyTorch SGD and SGD from cifar10-fast. The latter uses Sutskever et al implementation:
```
new_w = w + mu * v - lr * (dw + weight_decay * w)
v = mu * prev_v - lr * (dw + weight_decay * w)
```
and PyTorch's one is 
```
new_w = w - lr * (mu * v + dw + weight_decay * w)
v = mu * prev_v + dw + weight_decay * w
```

In [None]:
from ignite.contrib.handlers import PiecewiseLinear, ParamGroupScheduler

In [None]:
def get_lr_scheduler(optimizer, lr_max_value, lr_max_value_epoch, num_epochs, epoch_length):
    milestones_values = [
        (0, 0.0), 
        (epoch_length * lr_max_value_epoch, lr_max_value), 
        (epoch_length * num_epochs - 1, 0.0)
    ]
    lr_scheduler1 = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values, param_group_index=0)

    milestones_values = [
        (0, 0.0), 
        (epoch_length * lr_max_value_epoch, lr_max_value * 64), 
        (epoch_length * num_epochs - 1, 0.0)
    ]
    lr_scheduler2 = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values, param_group_index=1)

    lr_scheduler = ParamGroupScheduler(
        [lr_scheduler1, lr_scheduler2],
        ["lr scheduler (non-biases)", "lr scheduler (biases)"]
    )
    
    return lr_scheduler

In [None]:
%matplotlib inline

num_epochs = 25
lr_max_value = 0.4
milestones_values = [(0, 0.0), (num_epochs // 5, lr_max_value), (num_epochs - 1, 0.0)]

PiecewiseLinear.plot_values(num_epochs, param_name="lr", milestones_values=milestones_values)

### Setup hyperparameter tuning

Now we are ready to setup hyperparameter tuning to optimize the following parameters in order to get higher accuracy on test dataset while training limited by 12 epochs:

- learning rate peak value: `[0.1, 1.0]`
- SGD momentum: `[0.7, 1.0]`
- weight decay: `[0.0, 1e-3]`
- label smoothing `alpha`: `[0.1, 0.5]`
- number of features (`fmap_factor`): `[16, 24, 32, 40, 48, 56, 64, 72, 80]`
- convolution kernel size: `3` or `2`
- ...

In [None]:
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting

init_notebook_plotting()

First, we need to create evaluation function that receives experiment parameters and returns test accuracy.

Input parameters search space is defined as a list of dictionaries that have the following required keys: 
- "name" - parameter name, 
- "type" - parameter type ("range", "choice" or "fixed"), 
- "bounds" for range parameters, 
- "values" for choice parameters, and 
- "value" for fixed parameters.

Experiment parameters object provided for a single experiment is a dictionary `parameter name: value or values`. 


Links: 
- [Ax Parameters API](https://ax.dev/api/core.html#module-ax.core.parameter)
- [Ax optimize function](https://ax.dev/api/service.html#ax.service.managed_loop.optimize)
- [Ax parameters search space example](https://ax.dev/tutorials/gpei_hartmann_service.html#2.-Set-up-experiment)

In [None]:
from ignite.engine import create_supervised_trainer, create_supervised_evaluator, Events, convert_tensor
from ignite.metrics import Accuracy
from ignite.contrib.handlers import TensorboardLogger, ProgressBar
from ignite.contrib.handlers.tensorboard_logger import OutputHandler, OptimizerParamsHandler, GradsHistHandler, \
    global_step_from_engine

In [None]:
# Transfer batch to GPU and set floating-point 16
def prepare_batch_fp16(batch, device=None, non_blocking=True):
    x, y = batch
    return (convert_tensor(x, device=device, non_blocking=non_blocking).half(),
            convert_tensor(y, device=device, non_blocking=non_blocking))

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
num_epochs = 17


def run_experiment(parameters):
    device = 'cuda'
    fast_mode = parameters.get("fast_mode", True)
    
    # setup model
    model = FastResNet(
        num_classes=10, 
        fmap_factor=parameters.get("fmap_factor"), 
        conv_ksize=parameters.get("conv_ksize"),
        classif_scale=parameters.get("classif_scale")
    ).to(device).half()
    
    # setup dataloaders 
    train_loader, test_loader = get_train_test_loaders()
    
    # setup solver
    criterion = get_criterion(parameters.get("alpha")).to(device)
    optimizer = get_optimizer(
        model, 
        parameters.get("momentum"), 
        parameters.get("weight_decay"),
        parameters.get("nesterov")
    )
    lr_scheduler = get_lr_scheduler(
        optimizer, 
        parameters.get("lr_max_value"),
        parameters.get("lr_max_value_epoch"),        
        num_epochs=num_epochs,
        epoch_length=len(train_loader)
    )
    
    # setup ignite trainer
    trainer = create_supervised_trainer(model, optimizer, criterion, 
                                        device=device, non_blocking=True,
                                        prepare_batch=prepare_batch_fp16)
    
    # setup learning rate scheduler
    trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
    
    # setup tensorboard logger
    exp_log_name = f"exp_{parameters.get('fmap_factor')}_{parameters.get('conv_ksize')}_" + \
        f"{parameters.get('alpha'):.2}_{parameters.get('lr_max_value'):.4}"
    tb_logger = TensorboardLogger(log_dir=f"/tmp/tb_logs/{exp_log_name}")
    
    if not fast_mode:
        # - log learning rate
        tb_logger.attach(trainer, OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)

        # - log training batch loss
        tb_logger.attach(trainer, OutputHandler(tag="training", output_transform=lambda x: {"batch loss": x}), 
                         event_name=Events.ITERATION_COMPLETED)

        # - log model grads
        tb_logger.attach(trainer, GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED)    
    
        # setup a progress bar
        ProgressBar().attach(trainer, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)    
        
    # setup evaluator
    def output_transform(output):
        y_pred, y = output
        y_pred = y_pred.float()
        return y_pred, y

    metrics = {
        "test accuracy": Accuracy(output_transform=output_transform)
    }
    evaluator = create_supervised_evaluator(model, metrics=metrics, 
                                            device=device, non_blocking=True, 
                                            prepare_batch=prepare_batch_fp16)
    
    # evaluate trained model each 3 epochs
    @trainer.on(Events.EPOCH_COMPLETED)
    def run_evaluation(engine):
        c1 = (engine.state.epoch - 1) % 3 == 0
        c2 = engine.state.epoch == engine.state.max_epochs
        if (c1 and not fast_mode) or c2:
            evaluator.run(test_loader)
    
    if not fast_mode:
        # - log test accuracy
        tb_logger.attach(evaluator, 
                         OutputHandler(tag="validation", metric_names="all", 
                                                  global_step_transform=global_step_from_engine(trainer)), 
                         event_name=Events.EPOCH_COMPLETED)

    trainer.run(train_loader, max_epochs=num_epochs)    
    test_acc = evaluator.state.metrics['test accuracy']
    
    # dump hparams/result to Tensorboard
    tb_logger.writer.add_hparams(parameters, {'hparam/test_accuracy': test_acc})

    tb_logger.close()    
    return test_acc

Original training configurations gives us the following result:

In [None]:
batch_size = 512
num_epochs = 20

run_experiment(
    parameters={
        "fmap_factor": 64,
        "conv_ksize": 3,
        "classif_scale": 0.0625,
        "alpha": 0.2,
        "momentum": 0.9,
        "weight_decay": 5e-4,
        "nesterov": True,
        "lr_max_value": 1.0,
        "lr_max_value_epoch": num_epochs // 5,
        "fast_mode": False
    }
)

#### Setup parameters search space

In [None]:
parameters_space = [
    {
        "name": "fmap_factor",
        "type": "range",
        "bounds": [48, 80],
    },
    {
        "name": "conv_ksize",
        "type": "choice",
        "values": [2, 3],
    },
    {
        "name": "classif_scale",
        "type": "range",
        "bounds": [0.00625, 0.250],
    },
    {
        "name": "alpha",
        "type": "range",
        "bounds": [0.1, 0.5],
    },
    {
        "name": "momentum",
        "type": "range",
        "bounds": [0.7, 1.0],
    },
    {
        "name": "weight_decay",
        "type": "range",
        "bounds": [1e-4, 1e-3],
        "value_type": "float",
    },    
    {
        "name": "nesterov",
        "type": "choice",
        "values": [True, False],
    },
    {
        "name": "lr_max_value",
        "type": "range",
        "bounds": [0.1, 1.0],
    },
    {
        "name": "lr_max_value_epoch",
        "type": "range",
        "bounds": [1, 10],
    },
]


### Start tuning

In [None]:
num_epochs = exp_num_epochs = 20


best_parameters, values, experiment, model = optimize(
    parameters=parameters_space,
    evaluation_function=run_experiment,
    objective_name='test accuracy',
    total_trials=30
)


We found the best parameters that give the following outcome:

In [None]:
means, covariances = values
print(f"\nBest parameters: {best_parameters}\n")
print(f"Test accuracy: {means} ± {covariances}")

Let's plot contours showing test accuracy as a function of the two hyperparameters.

In [None]:
render(plot_contour(model=model, param_x='lr_max_value', param_y='momentum', metric_name='test accuracy'))

Let's retrain the model with best found parameters and compare with previous baseline: 

In [None]:
batch_size = 512
num_epochs = 20

best_parameters_copy = dict(best_parameters)
best_parameters_copy['fast_mode'] = False

run_experiment(
    parameters=best_parameters_copy
)

In Tensorboard we can observer a tab "HPARAMS":

![hparams](https://github.com/abdulelahsm/ignite/blob/update-tutorials/examples/notebooks/assets/ax_hparams.png?raw=1)