# Federated Learning - MNIST Example

## Train a remote Deep Learning model
In this notebbok, we will show how to train a Federated Deep Learning with data hosted in Nodes.

We will consider that you are a Data Scientist and you do not know where data lives, you only have access to GridNetwork

## 0 - Previous setup

Components:

 - PyGrid Network      http://alice:7000
 - PyGrid Node Alice (http://bob:5000)
 - PyGrid Node Bob   (http://charlie:5001)

This tutorial assumes that these components are running in background. See [instructions](https://github.com/OpenMined/PyGrid/tree/dev/examples#how-to-run-this-tutorial) for more details.

### Import dependencies
Here we import core dependencies

In [16]:
import syft as sy
from syft.grid.public_grid import PublicGridNetwork
import numpy as np
import torch as th

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
from torchvision import datasets, transforms


### Syft and client configuration
Now we hook Torch and connect to the GridNetwork. This is the only sever you do not need to know node addresses (networks knows), but lets first define some useful parameters

In [2]:


NUM_CLASSES = 10
parties = 8
TAG_NAME = "mnist_test_"+str(parties)+"nodes_ns"
#TAG_NAME = "mnist_small"

#TAG_NAME = "NPC_500_2nodes"
# mnist_test_8nodes  mnist_test_4nodes  mnist_test  mnist_test_small(4)   mnist_test_small2 (2node)
grid_address = "http://203.145.221.20:80"  # address


AGG_EPOCH = 2
EPOCHS = 10
N_EPOCHS = AGG_EPOCH*EPOCHS  # number of epochs to train
N_TEST   = 128   # number of test
train_batch_size = 16
N_LOG = 1

LR = 0.01
momentum = 0.9
weight_decay = 1e-5

node_name = ["gridnode01","gridnode02","gridnode03","gridnode04","gridnode05","gridnode06","gridnode07","gridnode08"]
output_model_folder = 'model'

In [3]:
hook = sy.TorchHook(th)


# Connect direcly to grid nodes
my_grid = PublicGridNetwork(hook, grid_address)

## 1 - Define our Neural Network Arquitecture

Now we will define a Deep Learning Network, feel free to write your own model!

In [4]:
class Arguments():
    def __init__(self):
        self.test_batch_size = N_TEST
        self.epochs = N_EPOCHS
        self.lr = LR
        self.log_interval = N_LOG
        self.momentum = momentum
        self.weight_decay = weight_decay
        #self.device = th.device("cpu")
        
args = Arguments()

In [5]:
device = th.device("cuda" if th.cuda.is_available() else "cpu")
#device=[th.device("cuda:2"),th.device("cuda:3")]
if(th.cuda.is_available()):
     th.set_default_tensor_type(th.cuda.FloatTensor)

### small network (skip if use ResNet)

In [17]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        #self.conv2 = nn.Conv2d(32, 64, 3, 1)
        #self.dropout1 = nn.Dropout(0.25)
        #self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(5408, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        #print(x.size())
        x = self.conv1(x)        
        x = F.relu(x)        
        #x = self.conv2(x)
        #x = F.relu(x)
        x = F.max_pool2d(x, 2)        
        #x = self.dropout1(x)
        x = th.flatten(x, 1)        
        x = self.fc1(x)        
        x = F.relu(x)
        #x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


## 2 - Search for remote data

Once we have defined our Deep Learning Network, we need some data to train... Thanks to PyGridNetwork this is very easy, you just need to search for your tags of interest.

Notice that _search()_ method  returns a pointer tensor, so we will work with those keeping real tensors hosted in Alice and Bob

## search fixed amount of data

In [7]:
data = my_grid.search("#X_"+TAG_NAME)  # images
target = my_grid.search("#Y_"+TAG_NAME)  # labels

data = list(data.values())  # returns a pointer
target = list(target.values())  # returns a pointer

In [8]:
print(data)
print(target)

[[(Wrapper)>[PointerTensor | me:3448380049 -> gridnode01:97351606311]
	Tags: #X_mnist_test_8nodes_ns 
	Shape: torch.Size([7500, 1, 28, 28])
	Description: input mnist datapoinsts split 8 parties...], [(Wrapper)>[PointerTensor | me:18147373791 -> gridnode03:55766307204]
	Tags: #X_mnist_test_8nodes_ns 
	Shape: torch.Size([7500, 1, 28, 28])
	Description: input mnist datapoinsts split 8 parties...], [(Wrapper)>[PointerTensor | me:1064783392 -> gridnode04:65149853549]
	Tags: #X_mnist_test_8nodes_ns 
	Shape: torch.Size([7500, 1, 28, 28])
	Description: input mnist datapoinsts split 8 parties...], [(Wrapper)>[PointerTensor | me:48059717666 -> gridnode05:78174810595]
	Tags: #X_mnist_test_8nodes_ns 
	Shape: torch.Size([7500, 1, 28, 28])
	Description: input mnist datapoinsts split 8 parties...], [(Wrapper)>[PointerTensor | me:60555040428 -> gridnode06:49129793709]
	Tags: #X_mnist_test_8nodes_ns 
	Shape: torch.Size([7500, 1, 28, 28])
	Description: input mnist datapoinsts split 8 parties...], [(Wrap

## 3 - Train the model

In [9]:
from mnist_loader import read_mnist_data
transform = transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),  #  mean and std 
                              ])
#npz_path = '../'+str(parties)+'Parties/data_party0.npz'
#trainloader,testloader = read_mnist_data(npz_path, batch = args.test_batch_size )
testset = datasets.MNIST('../8node/dataset2', download=True, train=False, transform=transform)
testloader = th.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=True)

### original mnist train data

In [None]:
trainset = datasets.MNIST('../8node/dataset2', download=False, train=True, transform=transform)
trainloader = th.utils.data.DataLoader(trainset, batch_size=args.test_batch_size, shuffle=True)

### initial model

In [18]:
def init_model(parties):
    model_list=[None] * parties
    optims_list=[None] * parties
    for i in range(parties):
        model_list[i] = Net()
#         model_list[i] = torchvision.models.resnet50(pretrained=False)
        
#         model_list[i].conv1 = th.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
#         batch = th.rand(4, 1, 224, 224) # in the NCHW format
#         model_list[i](batch).size()
        
        
        in_feature =  model_list[i].fc.in_features
        model_list[i].fc = nn.Linear(in_feature, NUM_CLASSES)  # ch
        model_list[i].to(device)
        #optims_list[i] = Optims(workers, optim=optim.Adam(params=model_list[i].parameters(),lr=args.lr,weight_decay=args.weight_decay))
        #optims_list[i] = Optims(workers, optim=optim.Adam(params=model_list[i].parameters(),lr=args.lr))
        optims_list[i] = Optims(workers, optim=optim.SGD(params=model_list[i].parameters(),lr=args.lr, momentum = args.momentum,weight_decay=args.weight_decay))
        #optims_list[i] = Optims(workers, optim=optim.SGD(params=model_list[i].parameters(),lr=args.lr, momentum = args.momentum))
    return model_list, optims_list

In [11]:
def avgWeight(model_list):
    model_tmp=[None] * parties
    optims_tmp=[None] * parties

    for idx, my_model in enumerate(model_list):
        
        model_tmp[idx] = my_model.state_dict()


    for key in model_tmp[0]:    
        print(key)
        model_sum = 0
        for model_tmp_content in model_tmp:        
            model_sum += model_tmp_content[key]
            #print(model_tmp_content[key])
        for i in range(len(model_tmp)):
            #print("model_sum={}".format(model_sum))
            #print("len:{}".format(len(model_tmp)))
            model_avg = model_sum/len(model_tmp)
            #print("model_avg={}".format(model_avg))
            model_tmp[i][key] = model_sum/len(model_tmp)
    for i in range(len(model_list)):    
        model_list[i].load_state_dict(model_tmp[i])
        optims_tmp[i] = Optims(workers, optim=optim.SGD(params=model_list[i].parameters(),lr=args.lr, momentum = args.momentum,weight_decay=args.weight_decay))
        #optims_tmp[i] = Optims(workers, optim=optim.Adam(params=model_list[i].parameters(),lr=args.lr))
    return model_list, optims_tmp

In [12]:
from syft.federated.floptimizer import Optims

workers =node_name[:parties]
criterion = nn.CrossEntropyLoss()
model_list, optims_list = init_model(parties)

In [13]:
def train(curr_model, curr_optims, args):
#    shuffle_list = np.arange(len(datalist))
#    np.random.shuffle(shuffle_list)
#    datalist = datalist[shuffle_list]
#    targetlist = targetlist[shuffle_list]

    
    

    for i in range(len(data)):
        
        curr_model[i].train()
        print(next(curr_model[i].parameters()).is_cuda )
        
        # This loop is for "a bunch of data" searched on the node.
        # Equals to an epoch for a node if there is only "one bunch of data" for a node. 
        loss_epoch = 0
        for j in range(len(data[i])):




            worker = data[i][j].location  # worker hosts data
            print(worker.id)
            if worker.id not in workers:
                print("not in worker list")
                continue


            data_device = data[i][j].to(device)
            target_device = target[i][j].to(device)

            curr_model[i].send(worker)  # send model to PyGridNode worker

            batch_remainder = len(data[i][j])%train_batch_size


            for k in range(len(data[i][j])//train_batch_size):
                optimizer = curr_optims[i].get_optim(worker.id)   

                optimizer.zero_grad()  
                pred = curr_model[i](data_device[k*train_batch_size:(k+1)*train_batch_size])
                loss = criterion(pred, target_device[k*train_batch_size:(k+1)*train_batch_size])
                #loss = F.nll_loss(pred, target[i][j])
                loss.backward()

                optimizer.step()
                loss_epoch += loss.get().item()
                
            k+=1

            if batch_remainder != 0:
                optimizer = curr_optims[i].get_optim(worker.id)   
                optimizer.zero_grad()  
                pred = curr_model[i](data_device[k*train_batch_size:k*train_batch_size+batch_remainder])
                loss = criterion(pred, target_device[k*train_batch_size:k*train_batch_size+batch_remainder])
                #loss = F.nll_loss(pred, target[i][j])
                loss.backward()

                optimizer.step()
                loss_epoch += loss.get().item()

            curr_model[i].get()  # get back the model
            print(next(curr_model[i].parameters()).is_cuda )


        th.save(curr_model[i].state_dict(), f'{output_model_folder}/checkpoint_{epoch}_{i}.pth')    

        if epoch % args.log_interval == 0:

            print('Train Epoch: {} | With {} data |: \tLoss: {:.6f}'.format(
                      epoch, worker.id,  loss_epoch))

    return curr_model

In [14]:
def test(test_model, args,fo):
    
    if epoch % args.log_interval == 0:
    
        test_model.eval()
        test_loss = 0
        correct = 0
        with th.no_grad():
            for data, target in testloader:
                data, target = data.to(device), target.to(device)
                output = test_model(data)
                loss = criterion(output, target)
                test_loss += loss  #F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
                pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(testloader.dataset)
        fo.write("{},{:.4f},{:.2f}\n".format(epoch, test_loss,100. * correct / len(testloader.dataset)))   
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(testloader.dataset),
            100. * correct / len(testloader.dataset)))

In [18]:
def train_acc(test_model, args):
    
    if epoch % args.log_interval == 0:
    
        test_model.eval()
        test_loss = 0
        correct = 0
        with th.no_grad():
            for data, target in trainloader:
                data, target = data.to(device), target.to(device)
                output = test_model(data)
                loss = criterion(output, target)
                test_loss += loss  #F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
                pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(trainloader.dataset)

        print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(trainloader.dataset),
            100. * correct / len(trainloader.dataset)))

