# Training Script for Resnet on CIFAR-10/CIFAR-


Note: this training script is adapted from : https://github.com/gpleiss/temperature_scaling/blob/master/train.py.

The original script is used to train a 40-layer DenseNet-BC on CIFAR-100. I have adapted this script for resnet18 and resnet34 on CIFAR-100.

Originally, I also tried to train Alexnet on CIFAR-100, but noticed that the accuracy was quite low ~ 54%, so I decided to only use Resnet for my experiments.

I also trained Resnet for CIFAR-10, but noticed that the calibration errors were quite low for that, so decided to use CIFAR-100.


I trained Resnet18 and Resnet50 for about 50 epochs, which took about total ~2 hour  on Colab on a T4 GPU.

# NOTE:

This training script saves the entire model state and the list of validation indices used from the training set to be used during calibration in the main notebook.

In [None]:
# Code Libraries needed
# Torch and Torchvision (same as main notebook)

import torch
import os
import time
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data.sampler import SubsetRandomSampler

In [None]:
def load_data_train_cifar100(train_transform, batch_size=64, valid_size=5000):
    train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
    valid_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=False, transform=train_transform)
    indices = torch.randperm(len(train_set))
    train_indices = indices[:len(indices) - valid_size]
    valid_indices = indices[len(indices) - valid_size:] if valid_size else None

    # Creating data loaders for train, validation, and test sets
    train_loader = torch.utils.data.DataLoader(train_set, pin_memory=True, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices))
    valid_loader = torch.utils.data.DataLoader(valid_set, pin_memory=True, batch_size=batch_size, sampler=SubsetRandomSampler(valid_indices))
    return train_loader, valid_loader, valid_indices

def load_data_train_cifar10(train_transform, batch_size=64, valid_size=5000):
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    valid_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=train_transform)
    indices = torch.randperm(len(train_set))
    train_indices = indices[:len(indices) - valid_size]
    valid_indices = indices[len(indices) - valid_size:] if valid_size else None

    # Creating data loaders for train, validation, and test sets
    train_loader = torch.utils.data.DataLoader(train_set, pin_memory=True, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices))
    valid_loader = torch.utils.data.DataLoader(valid_set, pin_memory=True, batch_size=batch_size, sampler=SubsetRandomSampler(valid_indices))
    return train_loader, valid_loader, valid_indices

In [None]:
def get_resnet18_model(num_classes=100):
    model = torchvision.models.resnet18(pretrained=False, num_classes=num_classes)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    return model

def get_resnet34_model(num_classes=100):
    model = models.resnet34(pretrained=False, num_classes=num_classes)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    return model

def get_resnet50_model(num_classes=100):
    model = models.resnet50(pretrained=False, num_classes=num_classes)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    return model

In [None]:
class AlexNet(nn.Module):
  def __init__(self, classes=100, dropout=0.1):
    super(AlexNet, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),
      nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),
      nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),
    )
    self.classifier = nn.Sequential(
      nn.Dropout(p=dropout),
      nn.Linear(256 * 1 * 1, 4096),
      nn.ReLU(inplace=True),
      nn.Dropout(p=dropout),
      nn.Linear(4096, 4096),
      nn.ReLU(inplace=True),
      nn.Linear(4096, classes),
    )

  def forward(self, x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)
    return x

def get_alexnet_model():
    return AlexNet()

In [None]:
# mean = [0.5071, 0.4867, 0.4408]
# stdv = [0.2675, 0.2565, 0.2761]

# mean and std of CIFAR-10
mean_cifar10 = [0.4915, 0.4822, 0.4466]
std_cifar10 = [0.2463, 0.2428, 0.2607]

# mean and std of CIFAR-100
mean_cifar100 = [0.5070, 0.4865, 0.4408]
std_cifar100 = [0.2664, 0.2555, 0.2750]

train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean_cifar100, std=std_cifar100),
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean_cifar100, std=std_cifar100),
])

