# Implementation of FedGKT framework


In [None]:
# Import all the necessary libraries
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.utils import make_grid
from torch.autograd import Variable
import torch.optim as optim
from torchvision import models
from torchvision import transforms as tt
from PIL import Image
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import time
import argparse
import logging

sns.set()

# set manual seed for reproducibility
# [100, 0, 42] => list of seeds
seed = 0

# general reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# gpu training specific
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Argument Parser initialization

In [None]:
def add_args(parser):
  # Training settings
  parser.add_argument('--model_client', type=str, default='resnet5', metavar='N', \
                      help='neural network used in training')

  parser.add_argument('--model_server', type=str, default='resnet32', metavar='N', \
                      help='neural network used in training')
  # Dataset used
  parser.add_argument('--dataset', type=str, default='cifar10', metavar='N', \
                      help='dataset used for training')
  # Path for dataset storage
  parser.add_argument('--data_dir', type=str, default='./content', help='data directory') 
  # how to partition the dataset on local workers
  parser.add_argument('--partition_method', type=str, default='hetero', metavar='N', \
                      help='how to partition the dataset on local workers')
  
  parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA', \
                      help='partition alpha (default: 0.5)')

  parser.add_argument('--batch_size', type=int, default=128, metavar='N', \
                      help='input batch size for training (default: 64)')

  parser.add_argument('--lr', type=float, default=0.001, metavar='LR', \
                      help='learning rate (default: 0.001)')

  parser.add_argument('--wd', help='weight decay parameter;', type=float, default=1e-4)

  parser.add_argument('--epochs_client', type=int, default=1, metavar='EP', \
                      help='how many epochs will be trained locally')

  parser.add_argument('--local_points', type=int, default=5000, metavar='LP', \
                      help='the approximate fixed number of data points we will have on each local worker')
  # Client's number
  parser.add_argument('--client_number', type=int, default=100, metavar='NN', \
                      help='number of workers in a distributed cluster')

  parser.add_argument('--comm_round', type=int, default=20, \
                      help='how many round of communications we shoud use')

  parser.add_argument('--gpu', type=int, default=0, help='gpu')

  parser.add_argument('--loss_scale', type=float, default=1024, help='Loss scaling, positive power of 2 values can improve fp16 convergence.')

  parser.add_argument('--no_bn_wd', action='store_true', help='Remove batch norm from weight decay')

  # knowledge distillation
  parser.add_argument('--temperature', default=3.0, type=float, help='Input the temperature: default(3.0)')
  parser.add_argument('--epochs_server', type=int, default=5, metavar='EP', help='how many epochs will be trained on the server side')
  # [0, 0.1, 0.01]
  parser.add_argument('--alpha', default=1, type=float, help='Input the relative weight: default(1.0)')
  parser.add_argument('--optimizer', default="SGD", type=str, help='optimizer: SGD, Adam, etc.')
  parser.add_argument('--whether_training_on_client', default=1, type=int)
  parser.add_argument('--whether_distill_on_the_server', default=0, type=int)
  parser.add_argument('--client_model', default="resnet4", type=str)
  parser.add_argument('--weight_init_model', default="resnet32", type=str)
  parser.add_argument('--running_name', default="default", type=str)
  parser.add_argument('--sweep', default=0, type=int)
  parser.add_argument('--multi_gpu_server', action='store_true')
  parser.add_argument('--test', action='store_true', help='test mode, only run 1-2 epochs to test the bug of the program')
  parser.add_argument('--gpu_num_per_server', type=int, default=1, help='gpu_num_per_server')
  # quick solution to solve the parsing problem
  parser.add_argument('-f')
  args = parser.parse_args()
  return args

In [None]:
# parse python script input parameters
parser = argparse.ArgumentParser()
args = add_args(parser)
logging.info(args)

## Class CIFAR-10 Truncated

In [None]:
import logging

import numpy as np
import torch.utils.data as data
from PIL import Image
from torchvision.datasets import CIFAR10

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def default_loader(path):
    return pil_loader(path)


class CIFAR10_truncated(data.Dataset):

    def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):

        self.root = root
        self.dataidxs = dataidxs
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download

        self.data, self.target = self.__build_truncated_dataset__()

    # Splitting between data and target
    def __build_truncated_dataset__(self):
        print("download = " + str(self.download))
        cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)

        if self.train:
            # print("train member of the class: {}".format(self.train))
            # data = cifar_dataobj.train_data
            data = cifar_dataobj.data
            target = np.array(cifar_dataobj.targets)
        else:
            data = cifar_dataobj.data
            target = np.array(cifar_dataobj.targets)

        if self.dataidxs is not None:
            data = data[self.dataidxs]
            target = target[self.dataidxs]

        return data, target

    def truncate_channel(self, index):
        for i in range(index.shape[0]):
            gs_index = index[i]
            self.data[gs_index, :, :, 1] = 0.0
            self.data[gs_index, :, :, 2] = 0.0

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.target[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

## Dataset initialization + Dataloaders definition + Dataset Splitting

### DataLoaders definition



In [None]:
import logging

import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

def record_net_data_stats(y_train, net_dataidx_map):
    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp
    logging.debug('Data statistics: %s' % str(net_cls_counts))
    return net_cls_counts
    
def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None):
    return get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs)

def get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs=None):
    """
      :param datadir: data directory
      :param train_bs: batch size for training
      :param test_bs:  batch size for test
      :param dataidxs: indexes of the splitting
      :return: train DataLoader, test DataLoader
    """
    dl_obj = CIFAR10_truncated

    transform_train, transform_test = _data_transforms_cifar10()

    train_ds = dl_obj(datadir, dataidxs=dataidxs, train=True, transform=transform_train, download=True)
    test_ds = dl_obj(datadir, train=False, transform=transform_test, download=True)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=True)

    return train_dl, test_dl

class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img

def _data_transforms_cifar10():
    """
        Defines the transformations to apply to the data.
        return: transformations for train and validation dataaset
    """
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    train_transform.transforms.append(Cutout(16))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    return train_transform, valid_transform

def load_cifar10_data(datadir):
    """
      :param data directory
      :return: splits (x_train, y_train, x_test, y_test)
    """
    train_transform, test_transform = _data_transforms_cifar10()
    # Build distributed datasets
    cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=train_transform)
    cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=test_transform)

    X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target
    X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target

    return (X_train, y_train, X_test, y_test)

