## 0. Importing Modules

The necessary modules are imported here. Python 3 is used with PyTorch that works with CUDA. Final training is done on Kagle.

In [None]:
import re

import numpy as np              # NumPy, for working with arrays/tensors
import matplotlib.pyplot as plt # For plotting

# PyTorch libraries:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

!pip install torchinfo
from torchinfo import summary

!wget https://github.com/huyvnphan/PyTorch_CIFAR10/archive/refs/tags/v3.0.1.zip
!unzip v3.0.1.zip
!pip install pytorch_lightning
!cp -r PyTorch_CIFAR10-3.0.1/cifar10_models cifar10_models
!python PyTorch_CIFAR10-3.0.1/train.py --download_weights 1
!rm -rf PyTorch_CIFAR10-3.0.1/
!rm -f v3.0.1.zip
!rm -f state_dicts.zip

# For pre-trained ResNet on CIFAR-10
from cifar10_models import resnet

# enable CUDA support if possible
if torch.cuda.is_available():
    print("Cuda (GPU support) is available and enabled!")
    device = torch.device("cuda")
else:
    print("Cuda (GPU support) is not available :(")
    device = torch.device("cpu")

## 1. Image Classification

This part of the implementation deals with reproducing Image Classification related experimental results of the original paper.

### 1.1. Loading the Dataset

For image classification, paper used the ImageNet dataset on training process, but it is a huge dataset in size and number of images. This would increase the training time due to limited computing resources of mine. So, I tried to reproduce the improvements with `CIFAR-10` in the image classification part of experiments. 
Using PyTorch utilities we are able to load the `CIFAR-10` dataset easily, since it is considered a standard dataset for Deep Learning applications.

In [None]:
def load_cifar10():
        """
        Uses torchvision.datasets.ImageNet to load ImageNet.
        Downloads the dataset when necessary.
        Returns 2 datasets for train and validation.
        """
        
        # This part was resizing to 256 before, this was the cause of dramatically high trianing time and faulty results, it is noticed very late
        # So epoch size had to be decreased to obtain some result
        TF = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(31),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        trainset = torchvision.datasets.CIFAR10('./datasets/CIFAR10/', train=True, download=True, transform=TF)
        valset = torchvision.datasets.CIFAR10('./datasets/CIFAR10/', train=False, download=True, transform=TF)

        return trainset, valset 

### 1.2. Define Modified ResNet Models

The paper uses a wider ResNet for the experimental setup on image classification part. It is `4x` wider than the original ResNet implementations. Imitating PyTorch's `2x` wider ResNet implementation I was able to acquire the model. Also, PyTorch Hooks are used to read intermadiate results between blocks.

In general, paper describes layer-wise output matching is achived with appropriately sized linear transformations. But for convolutional layers, these cannot be merged after the process. ResNet output channels are same between layers on reference implementation of PyTorch anyway. So, no linear transformation is needed for this case.

