In [None]:
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.insert(0, '/content/drive/MyDrive/FL2022/FedGKT') 

In [None]:
import torch
from torchsummary import summary
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from options import args_parser
from utils import get_dataset, get_data_loaders, baseline_data, client_syn, get_n_params
logger = SummaryWriter('../logs')
from models_server import ResNet50
from models_client import ResNet8
from server import GKTServerTrainer
from client import GKTClientTrainer

In [None]:
# parameters
sys.argv=['',
          '--iid=1',  #0 -> NONiid, 1 -> iid
          '--num_users=100',
          '--lr=0.01',
          '--local_ep=1',
          '--epochs=10',
          '--optimizer=adam',
          '--norm=batch_norm',
          '--local_bs=128',
          '--dataset=cifar',
          '--loss=CrossEntropyLoss',
          '--gpu=/device:GPU:0']
args=args_parser()
num_selected =int( args.num_users * args.frac )
baseline_num = 100 # number of baseline images to be saved on the global server
alpha = 0.5 # pratition 
unbalanced = False # if True the clients will contain different number of classes
temperature = 3.0
communication_rounds = 10  #between 10 and 40 

In [None]:
train_dataset, test_dataset, user_groups = get_dataset(iid=args.iid, unbalanced=unbalanced, num_users=args.num_users)

In [None]:
server_model = ResNet50(n_type=args.norm)
get_n_params(server_model)


In [None]:

client_model = ResNet8()
get_n_params(client_model)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu = [1 if device== 'cuda' else 0]
print(device)

In [None]:
server_model.to(device)
client_model.to(device)
server_model.train()
client_model.train()

In [None]:
# init server
server_trainer = GKTServerTrainer(server_model, args.num_users, args.lr, args.epochs, device,
                                  args.optimizer, temperature)
clients_trainer = []  

idxs_users = range(args.num_users)

for idx in idxs_users:
    client_trainer = GKTClientTrainer(client_model, train_dataset, test_dataset,
                                      user_groups[idx], idx, gpu, args.optimizer, args.local_bs,
                                      args.lr, args.local_ep, temperature, alpha)
    clients_trainer.append(client_trainer)

for communication_round in range(1, communication_rounds+1):
    print(f'\nCommunication Round: {communication_round} \n')

    m = max(num_selected, 1) 
    idxs_chosen_users = np.random.choice(range(args.num_users), m, replace=False) 

    print(idxs_chosen_users)
    for idx in idxs_chosen_users:
        extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test,\
        labels_dict_test = clients_trainer[idx].train()

        server_trainer.add_local_trained_result(idx, extracted_feature_dict, logits_dict, labels_dict,
                                                extracted_feature_dict_test, labels_dict_test)


    server_trainer.train(communication_round, idxs_chosen_users)

    for idx in idxs_chosen_users:
        global_logits = server_trainer.get_global_logits(idx)
        # send global logits to client
        clients_trainer[idx].update_large_model_logits(global_logits)

# get lists of train loss and accuracy
train_loss, train_accuracy = server_trainer.get_loss_acc_list()

In [None]:
iid=['iid' if args.iid else 'nonIID']
unb=['balanced' if args.unequal==0 and not args.iid else 'unbalanced' ]
bs=[args.local_bs if args.norm=='batch_norm' else '']
# path='/content/drive/MyDrive/FL2022/GKT/ResultsGKT/'
path='/content/drive/MyDrive/Colab_Notebooks/FL2022/FedGKT/ResultsGKT/'
filename = f"fedGKT_commRound_{communication_rounds}_{iid}_{unb}_{args.norm}{bs}_{args.epochs}_lr_{args.lr}_optimizer_{args.optimizer}_rik.csv"

data = list(zip(train_loss, train_accuracy))
pd.DataFrame(data, columns=['train_loss','train_accuracy']).to_csv(path+filename)