# Section Project: Federated Learning with Encrypted Gradient Aggregation

For the final project for this section, you're going to perform federated learning using the encryption and secret sharing methods you learned in the section.

## Import Modules

In [1]:
import random
import numpy as np

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as data

from fixed_adam import Adam

import torchvision.datasets as datasets
import torchvision.transforms as transforms

import syft

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

hook = syft.TorchHook(torch)



## Create Workers

In [2]:
n_workers = 30

workers = [syft.VirtualWorker(hook, id="Worker:{:d}".format(i)) for i in range(n_workers)]

# for i in range(len(workers)):
#     workers[i].add_workers(workers[:i] + workers[i+1:])

## Prepare Data

### Load the MNIST Training & Test Datasets

In [3]:
mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset  = datasets.MNIST(root='../data', train=False, download=True, transform=transforms.ToTensor())

print("Training Set Size:", len(mnist_trainset))
print("Test Set Size:", len(mnist_testset))

Training Set Size: 60000
Test Set Size: 10000


### Create Federated Training Dataset

In [4]:
federated_mnist_trainset = mnist_trainset.federate(workers)
print(federated_mnist_trainset)

FederatedDataset
    Distributed accross: Worker:0, Worker:1, Worker:2, Worker:3, Worker:4, Worker:5, Worker:6, Worker:7, Worker:8, Worker:9, Worker:10, Worker:11, Worker:12, Worker:13, Worker:14, Worker:15, Worker:16, Worker:17, Worker:18, Worker:19, Worker:20, Worker:21, Worker:22, Worker:23, Worker:24, Worker:25, Worker:26, Worker:27, Worker:28, Worker:29
    Number of datapoints: 60000



## Create the Classifier