The code below is taken and modified from the ResNet implementation that is trained on CIFAR-10 where needed. The original repo for those networks can be reached from [here](https://github.com/huyvnphan/PyTorch_CIFAR10).

In [None]:
def wide_resnet50_4(**kwargs):
    kwargs['width_per_group'] = 64 * 4
    model = resnet.ResNet(models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs)
    return model

def hook(module, input, output):
    setattr(module, "_value_hook", output)

# Override forward-pass and initialization
class ResNetMod(resnet.ResNet):
    def __init__(self, block, layers, num_classes=10, zero_init_residual=False, group=1, 
                 width_per_group=64, replace_stride_with_dilation=None, norm_layer=None):
        super(ResNetMod, self).__init__(block, layers, num_classes, zero_init_residual, group, 
                 width_per_group, replace_stride_with_dilation, norm_layer)   
        self.layer_idx = 5
        
    def forward(self, x):
        """
        x: input tensor
        layer_idx: idx of the layer targetted, from 1 to 4
        """
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        if (self.layer_idx >= 1):
            x = self.layer1(x)
            
        if (self.layer_idx >= 2):
            x = self.layer2(x)
        
        if (self.layer_idx >= 3):
            x = self.layer3(x)
        if (self.layer_idx >= 4):
            x = self.layer4(x)
        
        if (self.layer_idx == 5):
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x

def resnetmod50(**kwargs):
    kwargs['width_per_group'] = 64 // 2
    model = ResNetMod(models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs)
    return model

### 1.3. Create ResNet Models

We need to create baseline ResNet50, Wide Teacher ResNet50 and modified ResNet50 for layerwise imitation training. Also, hooks are registered to teacher model to acquire intermediate forward outputs between layers.

For baseline, a pretrained ResNet50 with `1/2` width cannot be found, so student will be check against teacher. But, for ensuring there is no bug and teacher parameters do not get updated, a baseline is also created. To gain time pretrained models are used for teacher and baseline.

In [None]:
teacher = resnet.resnet50(pretrained=True)
student = resnetmod50()
baseline = resnet.resnet50(pretrained=True)

for n, m in teacher.named_modules():
    match = re.search('layer[1234]$', n)
    if match:
        m.register_forward_hook(hook)

In [None]:
print('Teacher')
summary(teacher)

In [None]:
print('Student')
summary(student)

In [None]:
print('Baseline')
summary(baseline)

### 1.4. Training Process

We need to create a training method and apply training using `CIFAR-10` dataset. I made use of `CENG501` assignments for this part.

Batch Normalization decay is `0.9` as a default to my understanding. It is `0.1` as default, but substracted from `1` in the source code. Also, it is phrased as `momentum` in PyTorch.

In [None]:
# epoch assumed to be starting from 0, ending at 89
def vanilla_lr(epoch):
    if (epoch < 5):
        return 0.025 * epoch # linearly increasing from 0 to 0.1 at epochs from 0 to 4
    if (epoch < 29):
        return 0.1
    if (epoch < 59):
        return 0.01  # reduced 10x at 30
    if (epoch < 79):
        return 0.001 # reduced 10x at 60
    return 0.0001    # reduced 10x at 80

def vanilla_train(model, dataloader, multibatch=1, verbose=True):
    """
    Returns: the loss history.
    """
    loss_history = []
    criterion = nn.CrossEntropyLoss()
    for epoch in range(90):
        optimizer = torch.optim.SGD(model.parameters(), lr=vanilla_lr(epoch), momentum=0.9, weight_decay=1e-4)
        for i, data in enumerate(dataloader, 0):    
            print('Training, batch:', i)
            # our batch:
            inputs, truth = data
            inputs = inputs.to(device)
            truth = truth.to(device)

            # zero the gradients as PyTorch accumulates them 
            # every 4 batch since original paper uses batch size of 256
            if (i + 1) % multibatch == 0:
                optimizer.zero_grad()

            # obtain the scores
            outputs = model(inputs)
            # Calculate loss
            loss = criterion(outputs.to(device), truth)

            # backpropagate
            loss.backward()

            # update the weights
            # every 4 batch since original paper uses batch size of 256
            if (i + 1) % multibatch == 0:
                optimizer.step()

            loss_history.append(loss.item())

        if verbose: print(f'Epoch {epoch + 1} / {epochs}: avg. loss of last 5 iterations {np.sum(loss_history[:-6:-1])/5}')

    return loss_history

def train_epoch(model, teacher, criterion, optimizer, dataloader, loss_history, multibatch=1):
    for i, data in enumerate(dataloader, 0):
        if (i + 1) % 100 == 0:
            print('At step: ', i + 1)
        # our batch:
        inputs, _ = data
        inputs = inputs.to(device)
        # No gradient calculation is needed on teacher, reduce memory footprint
        with torch.no_grad():
            truth = teacher(inputs)
            if model.layer_idx < 5:
                for n, m in teacher.named_modules():
                    if n == 'layer' + str(model.layer_idx):
                        truth = m._value_hook
                        break
            truth = truth.cuda(non_blocking=True)
                
        # zero the gradients as PyTorch accumulates them 
        # every 4 batch since original paper uses batch size of 256
        if (i + 1) % multibatch == 0:
            optimizer.zero_grad()

        # obtain the scores
        outputs = model(inputs)
        
        # Calculate loss
        if model.layer_idx < 5:
            loss = criterion(outputs.to(device), truth)
        else:
            loss = criterion(F.log_softmax(outputs.to(device), dim=1), F.log_softmax(truth, dim=1)) # dim=1 since log_prob should be calculated for each batch
        
        # backpropagate
        loss.backward()

        # update the weights
        # every 4 batch since original paper uses batch size of 256
        if (i + 1) % multibatch == 0:
            optimizer.step()

        loss_history.append(loss.item())

def win_train(model, teacher, dataloader, multibatch=1, verbose=True):
    """
    Returns: the loss history
    """
    loss_history = []
    criterion_s1 = nn.MSELoss()
    criterion_s2 = nn.KLDivLoss(reduction='batchmean', log_target=True) # We are converting teacher model output to log probabilities also
    criterion_s3 = nn.CrossEntropyLoss()
    optimizer_s2 = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    optimizer_s3 = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    scheduler_s2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_s2, 4, eta_min=0, last_epoch=-1)
    
    for l in range(4): # for each layer block
        optimizer_s1 = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
        scheduler_s1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_s1, 2, eta_min=0, last_epoch=-1)
        model.layer_idx = l + 1
        for epoch in range(2):
            train_epoch(model, teacher, criterion_s1, optimizer_s1, dataloader, loss_history, multibatch=multibatch)
            scheduler_s1.step()
            if verbose: 
                print(f'Epoch {epoch + 1} / {2}: avg. loss of last 5 iterations {np.sum(loss_history[:-6:-1])/5}')
    
    model.layer_idx = 5
    
    for epoch in range(4):
        train_epoch(model, teacher, criterion_s2, optimizer_s2, dataloader, loss_history, multibatch=multibatch)
        scheduler_s2.step()
        if verbose: 
            print(f'Epoch {epoch + 1} / {4}: avg. loss of last 5 iterations {np.sum(loss_history[:-6:-1])/5}')

    for epoch in range(4):
        for i, data in enumerate(dataloader, 0):    
            if (i + 1) % 100 == 0:
                print('At step: ', i + 1)
            # our batch:
            inputs, truth = data
            inputs = inputs.to(device)
            truth = truth.to(device)

            # zero the gradients as PyTorch accumulates them 
            # every 4 batch since original paper uses batch size of 256
            if (i + 1) % multibatch == 0:
                optimizer_s3.zero_grad()

            # obtain the scores
            outputs = model(inputs)
            # Calculate loss
            loss = criterion_s3(outputs.to(device), truth)

            # backpropagate
            loss.backward()

            # update the weights
            # every 4 batch since original paper uses batch size of 256
            if (i + 1) % multibatch == 0:
                optimizer_s3.step()

            loss_history.append(loss.item())

        if verbose: 
            print(f'Epoch {epoch + 1} / {4}: avg. loss of last 5 iterations {np.sum(loss_history[:-6:-1])/5}')

    return loss_history

