#FedMD (II): training clients on the full private dataset

In [None]:
# Clone GitHub repository

import os

if not os.path.isdir('content/aml_project'):
  !git clone https://github.com/s295103/aml_project.git
  %cd /content/aml_project
else:
  if os.getcwd() != "/content/aml_project":
    %cd /content/aml_project/
  !git pull origin

import torch
from utils import cifar_processing, training, load_model
from fedmd import make_fc_layer, load_clients, save_clients

ROOT = "/content"

# Baselines folder
BL_PATH = f"{ROOT}/aml_project/results/fedmd/baselines"

# Load clients trained on public data (CIFAR10)
clients = load_clients(BL_PATH)

# Load and preprocess CIFAR100
train_set, _, test_set = cifar_processing(True, 0, ROOT)
num_classes = len(train_set.classes)

# Set the output layer to CIFAR100 number of classes
for client in clients.values():
  client.model.set_output_layer(make_fc_layer(client.model.fc_in_features, num_classes))

##Train on CIFAR100

In [None]:
# Make results folder
UPPER_BL_PATH = f"{ROOT}/upper_baselines"
if not os.path.isdir(UPPER_BL_PATH):
  !mkdir $UPPER_BL_PATH
  
# Training function keyword arguments
kwargs = dict(
    batch_size = 128,
    device = "cuda" if torch.cuda.is_available() else "cpu",
    num_workers = 8,
    path = UPPER_BL_PATH,
    lr = 1e-1,
    momentum = 0.9,
    weight_decay = 1e-4,
    num_epochs = 0,
    resume_file = None,
    test_freq = 5
  )

# Train
for name, client in clients.items():
  print(f"Training client {name} on CIFAR100")
  test_acc = training(name, client.model, train_set, test_set, **kwargs)
  print(f"Client {name} final test accuracy: {100*test_acc:.1f} %")

In [None]:
# Load best model onto clients and serialize them
for name, client in clients.items():
    client_data = load_model(f"{UPPER_BL_PATH}/{name}_best_model.pth")
    client.model.load_state_dict(client_data["weights"])
save_clients(clients, UPPER_BL_PATH)