# PyTorch + PySyft

[PySyft](https://github.com/OpenMined/PySyft/), the first open-source Federated Learning framework for building secure and scalable models. PySyft is simply a hooked extension of PyTorch.

**How long does it takes to do Federated Learning compared to normal PyTorch?**

The computation time is actually less than twice the time used for normal PyTorch execution! More precisely we have a +91% overhead.

We only had to **modify 10 lines of code** to upgrade the official Pytorch example on MNIST to a real Federated Learning task!

**Currently not working on GPU**  :(

Everything is explained in-detail in [blog post](https://dudeperf3ct.github.io/federated/learning/privacy/2019/02/08/Federated-Learning-and-Privacy/). This is notebook which replicates the result of blog and runs in colab. Enjoy!


#### Run in Colab

You can run this notebook in google colab.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dudeperf3ct/DL_notebooks/blob/master/Federated%20Learning/federated_learning_pysyft.ipynb)

In [0]:
! pip install syft

Collecting syft
[?25l  Downloading https://files.pythonhosted.org/packages/f5/79/60b8478049c305c7e5b5f2908104d18ef9d22fcf8535bc49995a3be4a0eb/syft-0.1.6a1-py3-none-any.whl (102kB)
[K    100% |████████████████████████████████| 102kB 2.8MB/s 
[?25hCollecting lz4 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/83/fe/66da85ed881031de7cf7de9dd38cc98aec8859824c7bcd3e8a88d255f36d/lz4-2.1.6-cp36-cp36m-manylinux1_x86_64.whl (359kB)
[K    100% |████████████████████████████████| 368kB 9.2MB/s 
[?25hCollecting websocket-client (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/38/54/684db2ba1b7a203602808446b8686ee786f93b4a7e080cdc440cc7e06e56/websocket_client-0.55.0-py2.py3-none-any.whl (200kB)
[K    100% |████████████████████████████████| 204kB 27.4MB/s 
Collecting sphinx-rtd-theme (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/60/b4/4df37087a1d36755e3a3bfd2a30263f358d2dea21938240fa02313d45f51/sphinx_rtd_theme-0.4.3-py2.p

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

torch.manual_seed(42)

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Training on %s ...'%format(device))

Training on cuda:0 ...


## Federated Learning

![image](https://blog.openmined.org/content/images/2019/02/Capture-d-e-cran-2019-02-25-a--17.45.36.png)



Code Adapted from : [Link](https://blog.openmined.org/upgrade-to-federated-learning-in-10-lines/)

In [0]:
import syft as sy  # <-- import the Pysyft library
hook = sy.TorchHook(torch)  # <-- hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob")  # <-- define remote worker bob
alice = sy.VirtualWorker(hook, id="alice")  # <-- and alice

In [0]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.01
        self.log_interval = 10
        self.save_model = False

args = Arguments()

In [0]:
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader 
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate((bob, alice)), # <-- we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True)

  0%|          | 16384/9912422 [00:00<01:28, 111914.41it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:00, 29151429.82it/s]                           


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz


32768it [00:00, 451632.36it/s]
  1%|          | 16384/1648877 [00:00<00:11, 141886.53it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 7314635.19it/s]                           
8192it [00:00, 178942.06it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!
Scanning and sending data to bob, alice...
Done!


## CNN Model

In [0]:
# define the CNN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # convolutional layer
        self.conv1 = nn.Conv2d(1, 32, 3, padding=0)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=0)
        # max pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(p=0.25)
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(12 * 12 * 64, 128)
        self.output = nn.Linear(128, 10)

    def forward(self, x):
        # add sequence of convolutional and max pooling layers
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout1(x)
        x = x.view(-1, 12 * 12 * 64)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = F.log_softmax(self.output(x), dim=1)
        return x

In [0]:
model = Net()
#model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr)

In [0]:
def train(args, model, device, train_loader, optimizer, epoch):

    model.train()
    
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
      
        model.send(data.location) # <-- send the model to the right location
        #data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # <-- get the model back
        
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size, #batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


In [0]:
def test(args, model, device, test_loader):
  
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
          
            #data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


In [0]:
for epoch in range(1, args.epochs + 1):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

if (args.save_model):
    torch.save(model.state_dict(), "mnist_cnn.pt")

  response = eval(cmd)(*args, **kwargs)



Test set: Average loss: 0.2150, Accuracy: 9366/10000 (94%)


Test set: Average loss: 0.1393, Accuracy: 9576/10000 (96%)


Test set: Average loss: 0.0996, Accuracy: 9687/10000 (97%)


Test set: Average loss: 0.0806, Accuracy: 9744/10000 (97%)


Test set: Average loss: 0.0674, Accuracy: 9779/10000 (98%)


Test set: Average loss: 0.0601, Accuracy: 9807/10000 (98%)


Test set: Average loss: 0.0553, Accuracy: 9828/10000 (98%)


Test set: Average loss: 0.0495, Accuracy: 9834/10000 (98%)


Test set: Average loss: 0.0446, Accuracy: 9844/10000 (98%)


Test set: Average loss: 0.0463, Accuracy: 9852/10000 (99%)