In [None]:
train_loader, valid_loader, valid_indices = load_data_train_cifar100(train_transforms, batch_size=64)
# train_loader, valid_loader, valid_indices = load_data_train_cifar10(train_transforms, batch_size=64)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 42895955.11it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:
class Meter():
    """
    A little helper class which keeps track of statistics during an epoch.
    """
    def __init__(self, name, cum=False):
        """
        name (str or iterable): name of values for the meter
            If an iterable of size n, updates require a n-Tensor
        cum (bool): is this meter for a cumulative value (e.g. time)
            or for an averaged value (e.g. loss)? - default False
        """
        self.cum = cum
        if type(name) == str:
            name = (name,)
        self.name = name

        self._total = torch.zeros(len(self.name))
        self._last_value = torch.zeros(len(self.name))
        self._count = 0.0

    def update(self, data, n=1):
        """
        Update the meter
        data (Tensor, or float): update value for the meter
            Size of data should match size of ``name'' in the initialized args
        """
        self._count = self._count + n
        if torch.is_tensor(data):
            self._last_value.copy_(data)
        else:
            self._last_value.fill_(data)
        self._total.add_(self._last_value)

    def value(self):
        """
        Returns the value of the meter
        """
        if self.cum:
            return self._total
        else:
            return self._total / self._count

    def __repr__(self):
        return '\t'.join(['%s: %.5f (%.3f)' % (n, lv, v)
            for n, lv, v in zip(self.name, self._last_value, self.value())])


def run_epoch(loader, model, criterion, optimizer, epoch=0, n_epochs=0, train=True):
    time_meter = Meter(name='Time', cum=True)
    loss_meter = Meter(name='Loss', cum=False)
    error_meter = Meter(name='Error', cum=False)

    if train:
        model.train()
        print('Training')
    else:
        model.eval()
        print('Evaluating')

    end = time.time()
    for i, (input, target) in enumerate(loader):
        if train:
            model.zero_grad()
            optimizer.zero_grad()

            # Forward pass
            input = input.cuda()
            target = target.cuda()
            output = model(input)
            loss = criterion(output, target)

            # Backward pass
            loss.backward()
            optimizer.step()
            optimizer.n_iters = optimizer.n_iters + 1 if hasattr(optimizer, 'n_iters') else 1

        else:
            with torch.no_grad():
                # Forward pass
                input = input.cuda()
                target = target.cuda()
                output = model(input)
                loss = criterion(output, target)

        # Accounting
        _, predictions = torch.max(output, 1)  # Get the indices of the max logits
        correct = predictions.eq(target).float().sum()  # Count correct predictions
        total = target.size(0)  # Total number of examples
        accuracy = correct / total
        error = 1 - accuracy
        batch_time = time.time() - end
        end = time.time()

        # Log errors
        time_meter.update(batch_time)
        loss_meter.update(loss)
        error_meter.update(error)
        print('  '.join([
            '%s: (Epoch %d of %d) [%04d/%04d]' % ('Train' if train else 'Eval',
                epoch, n_epochs, i + 1, len(loader)),
            str(time_meter),
            str(loss_meter),
            str(error_meter),
        ]))

    return time_meter.value(), loss_meter.value(), error_meter.value()

In [None]:
def train(checkpointing_dir, model, model_name, n_epochs, lr=0.01, wd=0.0001, momentum=0.9):
    model = model.cuda()
    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.AdamW(model.parameters(), lr=lr)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, nesterov=True)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs], gamma=0.1)
    # Train model
    best_error = 1
    for epoch in range(1, n_epochs + 1):
        scheduler.step()
        run_epoch(
            loader=train_loader,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            epoch=epoch,
            n_epochs=n_epochs,
            train=True,
        )
        valid_results = run_epoch(
            loader=valid_loader,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            epoch=epoch,
            n_epochs=n_epochs,
            train=False,
        )

        # Determine if model is the best
        _, _, valid_error = valid_results
        if valid_error[0] < best_error:
            best_error = valid_error[0]
            print('New best error: %.4f' % best_error)

            # When we save the model, we're also going to include the validation indices
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_error': best_error,
            }
            torch.save(state, os.path.join(checkpointing_dir, f'{model_name}_cifar100.pth'))
            #torch.save(model.state_dict(), os.path.join(checkpointing_dir, 'model_resnet18.pth'))
            torch.save(valid_indices, os.path.join(checkpointing_dir, f'valid_indices_{model_name}_cifar100.pth'))

In [None]:
model = get_resnet18_model(num_classes=100)
checkpointing_dir = "./trained_models"
if not os.path.exists(checkpointing_dir):
        os.makedirs(checkpointing_dir)
if not os.path.isdir(checkpointing_dir):
    raise Exception('%s is not a dir' % checkpointing_dir)
train(checkpointing_dir, model, "resnet18", n_epochs=30)