def partition_data(dataset, datadir, partition, n_nets, alpha):
    """
      :param dataset: dataset name
      :param datadir: storage folder
      :param partition: partition method (homogeneous as default)
      :param n_nets: number of clients
      :param alpha: proptionality in different partitions
      :return: splits: X_train, y_train, X_test, y_test
      
    """
    logging.info("*********partition data***************")
    X_train, y_train, X_test, y_test = load_cifar10_data(datadir)
    n_train = X_train.shape[0]
    # n_test = X_test.shape[0]

    if partition == "homo":
        total_num = n_train
        idxs = np.random.permutation(total_num)
        batch_idxs = np.array_split(idxs, n_nets)
        net_dataidx_map = {i: batch_idxs[i] for i in range(n_nets)}

    elif partition == "hetero":
        min_size = 0
        K = 10
        N = y_train.shape[0]
        logging.info("N = " + str(N))
        net_dataidx_map = {}

        while min_size < 10:
            idx_batch = [[] for _ in range(n_nets)]
            # for each class in the dataset
            for k in range(K):
                idx_k = np.where(y_train == k)[0]
                np.random.shuffle(idx_k)
                # Here we use alpha for partion the dataset
                proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
                ## Balance
                proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])

        for j in range(n_nets):
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]
            
    # Probably we will never use those features, but we can hold for future experiments
    elif partition == "hetero-fix":
        dataidx_map_file_path = './data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt'
        net_dataidx_map = read_net_dataidx_map(dataidx_map_file_path)

    if partition == "hetero-fix":
        distribution_file_path = './data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt'
        traindata_cls_counts = read_data_distribution(distribution_file_path)
    else:
        traindata_cls_counts = record_net_data_stats(y_train, net_dataidx_map)

    # net_dataidx_map is a dictionary of length 4: {key: int, value: [list of indexes mapping the data among the workers}
    # traindata_cls_counts is a dictionary of length 4, basically assesses how the different labels are distributed among
    # the client, counting the total number of examples per class in each client.
    return X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts

def load_partition_data_cifar10(dataset, data_dir, partition_method, partition_alpha, client_number, batch_size):
    """
      :param dataset: dataset name
      :param data_dir: storage folder
      :param partition_method: partition method used (homogeneous as default)
      :param partition_alpha: constant regulation the proportions in the partitions (default 0.5)
      :param client_number: number of clients involved 
      :param batch_size: number of images in a batch
      :return: partitions
    """
    X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data(dataset, data_dir, partition_method, client_number, partition_alpha)

    class_num = len(np.unique(y_train))
    logging.info("traindata_cls_counts = " + str(traindata_cls_counts))
    train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)])

    train_data_global, test_data_global = get_dataloader(dataset, data_dir, batch_size, batch_size)
    logging.info("train_dl_global number = " + str(len(train_data_global)))
    logging.info("test_dl_global number = " + str(len(test_data_global)))
    test_data_num = len(test_data_global)

    # get local dataset
    data_local_num_dict = dict()
    train_data_local_dict = dict()
    test_data_local_dict = dict()

    for client_idx in range(client_number):
        dataidxs = net_dataidx_map[client_idx]
        local_data_num = len(dataidxs)
        data_local_num_dict[client_idx] = local_data_num
        logging.info("client_idx = %d, local_sample_number = %d" % (client_idx, local_data_num))

        # training batch size = 64; algorithms batch size = 32
        train_data_local, test_data_local = get_dataloader(dataset, data_dir, batch_size, batch_size,
                                                           dataidxs)
        logging.info("client_idx = %d, batch_num_train_local = %d, batch_num_test_local = %d" % (
            client_idx, len(train_data_local), len(test_data_local)))
        train_data_local_dict[client_idx] = train_data_local
        test_data_local_dict[client_idx] = test_data_local
    return train_data_num, test_data_num, train_data_global, test_data_global, data_local_num_dict, \
    train_data_local_dict, test_data_local_dict, class_num #, traindata_cls_counts


### Dataset partition

In [None]:
def init_training_device(process_ID, fl_worker_num, gpu_num_per_machine):
  """
    The function maps the process ID to GPU ID: <process ID, GPU ID>
  """
  if process_ID == 0:
      device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
      return device

  process_gpu_dict = dict()
  for client_index in range(fl_worker_num):
      gpu_index = client_index % gpu_num_per_machine
      process_gpu_dict[client_index] = gpu_index

  logging.info(process_gpu_dict)
  device = torch.device("cuda:" + str(process_gpu_dict[process_ID - 1]) if torch.cuda.is_available() else "cpu")
  logging.info(device)
  return device

def load_data(args, dataset_name):
    """
      :param (str) dataset_name: name of the dataset used CIFAR-10
      :return: (list) dataset: dataloaders splitted according with input parameters
    """
    data_loader = load_partition_data_cifar10
    # the input parameters of data_loader are defined as global variables
    # args = [dataset,data_dir,partition_method,partition_alpha,client_number,batch_size]
    train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict,  \
    train_data_local_dict, test_data_local_dict, class_num = data_loader(args.dataset, args.data_dir, \
                                                                         args.partition_method, args.partition_alpha, \
                                                                         args.client_number, args.batch_size)

    dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
               train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
    return dataset

# Device initialization
device = init_training_device(process_ID=0, fl_worker_num=1, gpu_num_per_machine= args.gpu_num_per_server)

# Loading data.
# Note: if you use # of client epoch larger than 1,
# please set the shuffle=False for the dataloader (CIFAR10/CIFAR100/CINIC10),
# which keeps the batch sequence order across epoches.
dataset = load_data(args, args.dataset)
[train_data_num, test_data_num, train_data_global, test_data_global,
  train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] = dataset

INFO:root:*********partition data***************


download = True
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./content/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./content/cifar-10-python.tar.gz to ./content
download = True
Files already downloaded and verified