In [None]:
#AGG_EPOCH = 2

output_file_name = TAG_NAME+'_AGG'+str(AGG_EPOCH)+'.csv'
fo = open(output_file_name, "w")

for epoch in range(N_EPOCHS):
    
    
    current_model = train(model_list,optims_list, args)
    
    print("----before aggregation----")
    for test_model in current_model:
        test(test_model, args,fo)
        
     
    if (epoch+1) % AGG_EPOCH ==0:
        ## model avg
        model_list,optims_list = avgWeight(current_model)
        print("----after aggregation----")
        for test_model in model_list:
            test(test_model, args,fo)
fo.close()

True
gridnode01
True
Train Epoch: 0 | With gridnode01 data |: 	Loss: 3605.654568
True
gridnode03
True
Train Epoch: 0 | With gridnode03 data |: 	Loss: 3507.154406
True
gridnode04
True
Train Epoch: 0 | With gridnode04 data |: 	Loss: 3198.300802
True
gridnode05
True
Train Epoch: 0 | With gridnode05 data |: 	Loss: 3658.790111
True
gridnode06
True
Train Epoch: 0 | With gridnode06 data |: 	Loss: 4034.247991
True
gridnode07
True
Train Epoch: 0 | With gridnode07 data |: 	Loss: 3253.042799
True
gridnode08
True
Train Epoch: 0 | With gridnode08 data |: 	Loss: 3744.923020
----before aggregation----

