# Setup

Important: before running this notebook, make sure you have the following directories in your Drive:
- MyDrive/understanding_federated_learning_project/
- MyDrive/understanding_federated_learning_project/fedmd_collaboration_results_1000.0/
- MyDrive/understanding_federated_learning_project/fedmd_collaboration_results_0.0/

In [None]:
from google.colab import drive, files
drive.mount('/content/drive')

Mounted at /content/drive


## Install required libraries

In [None]:
!pip3 install 'torch'
!pip3 install 'torchvision'
!pip3 install 'pillow'
!pip3 install 'tqdm'
!pip3 install 'matplotlib'

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Clone the GitHub repository 

In [None]:
import os

if not os.path.isdir('fedmd'):
  !git clone https://github.com/sergiuabed/fedmd
  %cd fedmd/
  !git checkout experiment
  %cd ..
else:
  %cd fedmd/
  !git pull origin
  %cd ..

if not os.path.isdir('fedmd/baselines_public_cifar10'):
  !mkdir fedmd/baselines_public_cifar10

Cloning into 'fedmd'...
remote: Enumerating objects: 2815, done.[K
remote: Total 2815 (delta 0), reused 0 (delta 0), pack-reused 2815[K
Receiving objects: 100% (2815/2815), 492.95 MiB | 29.90 MiB/s, done.
Resolving deltas: 100% (823/823), done.
/content/fedmd
Updating files: 100% (816/816), done.
Branch 'experiment' set up to track remote branch 'experiment' from 'origin'.
Switched to a new branch 'experiment'
/content


## Import libraries

In [None]:
from fedmd.models_implementations.resnet import ResNet
from fedmd.models_implementations.densenet import DenseNet
from fedmd.models_implementations.shufflenetv2 import ShuffleNetV2
from fedmd.models_implementations.utils import model_size, plot_stats
from fedmd.models_implementations.train_on_cifar import train_on_cifar
from fedmd.data_utils import read_data_splits
from fedmd.client.private_dataloader import ClientPrivateDataset
from torch.utils.data import DataLoader
from fedmd.client.client import Client

## Load CIFAR100 locally

In [None]:
%cd fedmd/data
!chmod +x setup_datasets.sh
!./setup_datasets.sh
%cd ../..

/content/fedmd/data
Downloading CIFAR100...
--2023-05-15 17:22:58--  https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 169001437 (161M) [application/x-gzip]
Saving to: ‘cifar-100-python.tar.gz’


2023-05-15 17:23:00 (84.3 MB/s) - ‘cifar-100-python.tar.gz’ saved [169001437/169001437]

cifar-100-python/
cifar-100-python/file.txt~
cifar-100-python/train
cifar-100-python/test
cifar-100-python/meta
Extracting images for pickle files...
#### Setting up CIFAR100 ####
Saving train images...
50000it [00:27, 1806.84it/s]
Saving test images...
10000it [00:06, 1596.74it/s]
/content


## Define hyperparameters
This cell defines the hyperparameters for each architecture

In [None]:
LR = 1e-1     
MOMENTUM = 0.9       
WEIGHT_DECAY = 1e-4  
NUM_EPOCHS = 30
NUM_CLASSES = 100
GROUPS = 2

######################
ALPHA = 1000.00 #0.00
######################

#ResNet
LAYERS = [3, 3, 3]
FILE_PATH =  os.getcwd() + '/fedmd/baselines_public_cifar10/resnet20'

#DenseNet
K = 12
FILE_PATH =  os.getcwd() + '/fedmd/baselines_public_cifar10/densenet'

#ShuffleNet
FILE_PATH =  os.getcwd() + '/fedmd/baselines_public_cifar10/shufflenetv2'
STAGES_REPEATS_BIG = [4, 8, 4]
STAGES_OUT_CHANNELS_BIG = [32, 64, 128, 256, 512]
STAGES_REPEATS_SMALL = [2, 4, 2]
STAGES_OUT_CHANNELS_SMALL = [16, 32, 64, 128, 256]

RESNET20_BN_PARAMS = [LAYERS, NUM_CLASSES] #net = ResNet(layers = LAYERS, num_classes = NUM_CLASSES)
RESNET20_GN_PARAMS = [LAYERS, NUM_CLASSES, "gn", GROUPS] #net = ResNet(layers = LAYERS, num_classes = NUM_CLASSES, norm_layer="gn", groups=GROUPS)

DENSENET10_PARAMS = [K, 10, 1, NUM_CLASSES, False, GROUPS]  #net = DenseNet(K, 10, 1, NUM_CLASSES, False, GROUPS)
DENSENET20_PARAMS = [K, 20, 1, NUM_CLASSES, False, GROUPS]  #net = DenseNet(K, 20, 1, NUM_CLASSES, False, GROUPS)

SHUFFLENET_BIG = [STAGES_REPEATS_BIG, STAGES_OUT_CHANNELS_BIG, NUM_CLASSES, GROUPS] #net = ShuffleNetV2(stages_repeats, stages_out_channels, NUM_CLASSES, GROUPS)
SHUFFLENET_SMALL = [STAGES_REPEATS_SMALL, STAGES_OUT_CHANNELS_SMALL, NUM_CLASSES, GROUPS] #net = ShuffleNetV2(stages_repeats, stages_out_channels, NUM_CLASSES, GROUPS)

#architectures = ["/resnet20/resnet20_bn", "/resnet20/resnet20_gn", "/densenet/densenet10", "/densenet/densenet20", "/shufflenetv2/shufflenetbig", "/shufflenetv2/shufflenetsmall"]

archs_dir = {}
archs_dir["/resnet20/resnet20_bn"] = (ResNet, RESNET20_BN_PARAMS)
archs_dir["/resnet20/resnet20_gn"] = (ResNet, RESNET20_GN_PARAMS)
archs_dir["/densenet/densenet10"] = (DenseNet, DENSENET10_PARAMS)
archs_dir["/densenet/densenet20"] = (DenseNet, DENSENET20_PARAMS)
archs_dir["/shufflenetv2/shufflenetbig"] = (ShuffleNetV2, SHUFFLENET_BIG)
archs_dir["/shufflenetv2/shufflenetsmall"] = (ShuffleNetV2, SHUFFLENET_SMALL)


In [None]:
from fedmd.models_implementations.train_on_cifar import _training, _data_processing
from fedmd.client.private_dataloader import ClientPrivateDataset
from fedmd.data_utils import read_data_splits
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision import transforms

BATCH_SIZE = 128
NUM_WORKERS = 1

#public_train_dataloader, public_validation_dataloader, public_test_dataloader = _data_processing(CIFAR10) # only "public_train_set" is used at this stage of the project

public_train_transform = transforms.Compose(
    [
        transforms.ToTensor(),  # Turn PIL Image to torch.Tensor
        transforms.Normalize(
            (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        ),  # Normalizes tensor with mean and standard deviation
    ]
)
public_train_dataset = CIFAR10(root=".", train=True, transform=public_train_transform, download=True)
public_train_dataloader = DataLoader(public_train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, drop_last=True)

PRIVATE_TRAIN_DATA_DIR = os.path.join('.', 'fedmd', 'data', 'cifar100', 'data', 'train') # location of json files storing the data splits
PRIVATE_TEST_DATA_DIR = os.path.join('.', 'fedmd', 'data', 'cifar100', 'data', 'test')

client_ids, train_data, test_data = read_data_splits(PRIVATE_TRAIN_DATA_DIR, PRIVATE_TEST_DATA_DIR, ALPHA)
# train_clients: list of client ids 
# train_data: dictionary with key=client_id and value=(dictionary storing the data of the client)
# test_data: dictionary storing the data for validation. It is not a dictionary of dictionaries

private_test_dataset = CIFAR100(root=".", train=True, transform=public_train_transform, download=True)   # the same test dataset for all clients
private_test_dataloader = DataLoader(private_test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

LOAD_CHECKPOINTS_FROM_PAST_ROUNDS = False
last_round_performed = -1


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 90860903.52it/s]


Extracting ./cifar-10-python.tar.gz to .
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:01<00:00, 95730697.08it/s] 


Extracting ./cifar-100-python.tar.gz to .


## IMPORTANT: RUN THE FOLLOWING CELL ONLY IF YOU HAVE CHECKPOINTS FROM PREVIOUS ROUND EXECUTIONS STORED IN GOOGLE DRIVE. IF THIS IS THE FIRST TIME YOU RUN THIS NOTEBOOK, SKIP THIS CELL.

In [None]:
if not os.path.isdir('fedmd_results'):
  LOAD_CHECKPOINTS_FROM_PAST_ROUNDS = True

  last_round_performed = 5 # change this variable to the latest round you have ran and saved on your google drive
  drive_path = f"/content/drive/MyDrive/understanding_federated_learning_project/fedmd_collaboration_results_{ALPHA}/fedmd_results_round{last_round_performed}.zip"
  !unzip {drive_path} -d .
  !mv ./content/fedmd_results .
  !rmdir ./content

Archive:  /content/drive/MyDrive/understanding_federated_learning_project/fedmd_collaboration_results_0.0/fedmd_results_round5.zip
   creating: ./content/fedmd_results/
   creating: ./content/fedmd_results/client92/
  inflating: ./content/fedmd_results/client92/best_model_0.0.pth  
  inflating: ./content/fedmd_results/client92/checkpoint_0.0.pth  
 extracting: ./content/fedmd_results/client92/stats_0.0.csv  
   creating: ./content/fedmd_results/client54/
   creating: ./content/fedmd_results/client62/
   creating: ./content/fedmd_results/client81/
  inflating: ./content/fedmd_results/client81/best_model_0.0.pth  
  inflating: ./content/fedmd_results/client81/checkpoint_0.0.pth  
 extracting: ./content/fedmd_results/client81/stats_0.0.csv  
   creating: ./content/fedmd_results/client15/
   creating: ./content/fedmd_results/client60/
  inflating: ./content/fedmd_results/client60/best_model_0.0.pth  
  inflating: ./content/fedmd_results/client60/checkpoint_0.0.pth  
  inflating: ./content/

## Create clients
Clients are created based on the informtation in:
  - "fedmd/client/client_architectures.csv" (which contains entries of the form "client_id,architecture")
  - "fedmd/independent_train" (which stores checkpoints of the models for each client. These checkpoints are obtained after training on both public and private datasets)
  
note: If you ran this notebook before for one or more rounds, you can load the backup stored on google drive to continue the execution from where you stopped

In [None]:
import csv
import torch.nn as nn
from fedmd.models_implementations.utils import load_model
from fedmd.client.client import Client

filename ="fedmd/client/client_architectures.csv"
# this csv has records with attributes "client_id"(0,1,2,3,...) and "architecture" (the keys used in "archs_dir" above)

client_archs = {} #key=client_id, val=model instance
clients = []
with open(filename,'r') as data:
  for line in csv.reader(data):
    if line[0] != 'client_id':
      client_id = str(line[0])
      architecture = line[1]
      
      net_class = archs_dir[architecture][0]
      params = archs_dir[architecture][1]

      # instantiate model for the client
      model = net_class(*params)

      # load parameters of best model on the public dataset (CIFAR10) + private dataset (subset of CIFAR100)
      #data = load_model('fedmd/independent_train/client'+str(client_id)+'/best_model_'+str(ALPHA)+'.pth')

      if LOAD_CHECKPOINTS_FROM_PAST_ROUNDS is True and os.path.exists('fedmd_results/client'+str(client_id)+'/checkpoint_'+str(ALPHA)+'.pth'):
        #load checkpoint from previous FedMD executions
        data = load_model('fedmd_results/client'+str(client_id)+'/checkpoint_'+str(ALPHA)+'.pth')
      else:
        #load checkpoint from trainings done only on public and private datasets, i.e. before any collaborative training
        data = load_model('fedmd/independent_train_'+str(ALPHA)+'/client'+str(client_id)+'/checkpoint_'+str(ALPHA)+'.pth')
        #data = load_model('fedmd/independent_train_'+str(ALPHA)+'_v2'+'/client'+str(client_id)+'/checkpoint_'+str(ALPHA)+'.pth')

      weights = data['weights']
      model.load_state_dict(weights)

      private_train_dataset = ClientPrivateDataset(train_data[client_id], train=True)
      private_train_dataloader = DataLoader(private_train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)# , drop_last=True)

      new_client = Client(client_id, public_train_dataloader, private_train_dataloader, private_test_dataloader, None, model, architecture, "cuda")
      clients.append(new_client)



In [None]:
import csv
import os

clients_accuracies = {}
if LOAD_CHECKPOINTS_FROM_PAST_ROUNDS is True:
  for cl_id in range(0, 100):
    if os.path.exists(f"fedmd_results/client{cl_id}/stats_{ALPHA}.csv"):
      with open(f"fedmd_results/client{cl_id}/stats_{ALPHA}.csv",'r') as data:
        accuracies = [float(line[1]) for line in csv.reader(data) if line[0] != "round"]
        clients_accuracies[str(cl_id)] = accuracies[-1]

  print(clients_accuracies)

{'1': 0.0163443094629156, '5': 0.022158727621483375, '7': 0.013107416879795396, '16': 0.011588874680306905, '17': 0.012887627877237852, '18': 0.015764865728900256, '20': 0.025275735294117647, '22': 0.014046515345268542, '23': 0.01676390664961637, '25': 0.03123001918158568, '26': 0.024176790281329925, '27': 0.016264386189258313, '29': 0.02227861253196931, '30': 0.019301470588235295, '31': 0.0357456841432225, '37': 0.02042039641943734, '40': 0.034646739130434784, '43': 0.020899936061381075, '45': 0.02131953324808184, '46': 0.0154451726342711, '50': 0.010070332480818414, '52': 0.022878037084398978, '55': 0.026474584398976982, '56': 0.009750639386189258, '57': 0.029871323529411766, '60': 0.02485613810741688, '63': 0.028812340153452687, '70': 0.022238650895140665, '71': 0.03216911764705882, '72': 0.029171994884910485, '73': 0.023817135549872123, '74': 0.02563539002557545, '75': 0.015944693094629155, '76': 0.018242487212276216, '80': 0.028472666240409207, '81': 0.016284367007672634, '82': 0.

## Create "fedmd_results" folder

In [None]:
if not os.path.isdir('fedmd_results'):
  PATH = os.getcwd() + '/fedmd_results'

  c_ids = [str(i) for i in range(100)]
  !mkdir fedmd_results

  for c_id in c_ids:
    client_path = PATH + '/client' + c_id
    os.mkdir(client_path)

In [None]:
import torch
from torch.utils.data import Subset, DataLoader
from fedmd.server.server import Server
from fedmd.models_implementations.utils import save_model

NUM_ROUNDS = 50
SUBSET_BATCH_SIZE = 64 #128
NUM_SAMPLES_PER_ROUND = 39936 #5120 #the batch size was set to 128 and NUM_SAMPLES_PER_ROUND must be a multiple of the batch size

PATH = os.getcwd() + '/fedmd_results'

server = Server(clients, NUM_ROUNDS, public_train_dataloader, NUM_SAMPLES_PER_ROUND, SUBSET_BATCH_SIZE, ALPHA, "cuda")

starting_round = 0
if LOAD_CHECKPOINTS_FROM_PAST_ROUNDS is True:
  starting_round = last_round_performed + 1
  for c in clients_accuracies.keys():
    # update the dictionary of accuracies in the server with the latest accuracy of each client
    # some clients may have not participated in any rounds previously. They will have the accuracy obtained after "independent training". These accuracies are
    # loaded internally by the server

    # note: the purpose of this dictionary of accuracy is for computing the consensus using the weighted average, where clients with more accurate models
    # have higher weights
    server.accuracies[c] = clients_accuracies[c]

client_maxacc = {str(i): -1 for i in range(100)}

for round in range(starting_round, NUM_ROUNDS):
  print(f"############################################################## Begin round {round} ########################################################################")
  client_acc = server.perform_round()

  for c in client_acc.keys():
    # Save the best model
    if client_acc[c] > client_maxacc[c]:
      client_maxacc[c] = client_acc[c]
      path_best = PATH + "/client" + c + '/best_model_' + str(ALPHA) + '.pth'
      save_model(clients[int(c)]._model, path_best, round, client_maxacc[c])
    # Checkpoint
    path_checkpoint = PATH + "/client" + c + '/checkpoint_' + str(ALPHA) + '.pth'
    save_model(clients[int(c)]._model, path_checkpoint, round, client_acc[c])

    # Record stats
    stats_path = PATH + '/client' + c + '/stats_' + str(ALPHA) + '.csv'
    with open(stats_path, "a") as f:
      if round == 0:
        f.write("round,accuracy\n")
      f.write(f"{round},{client_acc[c]}\n")

  zip_name = f"/content/fedmd_results_round{round}.zip"
  destination_path = f"/content/drive/MyDrive/understanding_federated_learning_project/fedmd_collaboration_results_{ALPHA}"
  !zip -r {zip_name} /content/fedmd_results/
  !cp {zip_name} {destination_path}
  

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  adding: content/fedmd_results/client45/stats_0.0.csv (deflated 24%)
  adding: content/fedmd_results/client45/checkpoint_0.0.pth (deflated 8%)
  adding: content/fedmd_results/client36/ (stored 0%)
  adding: content/fedmd_results/client36/best_model_0.0.pth (deflated 10%)
  adding: content/fedmd_results/client36/stats_0.0.csv (stored 0%)
  adding: content/fedmd_results/client36/checkpoint_0.0.pth (deflated 10%)
  adding: content/fedmd_results/client43/ (stored 0%)
  adding: content/fedmd_results/client43/best_model_0.0.pth (deflated 8%)
  adding: content/fedmd_results/client43/stats_0.0.csv (stored 0%)
  adding: content/fedmd_results/client43/checkpoint_0.0.pth (deflated 8%)
  adding: content/fedmd_results/client22/ (stored 0%)
  adding: content/fedmd_results/client22/best_model_0.0.pth (deflated 9%)
  adding: content/fedmd_results/client22/stats_0.0.csv (stored 0%)
  adding: content/fedmd_results/client22/checkpoint_0.0.

KeyboardInterrupt: ignored