#FedMD (III): training clients on their own private dataset

In [None]:
# Clone GitHub repository

import os

if not os.path.isdir('content/aml_project'):
  !git clone https://github.com/s295103/aml_project.git
  %cd /content/aml_project
else:
  %cd /content/aml_project/
  !git pull origin

import torch
import pickle
from utils import cifar_processing, training, load_model
from fedmd import load_clients, make_fc_layer, partition_dataset, save_clients

ROOT = "/content"

# Baselines folder
BL_PATH = f"{ROOT}/aml_project/results/fedmd/baselines"


TRAIN_SET, _, TEST_SET = cifar_processing(True, 0, ROOT)
NUM_CLASSES = len(TRAIN_SET.classes)

##Training on private data with IID data distributions

In [None]:
# Load clients trained on public data (CIFAR10)
clients = load_clients(BL_PATH)

# Set the output layer to CIFAR100 number of classes
for client in clients.values():
  client.model.set_output_layer(make_fc_layer(client.model.fc_in_features, NUM_CLASSES))
  
# Make results folder
IID_LOWER_BL_PATH = f"{ROOT}/lower_baselines_iid"
if not os.path.isdir(IID_LOWER_BL_PATH):
  !mkdir $IID_LOWER_BL_PATH

# Partition CIFAR100
alpha = 1000
partitions, _ = partition_dataset(TRAIN_SET, len(clients), alpha)
part_dict = dict(dataset="CIFAR100", alpha=alpha)
for partition, (name, client) in zip(partitions, clients.items()):
    client.private_data = torch.utils.data.Subset(TRAIN_SET, partition)
    part_dict["name"] = partition
pickle.dump(part_dict, open(f"{IID_LOWER_BL_PATH}/partitions.p", "wb"))

# Declare training keyword arguments
kwargs = dict(
    batch_size = 128,
    device = "cuda" if torch.cuda.is_available() else "cpu",
    num_workers = 8,
    path = IID_LOWER_BL_PATH,
    lr = 1e-1,
    momentum = 0.9,
    weight_decay = 1e-4,
    num_epochs = 0,
    resume_file = None,
    test_freq = 5
  )

# Train
for name, client in clients.items():
  print(f"Training client {name} on its own private data")
  test_acc = training(name, client.model, client.private_data, TEST_SET, **kwargs)
  print(f"Client {name} final test accuracy: {100*test_acc:.1f} %")

In [None]:
# Load best model onto clients and serialize them, without private data
for name, client in clients.items():
    client_data = load_model(f"{IID_LOWER_BL_PATH}/{name}_best_model.pth")
    client.model.load_state_dict(client_data["weights"])
    client.private_data = None
save_clients(clients, IID_LOWER_BL_PATH)

##Training on private data with non-IID data distributions

In [None]:
# Load clients trained on public data (CIFAR10)
clients = load_clients(BL_PATH)

# Set the output layer to CIFAR100 number of classes
for client in clients.values():
  client.model.set_output_layer(make_fc_layer(client.model.fc_in_features, NUM_CLASSES))

# Make results folder
NON_IID_LOWER_BL_PATH = f"{ROOT}/lower_baselines_non_iid"
if not os.path.isdir(NON_IID_LOWER_BL_PATH):
  !mkdir $NON_IID_LOWER_BL_PATH

# Partition CIFAR100
alpha = 0
partitions, _ = partition_dataset(TRAIN_SET, len(clients), alpha)
fpart_dict = dict(dataset="CIFAR100", alpha=alpha)
for partition, (name, client) in zip(partitions, clients.items()):
    client.private_data = torch.utils.data.Subset(TRAIN_SET, partition)
    part_dict["name"] = partition
part_filename = f"{NON_IID_LOWER_BL_PATH}/partitions.p"
pickle.dump(part_dict, open(f"{NON_IID_LOWER_BL_PATH}/partitions.p", "wb"))

# Declare training keyword arguments
kwargs = dict(
    batch_size = 128,
    device = "cuda" if torch.cuda.is_available() else "cpu",
    num_workers = 8,
    path = NON_IID_LOWER_BL_PATH,
    lr = 1e-1,
    momentum = 0.9,
    weight_decay = 1e-4,
    num_epochs = 0,
    resume_file = None,
    test_freq = 5
  )

# Train
for name, client in clients.items():
  print(f"Training client {name} on its own private data")
  test_acc = training(name, client.model, client.private_data, TEST_SET, **kwargs)
  print(f"Client {name} final test accuracy: {100*test_acc:.1f} %")

In [None]:
# Load best model onto clients and serialize them, without private data
for name, client in clients.items():
    client_data = load_model(f"{NON_IID_LOWER_BL_PATH}/{name}_best_model.pth")
    client.model.load_state_dict(client_data["weights"])
    client.private_data = None
save_clients(clients, NON_IID_LOWER_BL_PATH)