INFO:root:N = 50000
INFO:root:traindata_cls_counts = {0: {0: 86, 1: 69, 2: 2, 3: 34, 4: 83, 5: 17, 6: 21, 7: 209}, 1: {0: 2, 1: 28, 2: 37, 3: 135, 4: 11, 5: 62, 7: 132, 8: 47, 9: 318}, 2: {0: 1, 1: 3, 2: 6, 3: 3, 5: 7, 6: 2, 7: 114, 8: 5, 9: 35}, 3: {0: 36, 1: 1, 2: 21, 3: 70, 5: 47, 6: 7, 7: 5, 8: 226, 9: 5}, 4: {0: 46, 1: 54, 2: 46, 3: 37, 4: 8, 5: 296, 6: 82}, 5: {0: 114, 1: 123, 2: 68, 3: 49, 4: 8, 5: 132, 6: 89}, 6: {0: 1, 2: 32, 3: 24, 4: 180, 5: 1, 7: 16, 8: 166, 9: 54}, 7: {0: 66, 2: 9, 3: 38, 4: 32, 5: 8, 6: 24, 8: 1, 9: 30}, 8: {0: 103, 1: 34, 2: 1, 3: 86, 4: 55, 5: 5, 6: 105, 7: 45}, 9: {1: 10, 2: 250, 3: 1, 4: 21, 5: 24, 6: 186, 7: 12}, 10: {0: 106, 1: 19, 3: 182, 4: 15, 5: 114, 6: 82}, 11: {0: 47, 2: 103, 3: 1, 4: 19, 5: 36, 6: 1, 7: 1, 8: 314}, 12: {0: 42, 1: 45, 2: 56, 3: 12, 5: 60, 6: 86, 7: 95, 8: 6}, 13: {0: 45, 1: 80, 2: 87, 3: 59, 4: 33, 5: 2, 6: 27, 7: 296}, 14: {0: 82, 1: 26, 2: 22, 3: 4, 4: 18, 5: 17, 6: 231, 7: 19, 8: 9, 9: 197}, 15: {0: 24, 1: 120, 2: 23, 3: 42

download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:train_dl_global number = 195
INFO:root:test_dl_global number = 39
INFO:root:client_idx = 0, local_sample_number = 521


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 0, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 1, local_sample_number = 772


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 1, batch_num_train_local = 3, batch_num_test_local = 39
INFO:root:client_idx = 2, local_sample_number = 176


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 2, batch_num_train_local = 0, batch_num_test_local = 39
INFO:root:client_idx = 3, local_sample_number = 418


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 3, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 4, local_sample_number = 569


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 4, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 5, local_sample_number = 583


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 5, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 6, local_sample_number = 474


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 6, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 7, local_sample_number = 208


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 7, batch_num_train_local = 0, batch_num_test_local = 39
INFO:root:client_idx = 8, local_sample_number = 434


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 8, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 9, local_sample_number = 504


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 9, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 10, local_sample_number = 518


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 10, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 11, local_sample_number = 522


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 11, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 12, local_sample_number = 402


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 12, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 13, local_sample_number = 629


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 13, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 14, local_sample_number = 625


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 14, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 15, local_sample_number = 576


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 15, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 16, local_sample_number = 349


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 16, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 17, local_sample_number = 271


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 17, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 18, local_sample_number = 309


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 18, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 19, local_sample_number = 625


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 19, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 20, local_sample_number = 168


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 20, batch_num_train_local = 0, batch_num_test_local = 39
INFO:root:client_idx = 21, local_sample_number = 571


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 21, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 22, local_sample_number = 587


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 22, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 23, local_sample_number = 572


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 23, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 24, local_sample_number = 1008


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 24, batch_num_train_local = 3, batch_num_test_local = 39
INFO:root:client_idx = 25, local_sample_number = 709


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 25, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 26, local_sample_number = 537


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 26, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 27, local_sample_number = 575


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 27, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 28, local_sample_number = 501


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 28, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 29, local_sample_number = 553


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 29, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 30, local_sample_number = 582


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 30, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 31, local_sample_number = 512


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 31, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 32, local_sample_number = 506


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 32, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 33, local_sample_number = 607


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 33, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 34, local_sample_number = 717


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 34, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 35, local_sample_number = 539


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 35, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 36, local_sample_number = 572


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 36, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 37, local_sample_number = 536


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 37, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 38, local_sample_number = 565


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 38, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 39, local_sample_number = 328


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 39, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 40, local_sample_number = 770


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 40, batch_num_train_local = 3, batch_num_test_local = 39
INFO:root:client_idx = 41, local_sample_number = 648


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 41, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 42, local_sample_number = 449


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 42, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 43, local_sample_number = 283


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 43, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 44, local_sample_number = 325


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 44, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 45, local_sample_number = 335


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 45, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 46, local_sample_number = 581


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 46, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 47, local_sample_number = 367


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 47, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 48, local_sample_number = 508


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 48, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 49, local_sample_number = 540


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 49, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 50, local_sample_number = 527


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 50, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 51, local_sample_number = 591


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 51, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 52, local_sample_number = 631


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 52, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 53, local_sample_number = 299


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 53, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 54, local_sample_number = 442


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 54, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 55, local_sample_number = 405


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 55, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 56, local_sample_number = 665


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 56, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 57, local_sample_number = 513


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 57, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 58, local_sample_number = 464


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 58, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 59, local_sample_number = 635


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 59, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 60, local_sample_number = 531


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 60, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 61, local_sample_number = 517


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 61, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 62, local_sample_number = 570


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 62, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 63, local_sample_number = 611


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 63, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 64, local_sample_number = 141


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 64, batch_num_train_local = 0, batch_num_test_local = 39
INFO:root:client_idx = 65, local_sample_number = 216


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 65, batch_num_train_local = 0, batch_num_test_local = 39
INFO:root:client_idx = 66, local_sample_number = 264


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 66, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 67, local_sample_number = 293


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 67, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 68, local_sample_number = 532


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 68, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 69, local_sample_number = 457


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 69, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 70, local_sample_number = 446


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 70, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 71, local_sample_number = 226


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 71, batch_num_train_local = 0, batch_num_test_local = 39
INFO:root:client_idx = 72, local_sample_number = 582


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 72, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 73, local_sample_number = 368


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 73, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 74, local_sample_number = 820


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 74, batch_num_train_local = 3, batch_num_test_local = 39
INFO:root:client_idx = 75, local_sample_number = 659


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 75, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 76, local_sample_number = 592


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 76, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 77, local_sample_number = 491


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 77, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 78, local_sample_number = 636


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 78, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 79, local_sample_number = 532


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 79, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 80, local_sample_number = 322


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 80, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 81, local_sample_number = 553


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 81, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 82, local_sample_number = 334


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 82, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 83, local_sample_number = 662


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 83, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 84, local_sample_number = 320


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 84, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 85, local_sample_number = 557


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 85, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 86, local_sample_number = 434


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 86, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 87, local_sample_number = 455


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 87, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 88, local_sample_number = 289


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 88, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 89, local_sample_number = 487


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 89, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 90, local_sample_number = 506


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 90, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 91, local_sample_number = 569


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 91, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 92, local_sample_number = 643


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 92, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 93, local_sample_number = 515


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 93, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 94, local_sample_number = 513


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 94, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 95, local_sample_number = 514


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 95, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 96, local_sample_number = 512


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 96, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 97, local_sample_number = 495


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 97, batch_num_train_local = 1, batch_num_test_local = 39
INFO:root:client_idx = 98, local_sample_number = 619


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 98, batch_num_train_local = 2, batch_num_test_local = 39
INFO:root:client_idx = 99, local_sample_number = 509


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:client_idx = 99, batch_num_train_local = 1, batch_num_test_local = 39


## Client/Server Models definition

### ResNet Server model definition

In [None]:
'''
ResNet for CIFAR-10/100 Dataset.
Reference:
1. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
2. https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua
3. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. https://arxiv.org/abs/1512.03385
'''
import logging

import torch
import torch.nn as nn

__all__ = ['ResNet']


class MyGroupNorm(nn.Module):
    def __init__(self, num_channels):
        super(MyGroupNorm, self).__init__()
        self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,eps=1e-5, affine=True)

    def forward(self, x):
        x = self.norm(x)
        return x


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet_server(nn.Module):

    def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1,
                 width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, KD=False):
        super(ResNet_server, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 16
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))

        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d()
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * block.expansion, num_classes)
        self.KD = KD
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        # x = self.conv1(x)
        # x = self.bn1(x)
        # x = self.relu(x)  # B x 16 x 32 x 32
        # x = self.maxpool(x)
        x = self.layer1(x)  # B x 16 x 32 x 32
        x = self.layer2(x)  # B x 32 x 16 x 16
        x = self.layer3(x)  # B x 64 x 8 x 8

        x = self.avgpool(x)  # B x 64 x 1 x 1
        x_f = x.view(x.size(0), -1)  # B x 64
        x = self.fc(x_f)  # B x num_classes
        return x


