# Section Project: Federated Learning

For the final project for this section, you're going to train a model on the MNIST dataset distributed across multiple devices **without retrieving the raw gradients to the local machine**.

## Import Modules

In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as data

import torchvision.datasets as datasets
import torchvision.transforms as transforms

import syft

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

hook = syft.TorchHook(torch)



## Create Workers

In [2]:
n_workers = 30

workers = [syft.VirtualWorker(hook, id="Worker:{:d}".format(i)) for i in range(n_workers)]
print(workers)

[<VirtualWorker id:Worker:0 #tensors:0>, <VirtualWorker id:Worker:1 #tensors:0>, <VirtualWorker id:Worker:2 #tensors:0>, <VirtualWorker id:Worker:3 #tensors:0>, <VirtualWorker id:Worker:4 #tensors:0>, <VirtualWorker id:Worker:5 #tensors:0>, <VirtualWorker id:Worker:6 #tensors:0>, <VirtualWorker id:Worker:7 #tensors:0>, <VirtualWorker id:Worker:8 #tensors:0>, <VirtualWorker id:Worker:9 #tensors:0>, <VirtualWorker id:Worker:10 #tensors:0>, <VirtualWorker id:Worker:11 #tensors:0>, <VirtualWorker id:Worker:12 #tensors:0>, <VirtualWorker id:Worker:13 #tensors:0>, <VirtualWorker id:Worker:14 #tensors:0>, <VirtualWorker id:Worker:15 #tensors:0>, <VirtualWorker id:Worker:16 #tensors:0>, <VirtualWorker id:Worker:17 #tensors:0>, <VirtualWorker id:Worker:18 #tensors:0>, <VirtualWorker id:Worker:19 #tensors:0>, <VirtualWorker id:Worker:20 #tensors:0>, <VirtualWorker id:Worker:21 #tensors:0>, <VirtualWorker id:Worker:22 #tensors:0>, <VirtualWorker id:Worker:23 #tensors:0>, <VirtualWorker id:Worker:

## Prepare Data

### Load the MNIST Training & Test Datasets

In [3]:
mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset  = datasets.MNIST(root='../data', train=False, download=True, transform=transforms.ToTensor())

print("Training Set Size:", len(mnist_trainset))
print("Test Set Size:", len(mnist_testset))

Training Set Size: 60000
Test Set Size: 10000


### Create Federated Training Dataset

In [4]:
federated_mnist_trainset = mnist_trainset.federate(workers)
print(federated_mnist_trainset)

FederatedDataset
    Distributed accross: Worker:0, Worker:1, Worker:2, Worker:3, Worker:4, Worker:5, Worker:6, Worker:7, Worker:8, Worker:9, Worker:10, Worker:11, Worker:12, Worker:13, Worker:14, Worker:15, Worker:16, Worker:17, Worker:18, Worker:19, Worker:20, Worker:21, Worker:22, Worker:23, Worker:24, Worker:25, Worker:26, Worker:27, Worker:28, Worker:29
    Number of datapoints: 60000



## Create the Classifier

In [5]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        
        ### Batch Normalization layers are incompatible with PySyft :(
        # 1x28x28
#         self.bn0        = nn.BatchNorm2d(1)
        self.conv0      = nn.Conv2d(1, 4, 3, padding=1)
#         self.bn1        = nn.BatchNorm2d(4)
        self.maxpool0   = nn.MaxPool2d(2)
        # 4x14x14
        self.conv1      = nn.Conv2d(4, 6, 3, padding=1)
#         self.bn2        = nn.BatchNorm2d(6)
        self.maxpool1   = nn.MaxPool2d(2)
        # 6x 7x 7
        self.conv2      = nn.Conv2d(6, 8, 3, padding=1)
#         self.bn3        = nn.BatchNorm2d(8)
        self.maxpool2   = nn.MaxPool2d(2, padding=1)
        # 8x 4x 4 = 128
        self.fc         = nn.Linear(128, 10)

        self.activation = nn.ReLU()
        
    def forward(self, x):
#         x = self.bn0(x)
        x = self.conv0(x)
#         x = self.bn1(x)
        x = self.activation(x)
        x = self.maxpool0(x)
        x = self.conv1(x)
#         x = self.bn2(x)
        x = self.activation(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
#         x = self.bn3(x)
        x = self.activation(x)
        x = self.maxpool2(x)
        x = self.fc(x.view(-1, 128))
        
        return x

In [6]:
model = MNISTClassifier()

## Federated Training

In [7]:
n_epochs   = 50
lr         = 2e-2
batch_size = 32

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(reduction='sum')

for i_epoch in range(n_epochs):
    print("Epoch {:d}:".format(i_epoch))
    
    optimizer.zero_grad()
    model = model.send(syft.local_worker)
    # Track the number of training examples in order to average the gradients later (+ additional stats)
    instance_count = torch.tensor(0.).send(syft.local_worker)
    running_loss   = torch.tensor(0.).send(syft.local_worker)
    correct_count  = torch.tensor(0.).send(syft.local_worker)
    for i, worker in enumerate(workers):
        print("worker {:d}/{:d}".format(i, len(workers)-1), end='\r')
        model.move(worker)
        instance_count.move(worker)
        running_loss.move(worker)
        correct_count.move(worker)

        dataset = federated_mnist_trainset.datasets[worker.id]
        dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        imgs, labels = next(iter(dataloader))

        instance_count.add_(imgs.shape[0]) # calling `.size()` on a remote tensor returns `Size([0])`, so getting `.shape` instead.
        
        preds = model(imgs)

        loss = criterion(preds, labels)
        loss.backward()
        
        running_loss.add_(loss.data)
        correct_count.add_(torch.sum(torch.eq(preds.data.argmax(dim=1), labels)))
    
    for param in model.parameters():
        param.grad.div_(instance_count)

    model = model.get()

    avg_loss     = running_loss.div_(instance_count).get().item()
    avg_accuracy = correct_count.div_(instance_count).get().item()
    
    print()
    print("    Training Loss:", avg_loss)
    print("    Training Accuracy:", avg_accuracy)

    optimizer.step()

Epoch 0:
worker 29/29
    Training Loss: 2.306910514831543
    Training Accuracy: 0.08124999701976776
Epoch 1:
worker 29/29
    Training Loss: 2.2990634441375732
    Training Accuracy: 0.10104166716337204
Epoch 2:
worker 29/29
    Training Loss: 2.2804172039031982
    Training Accuracy: 0.22604165971279144
Epoch 3:
worker 29/29
    Training Loss: 2.2394001483917236
    Training Accuracy: 0.4260416626930237
Epoch 4:
worker 29/29
    Training Loss: 2.1809535026550293
    Training Accuracy: 0.36666667461395264
Epoch 5:
worker 29/29
    Training Loss: 2.061296224594116
    Training Accuracy: 0.4645833373069763
Epoch 6:
worker 29/29
    Training Loss: 1.889341115951538
    Training Accuracy: 0.47708332538604736
Epoch 7:
worker 29/29
    Training Loss: 1.6500238180160522
    Training Accuracy: 0.5583333373069763
Epoch 8:
worker 29/29
    Training Loss: 1.3526575565338135
    Training Accuracy: 0.5885416865348816
Epoch 9:
worker 29/29
    Training Loss: 1.1131576299667358
    Training Accurac

## Evaluate the Final Model

In [8]:
test_dataloader = data.DataLoader(mnist_testset, batch_size=1024)

test_loss      = 0
instance_count = 0
correct_count  = 0
with torch.no_grad():
    for i, (imgs, labels) in enumerate(test_dataloader, 1):
        print("Batch {:d}/{:d}".format(i, len(test_dataloader)), end='\r')
        instance_count += imgs.size(0)

        preds = model(imgs)

        test_loss += criterion(preds, labels).item()
        correct_count += (preds.argmax(dim=1) == labels).sum().item()

print()
print("Test Loss:", test_loss / instance_count)
print("Test Accuracy:", correct_count / instance_count)

Batch 10/10
Test Loss: 0.19153675537109374
Test Accuracy: 0.9393


---