#FedMD (IV): cooperative training

In [None]:
# Clone GitHub repository

import os

if not os.path.isdir('/content/aml_project'):
  !git clone https://github.com/s295103/aml_project.git
  %cd /content/aml_project
else:
  if os.getcwd() != "/content/aml_project":
    %cd /content/aml_project/
  !git pull origin

import torch
from utils import cifar_processing, load_model
from fedmd import load_clients, Server, save_clients, create_clients
import pickle
import matplotlib.pyplot as plt

ROOT = "/content"

# Baselines folders
BL_PATH = f"{ROOT}/aml_project/results/fedmd/baselines"
UPPER_BL_PATH = f"{ROOT}/aml_project/results/fedmd/upper_baselines"
IID_LOWER_BL_PATH = f"{ROOT}/aml_project/results/fedmd/lower_baselines_iid"
NON_IID_LOWER_BL_PATH = f"{ROOT}/aml_project/results/fedmd/lower_baselines_non_iid"

# CIFAR10 training set will be the public dataset used to compute the consensus
PUB_TR_SET, _, _ = cifar_processing(False, 0, ROOT)

# CIFAR100 test set will be used for testing the cooperative training
PR_TR_SET, _, TEST_SET = cifar_processing(True, 0, ROOT) # Note: clients private data comes from CIFAR100 TRAINING set

PR_NUM_CLASSES = len(PR_TR_SET.classes)

##Cooperative training with IID private data distributions

In [None]:
# Load lower baselines clients
clients = load_clients(IID_LOWER_BL_PATH)

# Load partitions
partitions = pickle.load(open(f"{IID_LOWER_BL_PATH}/partitions.p", "rb"))
if partitions["dataset"] != "CIFAR100" or partitions["alpha"] != 1000:
    raise Exception("Error: wrong partitions file")
else:
    for name, client in clients.items():
        client.private_data = torch.utils.data.Subset(PR_TR_SET, partitions["name"])

# Make results folder
COOP_TRAINING_IID_PATH = f"{ROOT}/coop_training_iid"
if not os.path.isdir(COOP_TRAINING_IID_PATH):
    !mkdir $COOP_TRAINING_IID_PATH

# Initialize server and start the cooperative training
server_kwargs = dict(
    device = "cuda" if torch.cuda.is_available() else "cpu",
    num_classes_pr_data = PR_NUM_CLASSES,
    path = COOP_TRAINING_IID_PATH,
    priv_train_epochs = 5,
    pub_train_epochs = 1,
    lr = 1e-5,
)

server = Server(clients, PUB_TR_SET, TEST_SET, **server_kwargs)

# Declare kw arguments
clients_kwargs = dict(
    device = "cuda" if torch.cuda.is_available() else "cpu",
    lr = 1e-2,
    #weight_decay = 1e-4,
    #momentum = 0.9,
    batch_size = 128,
    path = COOP_TRAINING_IID_PATH,
    num_workers = 8,
)

# Initialize training for all the clients
for name, client in clients.items():
    params = client.model.parameters()
    optimizer = torch.optim.SGD(params, lr = clients_kwargs["lr"])#, momentum=clients_kwargs["momentum"], weight_decay=clients_kwargs["weight_decay"])
    criterion = torch.nn.CrossEntropyLoss()
    client.init_coop_training(optimizer, criterion, **clients_kwargs)

stats = server.coop_training()

In [None]:
# Load best model onto clients and serialize them, without private data
for name, client in clients.items():
    client_data = load_model(f"{COOP_TRAINING_IID_PATH}/{name}_best_model.pth")
    client.model.load_state_dict(client_data["weights"])
    client.private_data = None
save_clients(clients, COOP_TRAINING_IID_PATH)

###Plot results

In [None]:
COOP_TRAINING_IID_PATH = f"{ROOT}/aml_project/results/fedmd/coop_training_iid"

In [None]:
# Unpickle stats
stats = pickle.load(open(f"{COOP_TRAINING_IID_PATH}/stats.p", "rb"))
acc_per_round = stats["acc"]
num_rounds = min([len(r) for r in acc_per_round])

# Load upper and lower baselines data
client_names = list(create_clients().keys())
up_bl_acc = []
low_bl_acc = []
for name in client_names:
    up_bl_acc.append(load_model(f"{UPPER_BL_PATH}/{name}_best_model.pth")["accuracy"])
    low_bl_acc.append(load_model(f"{IID_LOWER_BL_PATH}/{name}_best_model.pth")["accuracy"])