def resnet56_server(c, pretrained=False, path=None, **kwargs):
    """
    Constructs a ResNet-110 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained.
    """
    logging.info("path = " + str(path))
    model = ResNet_server(Bottleneck, [6, 6, 6], num_classes=c, **kwargs)
    if pretrained:
        checkpoint = torch.load(path)
        state_dict = checkpoint['state_dict']

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            # name = k[7:]  # remove 'module.' of dataparallel
            name = k.replace("module.", "")
            new_state_dict[name] = v

        model.load_state_dict(new_state_dict)
    return model


def resnet49_server(c, pretrained=False, path=None, **kwargs):
    """
    Constructs a ResNet-49 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained.
    """
    logging.info("path = " + str(path))
    model = ResNet_server(Bottleneck, [5, 5, 6], num_classes=c, **kwargs)
    if pretrained:
        checkpoint = torch.load(path)
        state_dict = checkpoint['state_dict']

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            # name = k[7:]  # remove 'module.' of dataparallel
            name = k.replace("module.", "")
            new_state_dict[name] = v

        model.load_state_dict(new_state_dict)
    return model

### ResNet Client model definition


#### Client model

In [None]:
'''
ResNet for CIFAR-10/100 Dataset.
Reference:
1. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
2. https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua
3. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. https://arxiv.org/abs/1512.03385
'''

import torch
import torch.nn as nn

__all__ = ['ResNet']


class MyGroupNorm(nn.Module):
    def __init__(self, num_channels):
        super(MyGroupNorm, self).__init__()
        self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,eps=1e-5, affine=True)

    def forward(self, x):
        x = self.norm(x)
        return x



def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet_client(nn.Module):

    def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1,
                 width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, KD=False):
        super(ResNet_client, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 16
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))

        self.groups = groups
        self.base_width = width_per_group

        # initialization is defined here:https://github.com/pytorch/pytorch/tree/master/torch/nn/modules
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
                               bias=False)  # init: kaiming_uniform
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 16, layers[0])
        # self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        # self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(64 * block.expansion, num_classes)
        self.fc = nn.Linear(16 * block.expansion, num_classes)
        # self.fc = nn.Linear(32 * block.expansion, num_classes)

        self.KD = KD
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)  # B x 16 x 32 x 32
        # x = self.maxpool(x)
        extracted_features = x

        x = self.layer1(x)  # B x 16 x 32 x 32
        # x = self.layer2(x)  # B x 32 x 16 x 16
        # x = self.layer3(x)  # B x 64 x 8 x 8

        x = self.avgpool(x)  # B x 64 x 1 x 1
        x_f = x.view(x.size(0), -1)  # B x 64
        logits = self.fc(x_f)  # B x num_classes
        return logits, extracted_features


def resnet5_56(c, pretrained=False, path=None, **kwargs):
    """
    Constructs a ResNet-32 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained.
    """

    model = ResNet_client(BasicBlock, [1, 2, 2], num_classes=c, **kwargs)
    if pretrained:
        checkpoint = torch.load(path)
        state_dict = checkpoint['state_dict']

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            # name = k[7:]  # remove 'module.' of dataparallel
            name = k.replace("module.", "")
            new_state_dict[name] = v

        model.load_state_dict(new_state_dict)
    return model


def resnet8_56(c, pretrained=False, path=None, **kwargs):
    """
    Constructs a ResNet-32 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained.

    """

    model = ResNet_client(Bottleneck, [2, 2, 2], num_classes=c, **kwargs)
    if pretrained:
        checkpoint = torch.load(path)
        state_dict = checkpoint['state_dict']

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            # name = k[7:]  # remove 'module.' of dataparallel
            name = k.replace("module.", "")
            new_state_dict[name] = v

        model.load_state_dict(new_state_dict)
    return model

### Client/Server models initialization


In [None]:
def create_client_model(args, n_classes):
    # Uncomment the following line if you want to use BatchNorm
    client_model = resnet8_56(n_classes)
    # Uncommnet the following line if you want to use GroupNorm
    # client_model = ResNet_client(Bottleneck, [2, 2, 2], num_classes=n_classes, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=MyGroupNorm)
    
    return client_model


def create_server_model(n_classes):
    # Uncomment the following line if you want to use BatchNorm
    server_model = resnet56_server(n_classes)
    # Uncommnet the following line if you want to use GroupNorm
    # server_model = ResNet_server(Bottleneck,[5, 5, 6], num_classes=n_classes, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=MyGroupNorm)

    return server_model

## Distributed Training 

### Utilities functions for training

In [None]:
import json

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def get_state_dict(file):
    try:
        pretrain_state_dict = torch.load(file)
    except AssertionError:
        pretrain_state_dict = torch.load(file, map_location=lambda storage, location: storage)
    return pretrain_state_dict


def get_flat_params_from(model):
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))

    flat_params = torch.cat(params)
    return flat_params