In [5]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()

        # 1x28x28
        self.conv0      = nn.Conv2d(1, 4, 3, padding=1)
        self.maxpool0   = nn.MaxPool2d(2)
        # 4x14x14
        self.conv1      = nn.Conv2d(4, 6, 3, padding=1)
        self.maxpool1   = nn.MaxPool2d(2)
        # 6x 7x 7
        self.conv2      = nn.Conv2d(6, 8, 3, padding=1)
        self.maxpool2   = nn.MaxPool2d(2, padding=1)
        # 8x 4x 4 = 128
        self.fc         = nn.Linear(128, 10)

        self.activation = nn.ReLU()
        
    def forward(self, x):
        x = self.conv0(x)
        x = self.activation(x)
        x = self.maxpool0(x)
        x = self.conv1(x)
        x = self.activation(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.maxpool2(x)
        x = self.fc(x.view(-1, 128))
        
        return x

## Encrypted Federated Training

### Averaged Model Parameters Approach

In [6]:
model = MNISTClassifier()

In [7]:
n_epochs   = 10
n_steps    = 10
lr         = 1e-2
batch_size = 64

local_models     = [model.copy().send(worker) for worker in workers]
local_optimizers = [Adam(local_model.parameters(), lr=lr) for local_model in local_models]
criterion        = nn.CrossEntropyLoss()
test_dataloader  = data.DataLoader(mnist_testset, batch_size=1024)

for i_epoch in range(n_epochs):
    print("Epoch {:d}:".format(i_epoch))

    print("worker 0/{:d} - Step 0/{:d}                                ".format(len(workers)-1, n_steps-1), end='\r')
    for i, (worker, local_model, local_optimizer) in enumerate(zip(workers, local_models, local_optimizers)):

        dataset = federated_mnist_trainset.datasets[worker.id]
        dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for i_step in range(n_steps):

            imgs, labels = next(iter(dataloader))

            preds = local_model(imgs)

            local_optimizer.zero_grad()
            loss = criterion(preds, labels)
            loss.backward()
            local_optimizer.step()

            loss = loss.data.clone().get().item()
            acc  = (preds.argmax(dim=1) == labels).float().mean().get().item()
            print("worker {:d}/{:d} - Step {:d}/{:d} | Loss={:.4f} | Accuracy={:.4f}        ".format(i, len(workers)-1, i_step, n_steps-1, loss, acc), end='\r')

    ### [Clarification Requested]
    ### Loop through the main model parameters and corresponding parameters of the independently trained local models
    for global_param, local_params in zip(model.parameters(), zip(*[local_model.parameters() for local_model in local_models])):
        ### Share each local parameter to all other workers
        ### (I'm not sure whether having each of the cryptoprovider as the original owner is a right way to do this...)
        ### (Maybe I should use a separate trusted crypto provider?)
        local_param_shares = [local_param.clone().fix_prec().share(*workers, crypto_provider=worker).get()
                              for local_param, worker in zip(local_params, workers)]
        ### Sum the shares, retrieve them, turn the fixed precision values into floating ones,
        ### and get the average by dividing them by the number of local parameters
        avg_param = sum(local_param_shares).get().float_prec() / len(local_params)
        ### Update the main model's parameter with the average parameter of the local models
        global_param.data.copy_(avg_param)
        ### Update the local parameters by sending a copy of the updated main parameter to each worker
        for local_param, worker in zip(local_params, workers):
            local_param.data.copy_(global_param.data.clone().send(worker))

    test_loss      = 0
    instance_count = 0
    correct_count  = 0
    with torch.no_grad():
        for imgs, labels in test_dataloader:
            instance_count += imgs.size(0)

            preds = model(imgs)

            test_loss += criterion(preds, labels).item() * imgs.size(0)
            correct_count += (preds.argmax(dim=1) == labels).sum().item()

    print()
    print("    Test Loss:", test_loss / instance_count)
    print("    Test Accuracy:", correct_count / instance_count)

Epoch 0:
worker 29/29 - Step 9/9 | Loss=2.2068 | Accuracy=0.1719        
    Test Loss: 2.2098304954528807
    Test Accuracy: 0.3036
Epoch 1:
worker 29/29 - Step 9/9 | Loss=1.4151 | Accuracy=0.5469        
    Test Loss: 1.2848269987106322
    Test Accuracy: 0.6728
Epoch 2:
worker 29/29 - Step 9/9 | Loss=0.5334 | Accuracy=0.8125        
    Test Loss: 0.6508465280532837
    Test Accuracy: 0.7922
Epoch 3:
worker 29/29 - Step 9/9 | Loss=0.4554 | Accuracy=0.8594        
    Test Loss: 0.46136472125053407
    Test Accuracy: 0.8697
Epoch 4:
worker 29/29 - Step 9/9 | Loss=0.4278 | Accuracy=0.8750        
    Test Loss: 0.3562229021072388
    Test Accuracy: 0.8954
Epoch 5:
worker 29/29 - Step 9/9 | Loss=0.1859 | Accuracy=0.9531        
    Test Loss: 0.2978161085605621
    Test Accuracy: 0.9113
Epoch 6:
worker 29/29 - Step 9/9 | Loss=0.3252 | Accuracy=0.9219        
    Test Loss: 0.2570676740407944
    Test Accuracy: 0.9216
Epoch 7:
worker 29/29 - Step 9/9 | Loss=0.2954 | Accuracy=0.9062    

### Averaged Gradients Approach

In [8]:
model = MNISTClassifier()

In [9]:
n_epochs   = 100
lr         = 2e-2
batch_size = 32

local_models = [model.copy().send(worker) for worker in workers]
optimizer    = optim.Adam(model.parameters(), lr=lr)
criterion    = nn.CrossEntropyLoss(reduction='sum')

for param in model.parameters():
    param.grad = torch.zeros_like(param.data)

for i_epoch in range(n_epochs):
    print("Epoch {:d}:".format(i_epoch))
    
    optimizer.zero_grad()

    grad_shares_lists = [[] for _ in model.parameters()]

    instance_count_shares = []

    loss_shares           = []
    correct_count_shares  = []

    for i, (worker, local_model) in enumerate(zip(workers, local_models)):
        print("worker {:d}/{:d}".format(i, len(workers)-1), end='\r')

        dataset = federated_mnist_trainset.datasets[worker.id]
        dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        imgs, labels = next(iter(dataloader))

        preds = local_model(imgs)

        loss = criterion(preds, labels)
        loss.backward()

        for grad_shares_list, param in zip(grad_shares_lists, local_model.parameters()):
            grad_shares_list.append(param.grad.clone().fix_prec().share(*workers, crypto_provider=worker).get())
            param.grad.zero_()

        instance_count_shares.append(torch.tensor(0, dtype=torch.long).send(worker).add_(imgs.shape[0]).share(*workers, crypto_provider=worker).get())
        loss_shares.append(loss.data.clone().fix_prec().share(*workers, crypto_provider=worker).get())
        correct_count_shares.append((preds.data.argmax(dim=1) == labels).sum().share(*workers, crypto_provider=worker).get())

    instance_count = sum(instance_count_shares).get().item()

    for param, grad_shares_list in zip(model.parameters(), grad_shares_lists):
        param.grad.copy_(sum(grad_shares_list).get().float_prec() / instance_count)

    avg_loss     = sum(loss_shares).get().float_prec().item() / instance_count
    avg_accuracy = sum(correct_count_shares).get().item() / instance_count
    
    print()
    print("    Training Loss:", avg_loss)
    print("    Training Accuracy:", avg_accuracy)

    optimizer.step()
    
    for global_param, local_params in zip(model.parameters(), zip(*[local_model.parameters() for local_model in local_models])):
        for worker, local_param in zip(workers, local_params):
            local_param.data.copy_(global_param.data.clone().send(worker))

Epoch 0:
worker 29/29
    Training Loss: 2.3048917134602864
    Training Accuracy: 0.084375
Epoch 1:
worker 29/29
    Training Loss: 2.2998311360677084
    Training Accuracy: 0.11354166666666667
Epoch 2:
worker 29/29
    Training Loss: 2.2929842631022135
    Training Accuracy: 0.11666666666666667
Epoch 3:
worker 29/29
    Training Loss: 2.2694302876790364
    Training Accuracy: 0.11354166666666667
Epoch 4:
worker 29/29
    Training Loss: 2.2462562561035155
    Training Accuracy: 0.09895833333333333
Epoch 5:
worker 29/29
    Training Loss: 2.1893885294596354
    Training Accuracy: 0.221875
Epoch 6:
worker 29/29
    Training Loss: 2.0800530751546225
    Training Accuracy: 0.37395833333333334
Epoch 7:
worker 29/29
    Training Loss: 1.925430170694987
    Training Accuracy: 0.453125
Epoch 8:
worker 29/29
    Training Loss: 1.72519900004069
    Training Accuracy: 0.5791666666666667
Epoch 9:
worker 29/29
    Training Loss: 1.5279937744140626
    Training Accuracy: 0.5458333333333333
Epoch 10

In [10]:
test_dataloader = data.DataLoader(mnist_testset, batch_size=1024)

test_loss      = 0
instance_count = 0
correct_count  = 0
with torch.no_grad():
    for i, (imgs, labels) in enumerate(test_dataloader, 1):
        print("Batch {:d}/{:d}".format(i, len(test_dataloader)), end='\r')
        instance_count += imgs.size(0)

        preds = model(imgs)

        test_loss += criterion(preds, labels).item()
        correct_count += (preds.argmax(dim=1) == labels).sum().item()

print()
print("Test Loss:", test_loss / instance_count)
print("Test Accuracy:", correct_count / instance_count)

Batch 10/10
Test Loss: 0.12317433090209962
Test Accuracy: 0.9614


---