Test set: Average loss: 121.5383, Accuracy: 979/10000 (10%)


Test set: Average loss: 22.2885, Accuracy: 1024/10000 (10%)


Test set: Average loss: 0.0312, Accuracy: 1319/10000 (13%)


Test set: Average loss: 18.8499, Accuracy: 1023/10000 (10%)


Test set: Average loss: 0.7293, Accuracy: 1052/10000 (11%)


Test set: Average loss: 1.1350, Accuracy: 1410/10000 (14%)


Test set: Average loss: 86.9587, Acc

## backup

In [36]:
import torch
NUM_CLASSES = 10
model = Net()
in_feature =  model.fc.in_features
model.fc = nn.Linear(in_feature, NUM_CLASSES)  # ch
model.to(device)
#with torch.no_grad():
for name, param in model.named_parameters():
    print("{} {}".format(name, param.requires_grad))

    if param.grad is not None:
        print(name, param.grad.sum())
    else:
        print(name, param.grad)

conv1.weight True
conv1.weight None
conv1.bias True
conv1.bias None
fc1.weight True
fc1.weight None
fc1.bias True
fc1.bias None
fc2.weight True
fc2.weight None
fc2.bias True
fc2.bias None


### check weight

In [25]:
model_list[3].state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-0.1328, -0.1019,  0.0785],
                        [ 0.2806, -0.3180, -0.2036],
                        [-0.2728,  0.1638,  0.2065]]],
              
              
                      [[[-0.2845,  0.1327, -0.0324],
                        [-0.2365, -0.2362, -0.3152],
                        [ 0.0370,  0.0167, -0.3030]]],
              
              
                      [[[-0.3574, -0.2948, -0.0616],
                        [ 0.1574,  0.2851,  0.2707],
                        [ 0.0625,  0.1973, -0.1525]]],
              
              
                      [[[ 0.2383,  0.1192, -0.0877],
                        [-0.2530,  0.1197, -0.0134],
                        [-0.1103,  0.0722,  0.2642]]],
              
              
                      [[[-0.2541,  0.3477, -0.2550],
                        [-0.2919, -0.2888, -0.1888],
                        [-0.1940,  0.0830,  0.3208]]],
              
              
               