def set_flat_params_to(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(
            flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size


class RunningAverage():
    """A simple class that maintains the running average of a quantity
    Example:
    ```
    loss_avg = RunningAverage()
    loss_avg.update(2)
    loss_avg.update(4)
    loss_avg() = 3
    ```
    """

    def __init__(self):
        self.steps = 0
        self.total = 0

    def update(self, val):
        self.total += val
        self.steps += 1

    def value(self):
        # print(self.total)
        # print(self.steps)
        return self.total / float(self.steps)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
    # Transpose the tensor
    pred = pred.t()
    # Computes element-wise equality
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        # print(correct[:k].reshape(-1).sum(0))
        correct_k = correct[:k].reshape(-1).sum(0)
        tmp = correct_k.mul(100.0 / batch_size)
        res.append(tmp)

    return res


class KL_Loss(nn.Module):
    def __init__(self, temperature=1):
        super(KL_Loss, self).__init__()
        self.T = temperature

    def forward(self, output_batch, teacher_outputs):
        # output_batch  -> B X num_classes
        # teacher_outputs -> B X num_classes

        # loss_2 = -torch.sum(torch.sum(torch.mul(F.log_softmax(teacher_outputs,dim=1), F.softmax(teacher_outputs,dim=1)+10**(-7))))/teacher_outputs.size(0)
        # print('loss H:',loss_2)

        output_batch = F.log_softmax(output_batch / self.T, dim=1)
        teacher_outputs = F.softmax(teacher_outputs / self.T, dim=1) + 10 ** (-7)

        loss = self.T * self.T * nn.KLDivLoss(reduction='batchmean')(output_batch, teacher_outputs)

        # Same result KL-loss implementation
        # loss = T * T * torch.sum(torch.sum(torch.mul(teacher_outputs, torch.log(teacher_outputs) - output_batch)))/teacher_outputs.size(0)
        return loss


class CE_Loss(nn.Module):
    def __init__(self, temperature=1):
        super(CE_Loss, self).__init__()
        self.T = temperature

    def forward(self, output_batch, teacher_outputs):
        # output_batch      -> B X num_classes
        # teacher_outputs   -> B X num_classes

        output_batch = F.log_softmax(output_batch / self.T, dim=1)
        teacher_outputs = F.softmax(teacher_outputs / self.T, dim=1)

        # Same result CE-loss implementation torch.sum -> sum of all element
        loss = -self.T * self.T * torch.sum(torch.mul(output_batch, teacher_outputs)) / teacher_outputs.size(0)

        return loss


def save_dict_to_json(d, json_path):
    """Saves dict of floats in json file
    Args:
        d: (dict) of float-castable values (np.float, int, float, etc.)
        json_path: (string) path to json file
    """
    with open(json_path, 'w') as f:
        # We need to convert the values to float for json (it doesn't accept np.array, np.float, )
        d = {k: v for k, v in d.items()}
        json.dump(d, f, indent=4)


# Filter out batch norm parameters and remove them from weight decay - gets us higher accuracy 93.2 -> 93.48
# https://arxiv.org/pdf/1807.11205.pdf
def bnwd_optim_params(model, model_params, master_params):
    bn_params, remaining_params = split_bn_params(model, model_params, master_params)
    return [{'params': bn_params, 'weight_decay': 0}, {'params': remaining_params}]


def split_bn_params(model, model_params, master_params):
    def get_bn_params(module):
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): return module.parameters()
        accum = set()
        for child in module.children(): [accum.add(p) for p in get_bn_params(child)]
        return accum

    mod_bn_params = get_bn_params(model)
    zipped_params = list(zip(model_params, master_params))

    mas_bn_params = [p_mast for p_mod, p_mast in zipped_params if p_mod in mod_bn_params]
    mas_rem_params = [p_mast for p_mod, p_mast in zipped_params if p_mod not in mod_bn_params]
    return mas_bn_params, mas_rem_params


###  GKTServerTrainer Class

In [None]:
import logging
import os
import shutil

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

class GKTServerTrainer(object):
    def __init__(self, client_num, device, server_model, args):
        self.client_num = client_num
        self.device = device
        self.args = args

        """
            when use data parallel, we should increase the batch size accordingly (single GPU = 64; 4 GPUs = 256)
            One epoch training time: single GPU (64) = 1:03; 4 x GPUs (256) = 38s; 4 x GPUs (64) = 1:00
            Note that if we keep the same batch size, the frequent GPU-CPU-GPU communication will lead to
            slower training than a single GPU.
        """
        # server model
        self.model_global = server_model

        if args.multi_gpu_server and torch.cuda.device_count() > 1:
            self.model_global = nn.DataParallel(self.model_global, device_ids=[0, 1, 2, 3]).to(device)

        self.model_global.train()
        self.model_global.to(self.device)

        self.model_params = self.master_params = self.model_global.parameters()

        optim_params = bnwd_optim_params(self.model_global, self.model_params,
                                               self.master_params) if args.no_bn_wd else self.master_params

        if self.args.optimizer == "SGD":
            self.optimizer = torch.optim.SGD(optim_params, lr=self.args.lr, momentum=0.9,
                                             nesterov=True,
                                             weight_decay=self.args.wd)
        elif self.args.optimizer == "Adam":
            self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True)

        self.scheduler = ReduceLROnPlateau(self.optimizer, 'max')

        self.criterion_CE = nn.CrossEntropyLoss()
        self.criterion_KL = KL_Loss(self.args.temperature)
        self.best_acc = 0.0

        # key: client_index; value: extracted_feature_dict
        self.client_extracted_feauture_dict = dict()

        # key: client_index; value: logits_dict
        self.client_logits_dict = dict()

        # key: client_index; value: labels_dict
        self.client_labels_dict = dict()

        # key: client_index; value: labels_dict
        self.server_logits_dict = dict()

        # for test
        self.client_extracted_feauture_dict_test = dict()
        self.client_labels_dict_test = dict()

        self.model_dict = dict()
        self.sample_num_dict = dict()
        self.train_acc_list = []
        self.train_loss_list = []
        self.test_acc_list = []
        self.test_acc_avg = 0.0
        self.test_loss_avg = 0.0

        self.flag_client_model_uploaded_dict = dict()
        for idx in range(self.client_num):
            self.flag_client_model_uploaded_dict[idx] = False

    def add_local_trained_result(self, index, extracted_feature_dict, logits_dict, labels_dict,
                                 extracted_feature_dict_test, labels_dict_test):
        logging.info("add_model. index = %d" % index)
        self.client_extracted_feauture_dict[index] = extracted_feature_dict
        self.client_logits_dict[index] = logits_dict
        self.client_labels_dict[index] = labels_dict
        self.client_extracted_feauture_dict_test[index] = extracted_feature_dict_test
        self.client_labels_dict_test[index] = labels_dict_test

        self.flag_client_model_uploaded_dict[index] = True

    def check_whether_all_receive(self):
        for idx in range(self.client_num):
            if not self.flag_client_model_uploaded_dict[idx]:
                return False
        for idx in range(self.client_num):
            self.flag_client_model_uploaded_dict[idx] = False
        return True

    def get_global_logits(self, client_index):
        return self.server_logits_dict[client_index]

    def train(self, round_idx):
        if self.args.sweep == 1:
            self.sweep(round_idx)
        else:
            if self.args.whether_training_on_client == 1:
                self.train_and_distill_on_client(round_idx)
            else:
                self.do_not_train_on_client(round_idx)

    def train_and_distill_on_client(self, round_idx):
        if self.args.test:
            epochs_server, whether_distill_back = self.get_server_epoch_strategy_test()
        else:
            if self.args.client_model == "resnet56":
                epochs_server, whether_distill_back = self.get_server_epoch_strategy_reset56_2(round_idx)
            else:
                epochs_server = self.args.epochs_server

        # train according to the logits from the client
        self.train_and_eval(round_idx, epochs_server)

        # adjust the learning rate based on the number of epochs.
        # https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
        self.scheduler.step(self.best_acc, epoch=round_idx)

    def do_not_train_on_client(self, round_idx):
        self.train_and_eval(round_idx, 1)
        self.scheduler.step(self.best_acc, epoch=round_idx)

    def sweep(self, round_idx):
        # train according to the logits from the client
        self.train_and_eval(round_idx, self.args.epochs_server)
        self.scheduler.step(self.best_acc, epoch=round_idx)

    def get_server_epoch_strategy_test(self):
        return 1, True

    # ResNet56
    def get_server_epoch_strategy_reset56(self, round_idx):
        whether_distill_back = True
        # set the training strategy
        if round_idx < 20:
            epochs = 20
        elif 20 <= round_idx < 30:
            epochs = 15
        elif 30 <= round_idx < 40:
            epochs = 10
        elif 40 <= round_idx < 50:
            epochs = 5
        elif 50 <= round_idx < 100:
            epochs = 5
        elif 100 <= round_idx < 150:
            epochs = 3
        elif 150 <= round_idx <= 200:
            epochs = 2
            whether_distill_back = False
        else:
            epochs = 1
            whether_distill_back = False
        return epochs, whether_distill_back

    # ResNet56-2
    def get_server_epoch_strategy_reset56_2(self, round_idx):
        whether_distill_back = True
        # set the training strategy
        epochs = self.args.epochs_server
        return epochs, whether_distill_back

    # not increase after 40 epochs
    def get_server_epoch_strategy2(self, round_idx):
        whether_distill_back = True
        # set the training strategy
        if round_idx < 20:
            epochs = 20
        elif 20 <= round_idx < 30:
            epochs = 15
        elif 30 <= round_idx < 40:
            epochs = 10
        elif 40 <= round_idx < 50:
            epochs = 8
        elif 50 <= round_idx < 100:
            epochs = 5
        elif 100 <= round_idx < 150:
            epochs = 3
        elif 150 <= round_idx <= 200:
            epochs = 1
            whether_distill_back = False
        else:
            epochs = 1
            whether_distill_back = False
        return epochs, whether_distill_back

    def train_and_eval(self, round_idx, epochs):
        for epoch in range(epochs):
            logging.info("train_and_eval. round_idx = %d, epoch = %d" % (round_idx, epoch))
            train_metrics = self.train_large_model_on_the_server()

            if epoch == epochs - 1:
                self.train_loss_list.append(train_metrics['train_loss'])
                self.train_acc_list.append(train_metrics['train_accTop1'])
                # Evaluate for one epoch on validation set
                test_metrics = self.eval_large_model_on_the_server()

                # Find the best accTop1 model.
                test_acc = test_metrics['test_accTop1']
                self.test_acc_list.append(test_acc)

                last_path = os.path.join('/content/checkpoint/last.ph')
                # Save latest model weights, optimizer and accuracy
                torch.save({'state_dict': self.model_global.state_dict(),
                            'optim_dict': self.optimizer.state_dict(),
                            'epoch': round_idx + 1,
                            'test_accTop1': test_metrics['test_accTop1'],
                            'test_accTop5': test_metrics['test_accTop5']}, last_path)

                # If best_eval, best_save_path
                is_best = test_acc >= self.best_acc
                if is_best:
                    logging.info("- Found better accuracy")
                    self.best_acc = test_acc
                    # Save best metrics in a json file in the model directory
                    test_metrics['epoch'] = round_idx + 1
                    
                    save_dict_to_json(test_metrics, os.path.join('/content/checkpoint/', "test_best_metrics.json"))

                    # Save model and optimizer
                    shutil.copyfile(last_path, os.path.join('/content/checkpoint/', 'best.pth'))

    def train_large_model_on_the_server(self):
        # clear the server side logits
        for key in self.server_logits_dict.keys():
            self.server_logits_dict[key].clear()
        self.server_logits_dict.clear()
        # Set the model in the "train mode" is not a call to the method train of the class
        self.model_global.train()

        loss_avg = RunningAverage()
        accTop1_avg = RunningAverage()
        accTop5_avg = RunningAverage()

        for client_index in self.client_extracted_feauture_dict.keys():
            extracted_feature_dict = self.client_extracted_feauture_dict[client_index]
            logits_dict = self.client_logits_dict[client_index]
            labels_dict = self.client_labels_dict[client_index]

            s_logits_dict = dict()
            self.server_logits_dict[client_index] = s_logits_dict
            for batch_index in extracted_feature_dict.keys():
                batch_feature_map_x = torch.from_numpy(extracted_feature_dict[batch_index]).to(self.device)
                batch_logits = torch.from_numpy(logits_dict[batch_index]).float().to(self.device)
                batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device)

                # logging.info("running: batch_index = %d, client_index = %d" % (batch_index, client_index))
                output_batch = self.model_global(batch_feature_map_x)

                if self.args.whether_distill_on_the_server == 1:
                    loss_kd = self.criterion_KL(output_batch, batch_logits).to(self.device)
                    loss_true = self.criterion_CE(output_batch, batch_labels).to(self.device)
                    loss = loss_kd + self.args.alpha * loss_true
                else:
                    loss_true = self.criterion_CE(output_batch, batch_labels).to(self.device)
                    loss = loss_true

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # Update average loss and accuracy
                metrics = accuracy(output_batch, batch_labels, topk=(1, 5))
                accTop1_avg.update(metrics[0].item())
                accTop5_avg.update(metrics[1].item())
                loss_avg.update(loss.item())

                # update the logits for each client
                # Note that this must be running in the model.train() model,
                # since the client will continue the iteration based on the server logits.
                s_logits_dict[batch_index] = output_batch.cpu().detach().numpy()

        # compute mean of all metrics in summary
        train_metrics = {'train_loss': loss_avg.value(),
                         'train_accTop1': accTop1_avg.value(),
                         'train_accTop5': accTop5_avg.value()}

        metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in train_metrics.items())
        logging.info("- Train metrics: " + metrics_string)
        return train_metrics

    def eval_large_model_on_the_server(self):

        # set model to evaluation mode
        self.model_global.eval()
        loss_avg = RunningAverage()
        accTop1_avg = RunningAverage()
        accTop5_avg = RunningAverage()
        with torch.no_grad():
            for client_index in self.client_extracted_feauture_dict_test.keys():
                extracted_feature_dict = self.client_extracted_feauture_dict_test[client_index]
                labels_dict = self.client_labels_dict_test[client_index]

                for batch_index in extracted_feature_dict.keys():
                    batch_feature_map_x = torch.from_numpy(extracted_feature_dict[batch_index]).to(self.device)
                    batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device)

                    output_batch = self.model_global(batch_feature_map_x)
                    loss = self.criterion_CE(output_batch, batch_labels)

                    # Update average loss and accuracy
                    metrics = accuracy(output_batch, batch_labels, topk=(1, 5))
                    # only one element tensors can be converted to Python scalars
                    accTop1_avg.update(metrics[0].item())
                    accTop5_avg.update(metrics[1].item())
                    loss_avg.update(loss.item())

        # compute mean of all metrics in summary
        test_metrics = {'test_loss': loss_avg.value(),
                        'test_accTop1': accTop1_avg.value(),
                        'test_accTop5': accTop5_avg.value()}

        metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in test_metrics.items())
        logging.info("- Test  metrics: " + metrics_string)
        return test_metrics
    

    def plot_train_loss_vs_rounds(self):
        """
            Plot the loss vs communication rounds
        :return:
        """
        plt.figure()
        rounds = np.arange(1, len(self.train_loss_list) + 1)
        plt.plot(rounds, self.train_loss_list)
        plt.title("server loss vs communication rounds")
        plt.savefig("server_loss_x_round")
        plt.show()

    def plot_train_accuracy_vs_rounds(self):
        """
            Plot the accuracy vs communication rounds
        :return:
        """
        plt.figure()
        rounds = np.arange(1, len(self.train_acc_list) + 1)
        plt.plot(rounds, self.train_acc_list)
        plt.title("server train accuracy vs communication rounds")
        plt.savefig("server_accuracy_x_round")
        plt.show()

    def plot_test_accuracy_vs_rounds(self):
        """
            Plot the accuracy vs communication rounds
        :return:
        """
        plt.figure()
        rounds = np.arange(1, len(self.test_acc_list) + 1)
        plt.plot(rounds, self.test_acc_list)
        plt.title("server test accuracy vs communication rounds")
        plt.savefig("server_test_accuracy_x_round")
        plt.show()