# batch size used in the paper for Vanilla Training
batch_size = 32
# loaders for datasets
train_dataset, val_dataset = load_cifar10()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
eval_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

model = student.to(device)
teacher = teacher.to(device)
loss_history = win_train(model, teacher, train_loader, multibatch=8)

### 1.4. Quantitative Analysis

Again, using the code from the assignments, I plotted the loss function.

Actually longer training times were needed. But, initially a crutual mistake of resizing the input image wrongly to 256 since this ResNet is modified and works with `CIFAR-10` images natively, increased the training time. Since I noticed it late, I had to produce a result at least, despite the training time is shortened with not resizing to 256, remaining time were limited. Still, loss plot clearly indicates layerwise imitation is highly efficient since Kullback-Leibler Divergence stage is observed to be decreasing slower than the first stage.

In [None]:
%matplotlib inline
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 200

plt.plot(loss_history)
plt.xlabel('Iteration number')
plt.ylabel('Loss value')
plt.show()

### 1.5. Accuracy Analysis

Using the code from PAs of the lecture and modifying it slightly I was able to acquire the Top-1 accuracy values.

Unfortunately, accuracy of generated student network is not high enough the draw conclusions. Still, despite 2 epochs per layer, 4 epochs of output imitation and 4 epochs of finetuning is pretty low number of epochs in total for a newly initialized network. And considering the total epoch count, good results are obtained.

There seems to be a small difference between teacher and baseline despite `no_grad` is used for teacher at training. But it is very small difference and could be related to used implementation of ResNet for `CIFAR-10`.

In [None]:
def calc_top1(model, testloader):
    correct = 0
    total = 0

    model = model.to(device)

    with torch.no_grad():
        for i, data in enumerate(testloader, 0):
            if (i + 1) % 50 == 0:
                print('Eval at: ', i + 1)
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return (100 * correct / total)

student.eval()
teacher.eval()
baseline.eval()

accuracy_student = calc_top1(student, eval_loader)
print('Accuracy of the student: %d %%' % accuracy_student)
accuracy_teacher = calc_top1(teacher, eval_loader)
print('Accuracy of the teacher: %d %%' % accuracy_teacher)
accuracy_baseline = calc_top1(baseline, eval_loader)
print('Accuracy of the baseline: %d %%' % accuracy_baseline)