# Federated Learning Walkthrough

In this notebook you will get hands on experience programming a federated learning simulation. Note that, this notebook, through programming abstractions, simulates the clients and servers that operate in a federated learning setup. This notebook serves as an example to illustrate the steps taken in a federated learning system.

First, let's import our libraries, load the mnist dataset and define our model.

In [None]:
# Import modules
import os
import sys
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import random
import torch.nn.functional as F
import copy


In [None]:
# Load MNIST Dataset
d = './data'
if not os.path.exists(d):
    os.mkdir(d)
    
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = datasets.MNIST(root=d, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=d, train=False, transform=trans, download=True)

batch_size = 32
global_train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
global_test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)


In [None]:
# Define MNIST model
class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# Client Creation

Next, let's create a set of 10 clients that logically represent out set of participants that are performing federated learning. Each client has its own partition of the training dataset and its own local model.

In [None]:
n_clients = 10

def create_client(client_id, local_dataset, batch_size=32):
  model =  MLP()
  loader = torch.utils.data.DataLoader(
                 dataset=local_dataset,
                 batch_size=batch_size,
                 shuffle=True)
  return {"client_id": client_id,
          "local_dataset": loader,
          "local_model" : model,
          "optimizer": optim.SGD(model.parameters(), lr=0.01, momentum=0.9)}

# Partition datapoints
local_datasets = [[] for i in range(n_clients)]
for i, datapoint in enumerate(train_set):
  local_datasets[i%n_clients].append(datapoint)

# Create clients
clients = [create_client(i, local_datasets[i]) for i in range(n_clients)]

# Federated Learning Training
Now let's perform the federated learning loop! Recall that the steps of federated learning are:

* Model broadcasting
* Local training
* Model aggregation
* Model update

The next code block will simulate this process with the controlling logic acting as the central server.

In [None]:
def client_load_model(client, model):
  client["local_model"].load_state_dict(model.state_dict())

def client_local_training(client):
  criterion = nn.CrossEntropyLoss()
  optimizer = client["optimizer"]
  dataset = client["local_dataset"]
  model = client["local_model"]
  for batch_idx, (x, target) in enumerate(dataset):
    x, target = Variable(x), Variable(target)
    out = model(x)
    loss = criterion(out, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()                

def server_aggregate_models(clients):
  models = [x["local_model"].state_dict() for x in clients]
  averaged_model = copy.deepcopy(models[0])
  for k,v in averaged_model.items():
    averaged_model[k] = sum([m[k] for m in models]) / len(models)
  return averaged_model

def evaluate(model, dataset):
  criterion = nn.CrossEntropyLoss()
  total_loss, total, correct = 0, 0, 0
  for batch_idx, (x, target) in enumerate(dataset):
    x, target = Variable(x), Variable(target)
    out = model(x)
    _, pred_label = torch.max(out.data, 1)
    loss = criterion(out, target)
    total_loss += loss.item()    
    total += x.data.size()[0]
    correct += (pred_label == target.data).sum()
  print("Loss: %f, acc: %f" % (total_loss, correct/total))

def federated_learning_loop():  

  global_model = MLP()
  client_participation_fraction = .2
  rounds = 1000

  for r in range(rounds):

    # Broadcast global model to clients
    for c in clients:
      client_load_model(c, global_model)
    
    # Selected clients perform local training
    for c in clients:
      if random.random() <= client_participation_fraction:
        client_local_training(c)

    # Aggregate the models
    aggregated_model = server_aggregate_models(clients)

    # Update global model
    global_model.load_state_dict(aggregated_model)

    # Evaluation
    evaluate(global_model, global_test_loader)

federated_learning_loop()