### GKTClientTrainer Class

In [None]:
import logging

import torch
from torch import nn, optim
from torch.autograd import Variable

class GKTClientTrainer(object):
    def __init__(self, client_index, local_training_data, local_test_data, local_sample_number, device,
                 client_model, args):
        self.client_index = client_index
        self.local_training_data = local_training_data[client_index]
        self.local_test_data = local_test_data[client_index]

        self.local_sample_number = local_sample_number

        self.args = args

        self.device = device
        self.client_model = client_model

        logging.info("client device = " + str(self.device))
        self.client_model.to(self.device)

        self.model_params = self.master_params = self.client_model.parameters()

        optim_params = bnwd_optim_params(self.client_model, self.model_params,
                                               self.master_params) if args.no_bn_wd else self.master_params

        if self.args.optimizer == "SGD":
            self.optimizer = torch.optim.SGD(optim_params, lr=self.args.lr, momentum=0.9,
                                             nesterov=True,
                                             weight_decay=self.args.wd)
        elif self.args.optimizer == "Adam":
            self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True)

        self.criterion_CE = nn.CrossEntropyLoss()
        self.criterion_KL = KL_Loss(self.args.temperature)

        self.server_logits_dict = dict()
        self.round_loss = []

    def get_sample_number(self):
        return self.local_sample_number

    def update_large_model_logits(self, logits):
        self.server_logits_dict = logits

    def train(self):
        # key: batch_index; value: extracted_feature_map
        extracted_feature_dict = dict()

        # key: batch_index; value: logits
        logits_dict = dict()

        # key: batch_index; value: label
        labels_dict = dict()

        # for test - key: batch_index; value: extracted_feature_map
        extracted_feature_dict_test = dict()
        labels_dict_test = dict()

        if self.args.whether_training_on_client == 1:
            self.client_model.train()
            # train and update
            epoch_loss = []
            for epoch in range(self.args.epochs_client):
                batch_loss = []
                for batch_idx, (images, labels) in enumerate(self.local_training_data):
                    images = Variable(images.to(self.device))
                    # I don't know why we have to do such conversion here...but It works..so...
                    labels = torch.tensor(labels, dtype=torch.long, device=self.device)
                    self.optimizer.zero_grad()
                    # logging.info("shape = " + str(images.shape))
                    log_probs, _ = self.client_model(images)
                    loss_true = self.criterion_CE(log_probs, labels)
                    if len(self.server_logits_dict) != 0:
                        large_model_logits = torch.from_numpy(self.server_logits_dict[batch_idx]).to(
                            self.device)
                        loss_kd = self.criterion_KL(log_probs, large_model_logits)
                        loss = loss_true + self.args.alpha * loss_kd
                    else:
                        loss = loss_true

                    loss.backward()
                    self.optimizer.step()
                    # Try to comment this section
                    """
                    logging.info('client {} - Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        self.client_index, epoch, batch_idx * len(images), len(self.local_training_data.dataset),
                                                  100. * batch_idx / len(self.local_training_data), loss.item()))
                    """
                    batch_loss.append(loss.item())

                # Since we are considering just one epoch on the local clients, the epoch loss 
                # coincides with "round loss" or loss per round
                self.round_loss.append(sum(batch_loss) / len(batch_loss))
                # used for local computation               
                epoch_loss.append(sum(batch_loss) / len(batch_loss))
                logging.info('client {} - Update Epoch: {} \tEpoch Loss: {:.6f}'.format(
                    self.client_index, epoch, sum(batch_loss) / len(batch_loss)))
                 
        self.client_model.eval()

        """
            If the training dataset is too large, we may meet the following issue.
            ===================================================================================
            =   BAD TERMINATION OF ONE OF YOUR APPLICATION PROCESSES
            =   PID 28488 RUNNING AT ChaoyangHe-GPU-RTX2080Tix4
            =   EXIT CODE: 9
            =   CLEANING UP REMAINING PROCESSES
            =   YOU CAN IGNORE THE BELOW CLEANUP MESSAGES
            ===================================================================================
            The signal 9 may indicate that the job is out of memory.

            So it is better to run this program in a 256G CPU host memory. 
            If deploying our algorithm in real world system, please optimize the memory usage by compression.
        """
        # The features are extracted at the end of each local training
        for batch_idx, (images, labels) in enumerate(self.local_training_data):
            images = Variable(images.to(self.device))
            # I don't know why we have to do such conversion here...but if we don't there are problems
            labels = torch.tensor(labels, dtype=torch.long, device=self.device)
            # logging.info("shape = " + str(images.shape))
            log_probs, extracted_features = self.client_model(images)

            # logging.info("shape = " + str(extracted_features.shape))
            # logging.info("element size = " + str(extracted_features.element_size()))
            # logging.info("nelement = " + str(extracted_features.nelement()))
            # logging.info("GPU memory1 = " + str(extracted_features.nelement() * extracted_features.element_size()))
            extracted_feature_dict[batch_idx] = extracted_features.cpu().detach().numpy()
            log_probs = log_probs.cpu().detach().numpy()
            logits_dict[batch_idx] = log_probs
            labels_dict[batch_idx] = labels.cpu().detach().numpy()

        for batch_idx, (images, labels) in enumerate(self.local_test_data):
            # test_images, test_labels = images.to(self.device), labels.to(self.device)
            test_images = Variable(images.to(self.device))
            # I don't know why we have to do such conversion here...but It works..so...
            test_labels = torch.tensor(labels, dtype=torch.long, device=self.device)
            _, extracted_features_test = self.client_model(test_images)
            extracted_feature_dict_test[batch_idx] = extracted_features_test.cpu().detach().numpy()
            labels_dict_test[batch_idx] = test_labels.cpu().detach().numpy()

        return extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test
    
    def plot_loss_vs_rounds(self):
        """
            Plot the loss vs communication rounds
        :return:
        """
        import numpy as np
        sns.set()
        rounds = np.arange(len(self.round_loss))
        plt.plot(rounds, self.round_loss)
        plt.title("client loss vs communication rounds")
        plt.savefig("client" + str(self.client_index) + "_loss")
        plt.show()

