<a href="https://colab.research.google.com/github/s295103/aml_project/blob/main/train_baselines.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Training ResNet20
In this notebook we train ResNet-20 on CIFAR100 in a centralized way, as a baseline to test the FedAvg algorithm.


##Clone GitHub repository and import libraries

In [None]:
import os

if not os.path.isdir('content/aml_project'):
  !git clone https://github.com/s295103/aml_project.git
  %cd /content/aml_project
else:
  if os.getcwd() != "/content/aml_project":
    %cd /content/aml_project/
  !git pull origin

import torch
import matplotlib.pyplot as plt
from utils import *
from architectures import *

ROOT = "/content"

if not os.path.isdir(f'{ROOT}/models'):
  !mkdir /content/models

##Train ResNet20 on Cifar100

In [None]:
!mkdir models/cifar100/

# Get training and test set from CIFAR100
TRAIN_SET, _, TEST_SET = cifar_processing(True, 0, ROOT)

PATH = f"{ROOT}/models/cifar100"

# Define hyperparameters and other training arguments
KWARGS = dict(
      batch_size = 128,
      device = "cuda" if torch.cuda.is_available() else "cpu",
      num_workers = 8,
      path = PATH,
      lr = 1e-1,
      momentum = 0.9,
      weight_decay = 1e-4,
      num_epochs = 160,
      resume_file = None,
      test_freq = 10
)

NUM_CLASSES = len(TRAIN_SET.classes)
GROUPS = 2
LAYERS = [3, 3, 3]

net = ResNet(BasicBlock, LAYERS, NUM_CLASSES, norm_layer="bn")
print(f"Num Parameters = {model_size(net)}")
test_acc = training("resnet20bn", net, TRAIN_SET, TEST_SET, **KWARGS)
print(f"Test Accuracy: {100*test_acc:.1f} %")

net = ResNet(BasicBlock, LAYERS, NUM_CLASSES, groups=GROUPS, norm_layer="gn")
print(f"Num Parameters = {model_size(net)}")
test_acc = training("resnet20gn", net, TRAIN_SET, TEST_SET, **KWARGS)
print(f"Test Accuracy: {100*test_acc:.1f} %")

##Plot loss and test accuracy per epoch

In [None]:
epochs_bn, loss_bn, acc_bn = read_stats(f"{PATH}/resnet20bn_stats.csv")
epochs_gn, loss_gn, acc_gn = read_stats(f"{PATH}/resnet20gn_stats.csv")


acc_bn = [100*a for a in acc_bn]
acc_gn = [100*a for a in acc_gn]

fig, (loss_ax, acc_ax) = plt.subplots(1, 2, figsize=(9, 4))

loss_ax.set_xlim(0, max(epochs_bn))
loss_ax.set_ylim(0, max(loss_bn + loss_gn)+1)
loss_ax.set_xlabel("Epochs")
loss_ax.set_ylabel("Loss")
loss_ax.grid()
loss_ax.plot(epochs_bn, loss_bn, label="ResNet20_BN")
loss_ax.plot(epochs_gn, loss_gn, label="ResNet20_GN")
loss_ax.legend()

acc_ax.set_xlim(0, max(epochs_bn))
acc_ax.set_ylim(0, max(acc_bn + acc_gn)+1)
acc_ax.set_xlabel("Epochs")
acc_ax.set_ylabel("Accuracy [%]")
acc_ax.grid()
acc_ax.plot(epochs_bn, acc_bn, label="ResNet20_BN")
acc_ax.plot(epochs_gn, acc_gn, label="ResNet20_GN")
acc_ax.legend()

fig.tight_layout()
fig.savefig(PATH + "/results")
