In [1]:
###############################
##### importing libraries #####
###############################

import os
import random
from tqdm import tqdm
import numpy as np
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset   

torch.backends.cudnn.benchmark=True

In [2]:
num_clients = 20
num_selected = 6
num_rounds = 50
epochs = 6
batch_size = 32

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # what is this?
])

In [5]:
df_train = datasets.CIFAR10('../data', train=True, download=True, transform=transform_train)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 30.2MB/s] 


Extracting ../data/cifar-10-python.tar.gz to ../data


In [None]:
# split data for a number of num_clients
df_train_split = torch.utils.data.random_split(df_train, [int(df_train.data.shape[0] / num_clients) for _ in range(num_clients)])


In [18]:
train_loader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True) for x in df_train_split]

In [19]:
# Normalizing the test images
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [20]:
# Load test images to a test loader
# A test loader in PyTorch is .. ?
#    simply DataLoader configured for testing 

test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('../data', train=False, transform=transform_test), batch_size=batch_size, shuffle=True)

In [21]:
cfg = {
    'VGG11':[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__() 

        # Are these a common use in Torch, dividing up features and classifier layers?
        ##   Seems like it, to make it easier for transfer learning, modularity, clarity, and memory optimization (Claude)
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Sequential(
            nn.Linear(512, 512), # input and outputs
            nn.ReLU(True), # in PyTorch, the function is added sequentially
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 10),
        )

    def forward(self, X):
        out = self.features(X)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        output = F.log_softmax(out, dim=1)
        return output
    
    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for c in cfg:
            if c == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, c, kernel_size=3, padding=1),
                           nn.BatchNorm2d(c),
                           nn.ReLU(inplace=True)] # Why use inplace? ## Saves memory and some performance consideration(Claude)
                in_channels = c
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)



In [30]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [31]:
def client_update(client_model, optimizer, train_loader, epoch=5):
    """
    This function updates/trains client model on client data
    """
    client_model.train()
    for e in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = client_model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    return loss.item()

In [None]:
def server_aggregate(global_model, client_models):
    """
    This function has aggregation method 'mean'
    """
    ### This will take simple mean of the weights of models ###
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

    

In [33]:
def test(global_model, test_loader):
    """This function test the global model on test data and returns test loss and test accuracy """
    global_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = global_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)

    return test_loss, acc

---

In [34]:
device

device(type='mps')

In [35]:

#### global model ##########
global_model =  VGG('VGG11').to(device)

############## client models ##############
client_models = [ VGG('VGG11').to(device) for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict()) ### initial synchronizing with global model 

############### optimizers ################
opt = [optim.SGD(model.parameters(), lr=0.1) for model in client_models]

In [36]:
loss_train = []
loss_test = []
acc_train = []
acc_test = []

In [37]:
for r in range(num_rounds):
    client_idx = np.random.permutation(num_clients)[:num_selected]
    loss = 0
    for i in tqdm(range(num_selected)):
        loss += client_update(client_models[i], opt[i], train_loader[client_idx[i]], epoch=epochs)

    loss_train.append(loss)
    # server aggregate
    server_aggregate(global_model, client_models)
    
    test_loss, acc = test(global_model, test_loader)
    loss_test.append(test_loss)
    acc_test.append(acc)
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))


100%|██████████| 6/6 [01:17<00:00, 12.87s/it]


0-th round
average train loss 1.38 | test loss 2.28 | test acc: 0.194


100%|██████████| 6/6 [01:10<00:00, 11.68s/it]


1-th round
average train loss 1.83 | test loss 1.45 | test acc: 0.463


100%|██████████| 6/6 [01:10<00:00, 11.78s/it]


2-th round
average train loss 1.54 | test loss 1.13 | test acc: 0.600


100%|██████████| 6/6 [01:26<00:00, 14.49s/it]


3-th round
average train loss 2.02 | test loss 1.1 | test acc: 0.611


100%|██████████| 6/6 [01:27<00:00, 14.52s/it]


4-th round
average train loss 1.1 | test loss 0.854 | test acc: 0.691


100%|██████████| 6/6 [01:25<00:00, 14.25s/it]


5-th round
average train loss 0.642 | test loss 0.828 | test acc: 0.711


100%|██████████| 6/6 [01:29<00:00, 14.85s/it]


6-th round
average train loss 1.18 | test loss 0.789 | test acc: 0.731


100%|██████████| 6/6 [01:28<00:00, 14.67s/it]


7-th round
average train loss 1.4 | test loss 0.721 | test acc: 0.750


100%|██████████| 6/6 [01:34<00:00, 15.78s/it]


8-th round
average train loss 0.922 | test loss 0.852 | test acc: 0.720


100%|██████████| 6/6 [01:50<00:00, 18.45s/it]


9-th round
average train loss 0.827 | test loss 0.717 | test acc: 0.767


100%|██████████| 6/6 [01:40<00:00, 16.83s/it]


10-th round
average train loss 0.969 | test loss 0.651 | test acc: 0.778


100%|██████████| 6/6 [01:31<00:00, 15.30s/it]


11-th round
average train loss 0.977 | test loss 0.741 | test acc: 0.765


100%|██████████| 6/6 [01:38<00:00, 16.35s/it]


12-th round
average train loss 1.35 | test loss 0.722 | test acc: 0.768


100%|██████████| 6/6 [01:35<00:00, 15.95s/it]


13-th round
average train loss 1.03 | test loss 0.703 | test acc: 0.779


100%|██████████| 6/6 [01:51<00:00, 18.58s/it]


14-th round
average train loss 0.993 | test loss 0.662 | test acc: 0.789


