In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline 
%config Completer.use_jedi = False

# Communication Round

# Centralied Federeated Learning

# Decentralized Federated Learning

# Simulated Client Model Training on a Single System

In [2]:
import os, random
from tqdm import tqdm
import numpy as np
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset

torch.backends.cudnn.benchmark=True

In [3]:
##### Hyperparameters for federated learning #########
num_clients = 20
num_selected = 6
num_rounds = 150
epochs = 5
batch_size = 32

In [4]:
# IID Case
#############################################################
##### Creating desired data distribution among clients  #####
#############################################################

# Image Augmentation
transform_train = transforms.Compose(
[
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [5]:
# Loading CIFAR10 using torchvision.datasets
traindata = datasets.CIFAR10('data', train=True, download=True, transform=transform_train)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data\cifar-10-python.tar.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting data\cifar-10-python.tar.gz to data


In [6]:
# Dividing the training data into num_clients, with each client having equal number of images
traindata_split = torch.utils.data.random_split(traindata, [int(traindata.data.shape[0]/num_clients) for _ in range(num_clients)])

In [7]:
# Creating a pytorch loader for a Deep Learning model
train_loader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True) for x in traindata_split]

In [8]:
# Normalizing the test images
transform_test = transforms.Compose(
[
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [10]:
# Loading the test iamges and thus converting them into a test_loader
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=False, transform=transform_test),
    batch_size=batch_size,
    shuffle=True
)

In [31]:
#################################
##### Neural Network model #####
#################################
class VGG(nn.Module):
    cfg = { 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
                'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
                'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
                'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],}
    def __init__(self, vgg_name):
        super().__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 10)
        )
        
    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        output = F.log_softmax(out, dim=1)
        return output
    
    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(x)
                          ]
                in_channels = x
        
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        
        return nn.Sequential(*layers)
                

In [18]:
def client_update(client_model, optimizer, train_loader, epochs=5):
    """
    This function updates/trains client model on client data
    """
    client_model.train()
    for e in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    return loss.item()

In [19]:
def server_aggregate(global_model, client_models):
    """
    This function has aggregation method 'mean'
    """
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

In [20]:
def test(global_model, test_loader):
    """This function test the global model on test data and returns test loss and test accuracy """
    global_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = global_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    acc = correct/len(test_loader.dataset)
    
    return test_loss, acc

In [23]:
############################################
#### Initializing models and optimizer  ####
############################################

In [37]:
#### global model ##########
global_model = VGG('VGG19').cuda()

In [33]:
############## client models ##############
client_models = [VGG('VGG19').cuda() for _ in range(num_selected)]

In [38]:
for model in client_models:
    model.load_state_dict(global_model.state_dict()) ### initial synchronizing with global model 

In [40]:
# client optimizers
opt = [optim.SGD(model.parameters(), lr=0.1) for model in client_models]

In [41]:
###### List containing info about learning #########
losses_train = []
losses_test = []
acc_train = []
acc_test = []


In [None]:
# Runnining FL
loss = 0
for r in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]
    # client update
    for i in tqdm(range(num_selected)):
        loss += client_update(client_models[i], opt[i], train_loader[client_idx[i]], epochs)
        
    losses_train.append(loss)
    
    # server aggregate
    server_aggregate(global_model, client_models)
    
    test_loss, acc = test(global_model, test_loader)
    
    losses_test.append(test_loss)
    acc_test.append(acc)
    
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [01:22<00:00, 13.82s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

0-th round
average train loss 1.91 | test loss 2.33 | test acc: 0.100


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:35<00:00,  5.92s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

1-th round
average train loss 4.01 | test loss 2.35 | test acc: 0.100


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:35<00:00,  5.94s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

2-th round
average train loss 6.01 | test loss 1.8 | test acc: 0.294


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:36<00:00,  6.00s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

3-th round
average train loss 8.09 | test loss 1.89 | test acc: 0.312


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:36<00:00,  6.16s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

4-th round
average train loss 9.91 | test loss 1.44 | test acc: 0.489


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.35s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

5-th round
average train loss 11.2 | test loss 1.19 | test acc: 0.569


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.32s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

6-th round
average train loss 13.5 | test loss 1.29 | test acc: 0.556


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.44s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

7-th round
average train loss 15.1 | test loss 1.13 | test acc: 0.610


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.41s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

8-th round
average train loss 16.7 | test loss 1.63 | test acc: 0.532


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.45s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

9-th round
average train loss 17.6 | test loss 0.898 | test acc: 0.691


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.28s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

10-th round
average train loss 19.3 | test loss 1.16 | test acc: 0.626


100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:35<00:00,  5.89s/it]
  0%|                                                                                            | 0/6 [00:00<?, ?it/s]

11-th round
average train loss 21 | test loss 0.949 | test acc: 0.689


 67%|████████████████████████████████████████████████████████                            | 4/6 [00:23<00:11,  5.94s/it]