In [1]:
# From https://towardsdatascience.com/preserving-data-privacy-in-deep-learning-part-1-a04894f78029

In [2]:
import sys; sys.path.insert(0, '../..')

import syft as sy
from src.models.QuantizationModel import s_quantization

from src.deeplearning.Optimizer import DianaOptimizer

In [3]:
###############################
##### importing libraries #####
###############################

import os
import copy
import random
from tqdm import tqdm
import numpy as np
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset   
torch.backends.cudnn.benchmark=True

In [14]:
##### Hyperparameters for federated learning #########
num_clients = 10
num_selected = num_clients
num_rounds = 21
epochs = 1
batch_size = 64
quantization_param = 1
lr = 1e-3
dl_net = "VGGKo"

In [5]:
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

trans = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0.5,), (1.0,))
                                 ])

# TODO : To test, not sure what it does exactly.
dataset = Subset(MNIST(root="./", download=True, train=True, transform=trans), range(300))

# TODO : To test, not sure what it does exactly.
test_loader = DataLoader(MNIST(root="./", download=True, train=False, transform=trans))

loaders = []
for _ in range(num_clients):
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
    loaders.append(loader)


In [6]:
#################################
##### Neural Network model #####
#################################

cfg = {
    'VGGKo' : [64, 'M', 128, 'M', 256, 'M', 512, 'M', 512, 'M'],
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        output = F.log_softmax(out, dim=1)
        return output

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim


class TwoLayersModel(nn.Module):
    def __init__(self):
        super(TwoLayersModel, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        return self.fc2(x)



In [8]:
#################################
##### Defensive programming #####
#################################

def check_model_has_been_updated(old_model, new_model, state):
    updated = False
    cpt = 0
    for old_param, new_param in zip(old_model.parameters(), new_model.parameters()):
        cpt += 1
        if not torch.equal(old_param, new_param):
            updated = True
    assert updated == True, "The %s model hasn't been updated." %state

In [9]:
#################################
##### Central server model #####
#################################

class CentralServer:
    
    def __init__(self, test_loader):
        self.test_loader = test_loader
        self.model =  TwoLayersModel()
        self.optimizer = optim.SGD(self.model.parameters(), lr=lr)
        
    def server_aggregate(self, aggregated_grad):
        """
        This function has aggregation method 'mean'
        """
        ### This will take simple mean of the weights of models ###
        old_model = copy.deepcopy(self.model)

        for cpt, p in enumerate(self.model.parameters()):
            p.grad = aggregated_grad[cpt]
        self.optimizer.step()
        
        # Checking that the parameters has been updated.
        check_model_has_been_updated(old_model, self.model, "global")
        
    def test(self):
        """This function test the global model on test data and returns test loss and test accuracy """
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = data, target
                output = self.model(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.test_loader.dataset)
        acc = correct / len(self.test_loader.dataset)

        return test_loss, acc


In [10]:
class Worker:
    
    def __init__(self, train_loader, ID: int):
        self.ID = ID
        self.train_loader = train_loader
        self.model = TwoLayersModel()
        self.optimizer = optim.SGD(self.model.parameters(), lr=lr)
        self.criterion = nn.CrossEntropyLoss()
        
    def client_update(self, local_epoch=1):
        """
        This function updates/trains client model on client data
        """
        self.model.train() # Sets the module in training mode.

        local_grad = []

        for e in range(local_epoch):
            # We handle a single batch.
            data, target = next(iter(self.train_loader))
            #print("ID " + str(self.ID) + ": " + str(target))
            
            y_pred = self.model(data).squeeze()
            loss = self.criterion(y_pred, target)
            
    
            loss.backward()
            
            
            
            # Local compression
            for p in self.model.parameters():
                local_grad.append(p.grad.data)#s_quantization(p.grad, quantization_param))
                
            # We do not carry out the optimization step which update models. 
            # The model will be updated with the compressed aggregation of all gradients.
            #optimizer.step()

        return loss, local_grad
    
    def set_grad_tensor(self, aggregated_grad):
        #old_model = copy.deepcopy(self.model)
                
        for cpt, p in enumerate(self.model.parameters()):
            p.grad.data = aggregated_grad[cpt]

        
        # Checking that the parameters has been updated.
        #check_model_has_been_updated(old_model, self.model, "local")
        
    def optimizer_step(self):
        self.optimizer.step()
        

In [11]:
###### Defining all the workers #########
workers = [Worker(loaders[i], i) for i in range(num_clients)]
central_server = CentralServer(test_loader)

In [12]:
###### List containing info about learning #########
losses_train = []
losses_test = []
acc_train = []
acc_test = []
# Runnining FL

for r in range(1, num_rounds + 1):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]
    # client update
    loss = 0
    grad = []
    
    # Taking into account partial participation
    for i in range(num_selected):
        local_loss, local_grad = workers[client_idx[i]].client_update(local_epoch=epochs)
        loss += local_loss
        if not grad:
            grad = local_grad
        else:
            for i in range(len(grad)):
                grad[i] += local_grad[i]
                
    # Bidirectional compression:
    omega = []
    for i in range(len(grad)):
        omega.append(grad[i] / num_selected)#s_quantization(grad[i] / num_selected, quantization_param))
        
    print("==--> averaged loss :", loss / num_selected)
    
    losses_train.append(loss)
    # server aggregate
    central_server.server_aggregate(omega)
    for worker in workers:
        worker.set_grad_tensor(omega)
        worker.optimizer_step()
   
    if r % 10 == 0:
        test_loss, acc = central_server.test()
        losses_test.append(test_loss)
        acc_test.append(acc)
        print('%d-th round' % r)
        print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))

==--> averaged loss : tensor(2.3079, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(2.3055, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(535.2595, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(2.8539e+11, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(1.6131e+20, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(1.0203e+29, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(inf, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(nan, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(nan, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(nan, grad_fn=<DivBackward0>)
10-th round
average train loss nan | test loss nan | test acc: 0.098
==--> averaged loss : tensor(nan, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(nan, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(nan, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(nan, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(nan, grad_fn=<DivBackward0>)
==--> averaged loss : tensor(

In [13]:
for cpt, p in enumerate(central_server.model.parameters()):
    if cpt == 0:
        print(p[0])

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, n