100%|██████████| 6/6 [01:46<00:00, 17.68s/it]


15-th round
average train loss 0.994 | test loss 0.572 | test acc: 0.813


100%|██████████| 6/6 [01:35<00:00, 15.91s/it]


16-th round
average train loss 0.569 | test loss 0.558 | test acc: 0.816


100%|██████████| 6/6 [01:36<00:00, 16.16s/it]


17-th round
average train loss 0.775 | test loss 0.56 | test acc: 0.826


100%|██████████| 6/6 [01:41<00:00, 16.92s/it]


18-th round
average train loss 0.449 | test loss 0.59 | test acc: 0.817


100%|██████████| 6/6 [01:33<00:00, 15.60s/it]


19-th round
average train loss 0.993 | test loss 0.561 | test acc: 0.820


100%|██████████| 6/6 [01:38<00:00, 16.40s/it]


20-th round
average train loss 0.855 | test loss 0.598 | test acc: 0.814


100%|██████████| 6/6 [01:52<00:00, 18.79s/it]


21-th round
average train loss 0.655 | test loss 0.624 | test acc: 0.810


100%|██████████| 6/6 [01:52<00:00, 18.80s/it]


22-th round
average train loss 0.818 | test loss 0.51 | test acc: 0.837


100%|██████████| 6/6 [01:34<00:00, 15.79s/it]


23-th round
average train loss 0.383 | test loss 0.492 | test acc: 0.847


100%|██████████| 6/6 [01:38<00:00, 16.34s/it]


24-th round
average train loss 0.238 | test loss 0.461 | test acc: 0.855


100%|██████████| 6/6 [01:43<00:00, 17.18s/it]


25-th round
average train loss 0.554 | test loss 0.623 | test acc: 0.819


100%|██████████| 6/6 [01:40<00:00, 16.81s/it]


26-th round
average train loss 0.721 | test loss 0.568 | test acc: 0.830


100%|██████████| 6/6 [01:48<00:00, 18.12s/it]


27-th round
average train loss 0.736 | test loss 0.527 | test acc: 0.842


100%|██████████| 6/6 [01:44<00:00, 17.34s/it]


28-th round
average train loss 0.572 | test loss 0.54 | test acc: 0.838


100%|██████████| 6/6 [01:37<00:00, 16.24s/it]


29-th round
average train loss 0.625 | test loss 0.543 | test acc: 0.843


100%|██████████| 6/6 [01:44<00:00, 17.43s/it]


30-th round
average train loss 0.102 | test loss 0.476 | test acc: 0.860


100%|██████████| 6/6 [01:51<00:00, 18.61s/it]


31-th round
average train loss 0.682 | test loss 0.665 | test acc: 0.820


100%|██████████| 6/6 [01:48<00:00, 18.05s/it]


32-th round
average train loss 0.517 | test loss 0.52 | test acc: 0.848


100%|██████████| 6/6 [01:44<00:00, 17.46s/it]


33-th round
average train loss 0.774 | test loss 0.779 | test acc: 0.800


100%|██████████| 6/6 [01:44<00:00, 17.39s/it]


34-th round
average train loss 0.71 | test loss 0.499 | test acc: 0.853


100%|██████████| 6/6 [01:45<00:00, 17.63s/it]


35-th round
average train loss 0.4 | test loss 0.451 | test acc: 0.867


100%|██████████| 6/6 [02:05<00:00, 20.91s/it]


36-th round
average train loss 0.236 | test loss 0.477 | test acc: 0.863


100%|██████████| 6/6 [01:54<00:00, 19.04s/it]


37-th round
average train loss 0.802 | test loss 0.543 | test acc: 0.847


100%|██████████| 6/6 [02:18<00:00, 23.01s/it]


38-th round
average train loss 0.102 | test loss 0.451 | test acc: 0.874


100%|██████████| 6/6 [02:16<00:00, 22.76s/it]


39-th round
average train loss 0.353 | test loss 0.529 | test acc: 0.851


100%|██████████| 6/6 [01:50<00:00, 18.47s/it]


40-th round
average train loss 0.138 | test loss 0.458 | test acc: 0.868


100%|██████████| 6/6 [02:00<00:00, 20.02s/it]


41-th round
average train loss 0.768 | test loss 0.573 | test acc: 0.843


100%|██████████| 6/6 [01:56<00:00, 19.44s/it]


42-th round
average train loss 0.565 | test loss 0.529 | test acc: 0.855


100%|██████████| 6/6 [01:47<00:00, 17.96s/it]


43-th round
average train loss 0.352 | test loss 0.599 | test acc: 0.846


100%|██████████| 6/6 [01:50<00:00, 18.43s/it]


44-th round
average train loss 0.11 | test loss 0.442 | test acc: 0.878


100%|██████████| 6/6 [01:39<00:00, 16.57s/it]


45-th round
average train loss 0.17 | test loss 0.43 | test acc: 0.878


100%|██████████| 6/6 [02:21<00:00, 23.62s/it]


46-th round
average train loss 1.07 | test loss 0.59 | test acc: 0.843


100%|██████████| 6/6 [01:53<00:00, 18.84s/it]


47-th round
average train loss 0.119 | test loss 0.456 | test acc: 0.879


100%|██████████| 6/6 [02:16<00:00, 22.68s/it]


48-th round
average train loss 0.266 | test loss 0.414 | test acc: 0.884


100%|██████████| 6/6 [02:24<00:00, 24.11s/it]


49-th round
average train loss 0.412 | test loss 0.588 | test acc: 0.855


In [39]:
torch.save(global_model.state_dict(), 'global_model.pt')