In [37]:
model_list[1].state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-3.1955e-02, -1.5596e-01,  1.3126e-01],
                        [ 1.1209e-01, -2.4137e-02, -1.2187e-01],
                        [-9.8066e-02, -1.5323e-02, -5.1719e-02]]],
              
              
                      [[[-3.2562e-02, -6.3484e-02, -1.0979e-01],
                        [-1.9259e-01, -1.4969e-01, -2.9122e-02],
                        [ 1.0481e-01, -6.9001e-02, -6.4005e-02]]],
              
              
                      [[[-2.0426e-01, -5.5101e-02, -6.6211e-02],
                        [ 8.7659e-02,  1.8548e-01,  7.9895e-02],
                        [-1.6249e-01,  6.4304e-02, -1.0850e-01]]],
              
              
                      [[[ 8.6541e-03, -5.4680e-02, -7.4961e-02],
                        [-1.2949e-01, -7.5531e-02,  1.5874e-02],
                        [ 3.1721e-03,  1.3787e-01,  1.6286e-01]]],
              
              
                      [[[-1.4306e-01,  1.7598e-02, -9.9530e-03

Et voilà! Here you are, you have trained a model on remote data using Federated Learning!

In [22]:
sdA = model.state_dict()
sdB = model_tmp.state_dict()

# Average all parameters
for key in sdA:
    print("A={}  B={}".format(sdA[key],sdB[key]))
#sdB[key] = (sdB[key] + sdA[key]) / 2.

# Recreate model and load averaged state_dict (or use modelA/B)
# model = nn.Linear(1, 1)
# model.load_state_dict(sdB)

# model_tmp

A=tensor([[[[-0.1136,  0.0119,  0.0643],
          [ 0.2360, -0.2665,  0.3035],
          [-0.1654,  0.0671,  0.0101]]],


        [[[ 0.2986,  0.0780, -0.2830],
          [ 0.1987, -0.0251,  0.3257],
          [-0.3063,  0.3294,  0.2357]]],


        [[[ 0.1051,  0.1047, -0.2119],
          [-0.0992,  0.0779,  0.0424],
          [ 0.1413, -0.1650,  0.0465]]],


        [[[-0.2344, -0.1020,  0.0777],
          [ 0.1065,  0.0464, -0.2997],
          [ 0.0662,  0.2825, -0.0874]]],


        [[[ 0.0143, -0.1704,  0.0094],
          [-0.2515, -0.2792, -0.2568],
          [-0.0121,  0.3080, -0.1566]]],


        [[[ 0.0484,  0.2494, -0.0696],
          [-0.2811, -0.0548, -0.0820],
          [-0.1393, -0.0216, -0.3121]]],


        [[[ 0.1564, -0.0025,  0.0601],
          [-0.1601, -0.2392, -0.0638],
          [-0.2685, -0.3021,  0.1344]]],


        [[[ 0.0864, -0.2734,  0.1815],
          [ 0.2332,  0.1287,  0.1620],
          [-0.3181,  0.3021,  0.3328]]],


        [[[-0.0173,  0.1830,  

### check if model in cuda

In [10]:
model = torchvision.models.resnet50(pretrained=True)

        
model.conv1 = th.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
batch = th.rand(4, 1, 224, 224) # in the NCHW format
model(batch).size()


in_feature =  model.fc.in_features
model.fc = nn.Linear(in_feature, NUM_CLASSES)  # ch
model.to(device)

ResNet(
  (conv1): Conv1d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [33]:
model.conv1.weight.type()

'torch.cuda.FloatTensor'

In [32]:
next(model.parameters()).device

device(type='cuda', index=0)