In [3]:
# Architecture and training code partially based on https://github.com/pytorch/examples/blob/master/mnist/main.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

def client_update(client_model, optimizer, train_loader, epoch=5):
    model.train()
    for e in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    print('Train loss: {:.6f}'.format(loss.item()))

def server_aggregate(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k] for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

def test(global_model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = global_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)

In [4]:
# Hyperparameters

num_clients = 100
num_selected = 5
num_rounds = 5
epochs = 5
batch_size = 32

# Creating decentralized datasets

traindata = datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
                       )
traindata_split = torch.utils.data.random_split(traindata, [int(traindata.data.shape[0] / num_clients) for _ in range(num_clients)])
train_loader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True) for x in traindata_split]

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
        ), batch_size=batch_size, shuffle=True)

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:03, 2590267.69it/s]                             


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw


0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 60093.72it/s]                           
0it [00:00, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:03, 424061.60it/s]                             
0it [00:00, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 21795.03it/s]            

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!





In [5]:
# Instantiate models and optimizers

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())

opt = [optim.SGD(model.parameters(), lr=0.001) for model in client_models]
# Runnining FL

for round in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]

    # client update
    for i in range(num_selected):
        client_update(client_models[i], opt[i], train_loader[client_idx[i]], epoch=epochs)
    
    # serer aggregate
    server_aggregate(global_model, client_models)
    test_loss, acc = test(global_model, test_loader)

Train loss: 2.262230
Train loss: 2.235266
Train loss: 2.232000
Train loss: 2.267997
Train loss: 2.249782

Test loss: 2.2472, Accuracy: 2620/10000 (26%)

Train loss: 2.164184
Train loss: 2.178260
Train loss: 2.206033
Train loss: 2.205151
Train loss: 2.198921

Test loss: 2.1855, Accuracy: 3780/10000 (38%)

Train loss: 2.079116
Train loss: 2.085637
Train loss: 2.072364
Train loss: 2.144063
Train loss: 2.127028

Test loss: 2.0926, Accuracy: 4818/10000 (48%)

Train loss: 1.994810
Train loss: 1.936206
Train loss: 1.979727
Train loss: 1.950891
Train loss: 1.993379

Test loss: 1.9676, Accuracy: 5418/10000 (54%)

Train loss: 1.810959
Train loss: 1.921714
Train loss: 1.904141
Train loss: 1.787483
Train loss: 1.773445

Test loss: 1.7973, Accuracy: 5861/10000 (59%)



In [6]:
# Different conditions

epochs = 10   # more local epochs

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())

opt = [optim.SGD(model.parameters(), lr=0.001) for model in client_models]

for round in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]

    # client update
    for i in range(num_selected):
        client_update(client_models[i], opt[i], train_loader[client_idx[i]], epoch=epochs)
    
    # serer aggregate
    server_aggregate(global_model, client_models)
    test_loss, acc = test(global_model, test_loader)


num_selected = 10   # more clients

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())

opt = [optim.SGD(model.parameters(), lr=0.001) for model in client_models]

for round in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]

    # client update
    for i in range(num_selected):
        client_update(client_models[i], opt[i], train_loader[client_idx[i]], epoch=epochs)
    
    # serer aggregate
    server_aggregate(global_model, client_models)
    test_loss, acc = test(global_model, test_loader)

Train loss: 2.140615
Train loss: 2.093284
Train loss: 2.158375
Train loss: 2.140799
Train loss: 2.109223

Test loss: 2.1185, Accuracy: 4306/10000 (43%)

Train loss: 1.728553
Train loss: 1.687636
Train loss: 1.700896
Train loss: 1.610476
Train loss: 1.834261

Test loss: 1.7227, Accuracy: 6040/10000 (60%)

Train loss: 1.356174
Train loss: 1.155247
Train loss: 1.197011
Train loss: 1.267018
Train loss: 1.269918

Test loss: 1.2154, Accuracy: 6860/10000 (69%)

Train loss: 0.975063
Train loss: 0.661186
Train loss: 0.784656
Train loss: 0.827621
Train loss: 0.930010

Test loss: 0.8892, Accuracy: 7503/10000 (75%)

Train loss: 0.615026
Train loss: 0.667032
Train loss: 0.527943
Train loss: 0.809375
Train loss: 0.629045

Test loss: 0.7303, Accuracy: 7886/10000 (79%)

Train loss: 2.172666
Train loss: 2.125777
Train loss: 2.141700
Train loss: 2.061370
Train loss: 2.210159
Train loss: 2.071659
Train loss: 2.118256
Train loss: 2.160131
Train loss: 2.162329
Train loss: 2.110314

Test loss: 2.1613, Accur