### FedML distributed training init

In [None]:
def init_server(args, device, comm, rank, size, model):
    # aggregator
    client_num = size - 1
    server_trainer = GKTServerTrainer(client_num, device, model, args)
  
    return server_trainer

def init_client(args, device, comm, process_id, size, model, train_data_local_dict, test_data_local_dict,
                train_data_local_num_dict):
    client_ID = process_id - 1

    # 2. initialize the trainer
    trainer = GKTClientTrainer(client_ID, train_data_local_dict, test_data_local_dict, \
                               train_data_local_num_dict,device, model, args)
    return trainer

def FedML_FedGKT_distributed(process_id, worker_number, device, comm, model, train_data_local_num_dict,
                             train_data_local_dict, test_data_local_dict, args):
    if process_id == 0:
        # Note the server doesn't take images as input data, while the trainer for the client takes as input
        # the dictionaries containing indexes which will be used to select the images in the dataset. Indeed, the server
        # will be trained on the feature maps sent by the clients (We have to define a "protocol")
        trainer = init_server(args, device, comm, process_id, worker_number, model)

    else:
        trainer = init_client(args, device, comm, process_id, worker_number, model, train_data_local_dict, \
                              test_data_local_dict, train_data_local_num_dict)
        
    return trainer

## Main function


In [None]:
# create models
process_ids = [i for i in range(args.client_number + 1)]
size = len(process_ids)
# List containing the models
models = []
for i in range(len(process_ids)):
    if i == 0:
        models.append(create_server_model(class_num))
        print("Server model: ", models[i])
        # numel method returns the total number of elements in the input tensor.
        print(f"Parameters per layer: {[p.numel() for p in models[i].parameters()]} ")
        print(f"Total number of parameters: {sum([p.numel() for p in models[i].parameters()])}")
    else:
        # create client model
        models.append(create_client_model(args, class_num))
