In [None]:
# imports
import torch
import torch.nn as nn
# import fedAvg
from torchvision import datasets, transforms
from lenet5 import LeNet
import numpy as np
import pandas as pd
# from tqdm import tqdm # progress bar, not really necessary

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# parameters for federated learning
num_clients = 20 # number of total clients
num_selected = 6 # number of  clients selected for the training 
num_rounds = 10

# parameters
EPOCHS = 10 # epochs defined by user
LEARNING_RATE = 2e-3 
BATCH_SIZE = 32

In [None]:
# loaders and transformations
# Image augmentation 
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Normalizing the test images
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])



# Loading CIFAR10 using torchvision.datasets
#traindata = datasets.CIFAR10('./data', train=True, download=False,
#                       transform= transform_train)
traindata = datasets.CIFAR10('./data', train=True, download=False, 
                transform=transform_train)


# Dividing the training data into num_clients, with each client having equal number of images
traindata_split = torch.utils.data.random_split(traindata, [int(traindata.data.shape[0] / num_clients) for _ in range(num_clients)])


# # Dividing the training data into num_clients, with each client having equal number of images
traindata_split = torch.utils.data.random_split(traindata, [int(traindata.data.shape[0] 
                    / num_clients) for _ in range(num_clients)])
# Creating a pytorch loader for a Deep Learning model
# train_loader = torch.utils.data.DataLoader(traindata, batch_size=BATCH_SIZE, shuffle=True) 

# Creating a pytorch loader for a Deep Learning model
train_loader = [torch.utils.data.DataLoader(x, batch_size=BATCH_SIZE, shuffle=True) for x in traindata_split]


# Loading the test iamges and thus converting them into a test_loader
test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./data', train=False, 
            transform=transforms.Compose([transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
            ), batch_size=BATCH_SIZE, shuffle=True)

In [None]:
if DEVICE == 'cuda':
    modelChosen = LeNet().to(DEVICE)
else:
    modelChosen = LeNet()
    print("we're not on cuda")


centralizedModel = modelChosen

# list of models, model per device SELECTED ( same model for each device in our case)
federatedModels =  [modelChosen for _ in range(num_selected)]

for models in federatedModels:
    models.load_state_dict(centralizedModel.state_dict())  # we initialize every model with the central


optimizers = [torch.optim.SGD(model.parameters(), lr=LEARNING_RATE) for model in federatedModels]
# NO CRITERION?
criterion = nn.CrossEntropyLoss()

In [None]:
def train(train_loader, model, criterion, optimizer, device):
    model.train()
    running_loss = 0
    for X, y_target in train_loader:
        
        # set gradient to zero
        optimizer.zero_grad()

        # if there is a GPU

        X = X.to(device)
        y_target =y_target.to(device)

        # prediction

        # call model forward()
        y_predict, _ = model(X)
        # get loss
        loss = criterion(y_predict, y_target)
        running_loss += loss.item() * X.size(0)
        
        # adjusting weights
        loss.backward()
        optimizer.step()
    
    epoch_loss = running_loss / len(train_loader.dataset)
    return model, optimizer, epoch_loss

def client_update(model, optimizer, train_loader,device,criterion ,epoch=5):
    """
    This function updates/trains client model on client data
    """
    # model.train()
    for e in range(epoch):
        model, optimizer, train_loss = train(train_loader, model,
                                criterion, optimizer, device)



        # for batch_idx, (data, target) in enumerate(train_loader):
        #     # data, target = data, target
        #     optimizer.zero_grad()
        #     output = model(data)
        #     loss = F.nll_loss(output, target) # The negative log likelihood loss
        #     loss.backward()
        #     optimizer.step()
    # return loss.item()
    print(train_loss)
    return train_loss

def server_aggregate(global_model, client_models):
    """
    This function has aggregation method 'mean'
    """
    ### This will take simple mean of the weights of models ###
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
        print(global_dict[k])
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())



def get_accuracy(model, data_loader, device):
    '''
    Function for computing the accuracy of the predictions over the entire data_loader
    '''
    
    correct_pred = 0 
    n = 0
    
    with torch.no_grad():
        model.eval()
        for X, y_true in data_loader:

            X = X.to(device)
            y_true = y_true.to(device)

            _, y_prob = model(X)
            _, predicted_labels = torch.max(y_prob, 1)

            n += y_true.size(0)
            correct_pred += (predicted_labels == y_true).sum()

    return correct_pred.float() / n
    
def test(valid_loader, model, criterion, device):
    model.eval()
    running_loss = 0

    for X, y_target in valid_loader:
        # if there is a GPU

        X = X.to(device)
        y_target = y_target.to(device)

        # prediction and loss

        # call model forward()
        y_predict, _ = model(X)
        # get loss
        loss = criterion(y_predict, y_target)
        running_loss += loss.item() * X.size(0)

    epoch_loss = running_loss / len(valid_loader.dataset)
    return model, epoch_loss
        

In [None]:
###### training #######
###### List containing info about learning #########
losses_train = []
losses_test = []
acc_train = []
acc_test_arr = []

for r in range(num_rounds):
    # select random clients
    # we select in the total number of clients, a random array of clients of size num_selected at each round
    client_idx = np.random.permutation(num_clients)[:num_selected]
    
    loss = 0 # why 0?
    for i in range(num_selected):
        loss += client_update(federatedModels[i], optimizers[i], train_loader[client_idx[i]],DEVICE,criterion, epoch=EPOCHS)
        train_acc = get_accuracy(federatedModels[i], train_loader[client_idx[i]], DEVICE)

        # print(train_loader[i])
        print(train_acc)
        print(loss)
    losses_train.append(loss)
    server_aggregate(centralizedModel, federatedModels)

    test_loss= test(test_loader,centralizedModel,criterion, DEVICE) # Test global model on data
    test_acc = get_accuracy(centralizedModel, test_loader, DEVICE)
    
    # print(test_acc)
    losses_test.append(test_loss)
    acc_test_arr.append(test_acc)
    print('%d-th round' % r)

    # print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, test_acc))
    print('average train loss %0.3g ' % (loss / num_selected))
    print(test_loss)
    print(type(test_loss))
    print(' test loss %0.3g '%(test_loss))
    print('test acc: %0.3f'% (test_acc))


    # train_acc = get_accuracy(model, train_loader, device)
    # test_acc = get_accuracy(model, test_loader, device)
    # DOES NOT GET THE ACCURACY, CHECK WITH THE FUNCITON FROM THE FEDERATED LEARNING AAND COMPARE WITH CENTRALISED MODEL