In [10]:
!pip install tf-encrypted

! URL="https://github.com/openmined/PySyft.git" && FOLDER="PySyft" && if [ ! -d $FOLDER ]; then git clone -b dev --single-branch $URL; else (cd $FOLDER && git pull $URL && cd ..); fi;

!cd PySyft; python setup.py install  > /dev/null

!pip install --upgrade --force-reinstall lz4
!pip install --upgrade --force-reinstall websocket
!pip install --upgrade --force-reinstall websockets
!pip install --upgrade --force-reinstall zstd


You should consider upgrading via the 'pip install --upgrade pip' command.[0m
Cloning into 'PySyft'...
remote: Enumerating objects: 29220, done.[K
remote: Total 29220 (delta 0), reused 0 (delta 0), pack-reused 29220[K
Receiving objects: 100% (29220/29220), 32.18 MiB | 8.11 MiB/s, done.
Resolving deltas: 100% (19444/19444), done.
zip_safe flag not set; analyzing archive contents...
Collecting lz4
  Using cached https://files.pythonhosted.org/packages/0a/c6/96bbb3525a63ebc53ea700cc7d37ab9045542d33b4d262d0f0408ad9bbf2/lz4-2.1.10-cp36-cp36m-manylinux1_x86_64.whl
Installing collected packages: lz4
  Found existing installation: lz4 2.1.10
    Uninstalling lz4-2.1.10:
      Successfully uninstalled lz4-2.1.10
Successfully installed lz4-2.1.10
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
Collecting websocket
[?25l  Downloading https://files.pythonhosted.org/packages/f2/6d/a60d620ea575c885510c574909d2e3ed62129b121fa2df00ca1c81024c87/websocket-0.2.1.tar.gz 

In [11]:
import torch
from torchvision import datasets, transforms
import os
import sys
module_path = os.path.abspath(os.path.join('./PySyft'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [12]:
import syft as sy
hook = sy.TorchHook(torch)

# Create a couple of workers
bob = sy.VirtualWorker(hook, id="bob")  
alice = sy.VirtualWorker(hook, id="alice")




In [13]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments()

torch.manual_seed(args.seed)

<torch._C.Generator at 0x7f7d78627c10>

In [14]:
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

cuda


In [16]:
transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))]) 
mnist_trainset = datasets.MNIST('~/.pytorch/F_MNIST_data/', train=True, download=True, transform=transform).federate((bob, alice))
mnist_testset = datasets.MNIST('~/.pytorch/F_MNIST_data/', train=False, download=True, transform=transform)

federated_train_loader = sy.FederatedDataLoader(mnist_trainset, batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=args.test_batch_size, shuffle=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /home/akshay/.pytorch/F_MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:03, 2578883.18it/s]                             


Extracting /home/akshay/.pytorch/F_MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /home/akshay/.pytorch/F_MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 45502.42it/s]                           
0it [00:00, ?it/s]

Extracting /home/akshay/.pytorch/F_MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /home/akshay/.pytorch/F_MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:01, 838737.84it/s]                            
0it [00:00, ?it/s]

Extracting /home/akshay/.pytorch/F_MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /home/akshay/.pytorch/F_MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 17345.56it/s]            


Extracting /home/akshay/.pytorch/F_MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [17]:
from torch import nn, optim
import torch.nn.functional as F
torch.set_default_tensor_type(torch.cuda.FloatTensor)

In [19]:
class Classifier(nn.Module):
   
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [20]:
model = Classifier()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr)
for epoch in range(1, args.epochs + 1):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # iterate through each worker's dataset
        
        model.send(data.location) #send the model to the right location
        
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad() # 1) erase previous gradients (if they exist)
        output = model(data)  # 2) make a prediction
        loss = F.nll_loss(output, target)  # 3) calculate how much we missed
        loss.backward()  # 4) figure out which weights caused us to miss
        optimizer.step()  # 5) change those weights
        model.get()  # get the model back (with gradients)
        
        if batch_idx % args.log_interval == 0:
            loss = loss.get() #get the loss back
            print('Epoch: {} [Training: {:.0f}%]\tLoss: {:.6f}'.format(epoch, 100. * batch_idx / len(federated_train_loader), loss.item()))
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    
if (args.save_model):
    torch.save(model.state_dict(), "mnist_cnn.pt")