print("Client model: ", models[1])
# numel method returns the total number of elements in the input tensor.
print(f"Parameters per layer: {[p.numel() for p in models[1].parameters()]} ")
print(f"Total number of parameters: {sum([p.numel() for p in models[1].parameters()])}")

# Initialize the trainers, trainers[0] will be the server.
# Basically, the list will contain the Trainer object associated with each model.
trainers = []
for j in range(size):
    # process_id = j
    # worker_number = "our" size ==> len(process_ids)
    # FedML_FedGKT_distributed(process_id, worker_number, device, comm, model, train_data_local_num_dict,
    #                          train_data_local_dict, test_data_local_dict, args)
    trainer = FedML_FedGKT_distributed(process_id=j, worker_number=size, device=device, comm=None, \
                                       model=models[j], train_data_local_num_dict=train_data_local_num_dict, \
                                       train_data_local_dict=train_data_local_dict, \
                                       test_data_local_dict=test_data_local_dict, args=args)
    trainers.append(trainer)
# Argument of train should be the number of communication round.
print("START TRAINING")

# Simulate the distributed training framework
for curr_round in range(args.comm_round):
  # We random pick 10 clients among the total, index 0 is reserved for the server
  clients_idx = np.random.choice(np.arange(1, len(trainers) + 1), size=10, replace=False)
  # trainers[0] will be always the server trainer
  for k in range(len(clients_idx) + 1, 0, -1):
    # indexes will go from len(clients_idx) + 1 included to 0 excluded, so we have to fix the index
    if (k - 1) != 0:
      # Train the clients for a number of local_epochs specified in the parameters
      trainer_idx = clients_idx[k - 2]
      # Structures that will be used to enable the alternating training. The function returns dictionaries, where
      # Key: Client id, Value = tensors containing the information used by the server to perform the training
      extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test = \
          trainers[trainer_idx].train()

      # Update the local results on the server
      # we index the clients in the range [0, len(clients_idx) - 1]
      trainers[0].add_local_trained_result(index = k - 2, extracted_feature_dict=extracted_feature_dict,
                                            logits_dict=logits_dict, labels_dict=labels_dict,
                                            extracted_feature_dict_test=extracted_feature_dict_test,
                                            labels_dict_test=labels_dict_test)
    else:
      # Server Training
      trainers[0].train(curr_round)
      # Once the server finishes its current training, we have to update the clients
      # with the global logits
      for idx in range(len(clients_idx), 0, -1):
        trainer_idx = clients_idx[idx - 1]
        # Updates the global logit to clients
        global_logit = trainers[0].get_global_logits(client_index=idx - 1)
        trainers[trainer_idx].update_large_model_logits(logits=global_logit)

NameError: ignored

### Plots


In [None]:
# Clients plots
trainers[1].plot_loss_vs_rounds()
trainers[2].plot_loss_vs_rounds()
trainers[3].plot_loss_vs_rounds()
trainers[4].plot_loss_vs_rounds()
# Server plots
trainers[0].plot_train_loss_vs_rounds()
trainers[0].plot_train_accuracy_vs_rounds()
trainers[0].plot_test_accuracy_vs_rounds()