# Plot baselines accuracies and accuracy across rounds
plt.figure(figsize=(9, 4))
x = [_ for _ in range(num_rounds)]
colors = ['b', 'g', 'r', 'c', 'm']
for i in range(len(client_names)):
    low_bl_y = 100*low_bl_acc[i]
    plt.plot(0, low_bl_y, f"{colors[i]}v",)

    up_bl_y = 100*up_bl_acc[i]
    plt.plot(num_rounds, low_bl_y, f"{colors[i]}*")

    y = [100*a for a in acc_per_round[i]]
    y = y[:len(x)]
    plt.plot(x, y, f"{colors[i]}-o", label=client_names[i])

plt.xlabel("Round")
plt.ylabel("Test Accuracy [%]")
plt.title("Accuracy per Round: Non-IID")
plt.grid(True)
plt.legend()
plt.savefig(f"{COOP_TRAINING_IID_PATH}/results")

##Cooperative training with non-IID private data distributions

In [None]:
# Load lower baselines clients
clients = load_clients(NON_IID_LOWER_BL_PATH)

# Load partitions
partitions = pickle.load(open(f"{NON_IID_LOWER_BL_PATH}/partitions.p", "rb"))
if partitions["dataset"] != "CIFAR100" or partitions["alpha"] != 0:
    raise Exception("Error: wrong partitions file")
else:
    for name, client in clients.items():
        client.private_data = torch.utils.data.Subset(PR_TR_SET, partitions["name"])

# Make results folder
COOP_TRAINING_NON_IID_PATH = f"{ROOT}/coop_training_non_iid"
if not os.path.isdir(COOP_TRAINING_NON_IID_PATH):
    !mkdir $COOP_TRAINING_NON_IID_PATH

# Initialize server and start the cooperative training
server_kwargs = dict(
    device = "cuda" if torch.cuda.is_available() else "cpu",
    num_classes_pr_data = PR_NUM_CLASSES,
    path = COOP_TRAINING_NON_IID_PATH,
    priv_train_epochs = 4,
    pub_train_epochs = 1,
    lr = 1e-1,
)

server = Server(clients, PUB_TR_SET, TEST_SET, **server_kwargs)

# Declare kw arguments
clients_kwargs = dict(
    device = "cuda" if torch.cuda.is_available() else "cpu",
    lr = 1e-1,
    #weight_decay = 1e-4,
    #momentum = 0.9,
    batch_size = 128,
    path = COOP_TRAINING_NON_IID_PATH,
    num_workers = 8
)

# Initialize training for all the clients
for name, client in clients.items():
    params = client.model.parameters()
    optimizer = torch.optim.SGD(params, lr = clients_kwargs["lr"])#, momentum=clients_kwargs["momentum"], weight_decay=clients_kwargs["weight_decay"])
    criterion = torch.nn.CrossEntropyLoss()
    client.init_coop_training(optimizer, criterion, **clients_kwargs)

stats = server.coop_training()

In [None]:
# Load best model onto clients and serialize them
for name, client in clients.items():
    client_data = load_model(f"{COOP_TRAINING_NON_IID_PATH}/{name}_best_model.pth")
    client.model.load_state_dict(client_data["weights"])
    client.private_data = None
save_clients(clients, COOP_TRAINING_NON_IID_PATH)

###Plot results

In [None]:
COOP_TRAINING_NON_IID_PATH = f"{ROOT}/aml_project/results/fedmd/coop_training_non_iid"

In [None]:
# Unpickle stats
stats = pickle.load(open(f"{COOP_TRAINING_NON_IID_PATH}/stats.p", "rb"))
acc_per_round = stats["acc"]
num_rounds = min([len(r) for r in acc_per_round])

# Load upper and lower baselines data
client_names = list(create_clients().keys())
up_bl_acc = []
low_bl_acc = []
for name in client_names:
    up_bl_acc.append(load_model(f"{UPPER_BL_PATH}/{name}_best_model.pth")["accuracy"])
    low_bl_acc.append(load_model(f"{NON_IID_LOWER_BL_PATH}/{name}_best_model.pth")["accuracy"])


# Plot baselines accuracies and accuracy across rounds
plt.figure(figsize=(9, 4))
x = [_ for _ in range(num_rounds)]
colors = ['b', 'g', 'r', 'c', 'm']
for i in range(len(client_names)):
    low_bl_y = 100*low_bl_acc[i]
    plt.plot(0, low_bl_y, f"{colors[i]}v",)

    up_bl_y = 100*up_bl_acc[i]
    plt.plot(num_rounds, low_bl_y, f"{colors[i]}*")

    y = [100*a for a in acc_per_round[i]]
    y = y[:len(x)]
    plt.plot(x, y, f"{colors[i]}-o", label=client_names[i])

plt.xlabel("Round")
plt.ylabel("Test Accuracy [%]")
plt.title("Accuracy per Round: Non-IID")
plt.grid(True)
plt.legend()
plt.savefig(f"{COOP_TRAINING_NON_IID_PATH}/results")