Epoch: 1 [Training: 0%]	Loss: 2.307303
Epoch: 1 [Training: 3%]	Loss: 2.152380
Epoch: 1 [Training: 6%]	Loss: 1.851272
Epoch: 1 [Training: 10%]	Loss: 1.258177
Epoch: 1 [Training: 13%]	Loss: 0.785922
Epoch: 1 [Training: 16%]	Loss: 0.752820
Epoch: 1 [Training: 19%]	Loss: 0.389547
Epoch: 1 [Training: 22%]	Loss: 0.668079
Epoch: 1 [Training: 26%]	Loss: 0.556559
Epoch: 1 [Training: 29%]	Loss: 0.298134
Epoch: 1 [Training: 32%]	Loss: 0.709528
Epoch: 1 [Training: 35%]	Loss: 0.243506
Epoch: 1 [Training: 38%]	Loss: 0.295973
Epoch: 1 [Training: 42%]	Loss: 0.437727
Epoch: 1 [Training: 45%]	Loss: 0.444640
Epoch: 1 [Training: 48%]	Loss: 0.462919
Epoch: 1 [Training: 51%]	Loss: 0.377572
Epoch: 1 [Training: 54%]	Loss: 0.231817
Epoch: 1 [Training: 58%]	Loss: 0.164112
Epoch: 1 [Training: 61%]	Loss: 0.186264
Epoch: 1 [Training: 64%]	Loss: 0.267224
Epoch: 1 [Training: 67%]	Loss: 0.161648
Epoch: 1 [Training: 70%]	Loss: 0.223413
Epoch: 1 [Training: 74%]	Loss: 0.181516
Epoch: 1 [Training: 77%]	Loss: 0.177804
Epo

Epoch: 7 [Training: 16%]	Loss: 0.039876
Epoch: 7 [Training: 19%]	Loss: 0.014503
Epoch: 7 [Training: 22%]	Loss: 0.074501
Epoch: 7 [Training: 26%]	Loss: 0.004399
Epoch: 7 [Training: 29%]	Loss: 0.027279
Epoch: 7 [Training: 32%]	Loss: 0.026506
Epoch: 7 [Training: 35%]	Loss: 0.014903
Epoch: 7 [Training: 38%]	Loss: 0.086426
Epoch: 7 [Training: 42%]	Loss: 0.024283
Epoch: 7 [Training: 45%]	Loss: 0.015972
Epoch: 7 [Training: 48%]	Loss: 0.087584
Epoch: 7 [Training: 51%]	Loss: 0.054718
Epoch: 7 [Training: 54%]	Loss: 0.057633
Epoch: 7 [Training: 58%]	Loss: 0.004831
Epoch: 7 [Training: 61%]	Loss: 0.016346
Epoch: 7 [Training: 64%]	Loss: 0.004659
Epoch: 7 [Training: 67%]	Loss: 0.081221
Epoch: 7 [Training: 70%]	Loss: 0.010923
Epoch: 7 [Training: 74%]	Loss: 0.062022
Epoch: 7 [Training: 77%]	Loss: 0.108107
Epoch: 7 [Training: 80%]	Loss: 0.097467
Epoch: 7 [Training: 83%]	Loss: 0.021479
Epoch: 7 [Training: 86%]	Loss: 0.022884
Epoch: 7 [Training: 90%]	Loss: 0.016166
Epoch: 7 [Training: 93%]	Loss: 0.066763


In [21]:
print("Accuracy Obtained {:.4f}%".format( 100. * correct / len(test_loader.dataset)))

Accuracy Obtained 98.8000%
