<a href="https://colab.research.google.com/github/shenghaoG/FL_combined_DP/blob/master/FedAvg_suc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install tf-encrypted

! URL="https://github.com/openmined/PySyft.git" && FOLDER="PySyft" && if [ ! -d $FOLDER ]; then git clone -b dev --single-branch $URL; else (cd $FOLDER && git pull $URL && cd ..); fi;

!cd PySyft; python setup.py install  > /dev/null

import os
import sys
module_path = os.path.abspath(os.path.join('./PySyft'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
!pip install --upgrade --force-reinstall lz4
!pip install --upgrade --force-reinstall websocket
!pip install --upgrade --force-reinstall websockets
!pip install --upgrade --force-reinstall zstd
!pip install --upgrade msgpack

In [0]:
!pip install torch==1.1.0

In [0]:
!pip install torchvision==0.3.0

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import syft as sy
import random

In [0]:
hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
secure_worker=sy.VirtualWorker(hook, id="secure_worker")
bob.add_workers([alice,secure_worker])
alice.add_workers([bob,secure_worker])
secure_worker.add_workers([alice,bob])

In [0]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.001
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 100
        self.save_model = True
        self.gauss_mu=0
        self.gauss_sigma=0.001

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

In [0]:
federated_train_loader = sy.FederatedDataLoader(
                          datasets.MNIST('../data', train=True, download=True,
                          transform=transforms.Compose([
                          transforms.ToTensor(),
                          transforms.Normalize((0.1307,), (0.3081,))])).federate((bob, alice)),
                          batch_size=args.batch_size, shuffle=True,**kwargs)
torch.set_default_tensor_type(torch.cuda.FloatTensor)
print(len(federated_train_loader))

In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(4*4*64, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [0]:
test_loader = torch.utils.data.DataLoader(
                       datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))])),
                       batch_size=args.test_batch_size, shuffle=True)
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('Test set: Average loss: {:.8f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    return correct



In [0]:
args.epochs = 10
import time
start=time.time()
best_correct_list=[]

best_correct = 0
model = Net().cuda()
model.train()
model_bob = model.copy().send(bob)
model_alice= model.copy().send(alice)
optimizer_bob = optim.SGD(params=model_bob.parameters(), lr=args.lr)
optimizer_alice = optim.SGD(params=model_alice.parameters(), lr=args.lr)
    
    for epoch in range(1, args.epochs + 1):
        for idx,(data, target) in enumerate(federated_train_loader):
            
            if data.location == bob:
                model_bob.train()
                data, target = data.to(device), target.to(device)
                optimizer_bob.zero_grad()
                output_bob = model_bob(data)
                loss_bob = F.nll_loss(output_bob, target)
                loss_bob.backward()
                optimizer_bob.step()
                loss_bob = loss_bob.get().data
            elif data.location == alice:
                model_alice.train()
                data, target = data.to(device), target.to(device)
                optimizer_alice.zero_grad()
                output_alice = model_alice(data)
                loss_alice = F.nll_loss(output_alice, target)
                loss_alice.backward()
                optimizer_alice.step()
                loss_alice = loss_alice.get().data

#             if(idx % args.log_interval == 0):
#                 print('Train Epoch: {} \tAlice Loss: {:.6f} \tbob Loss: {:.6f}'.format(epoch, loss_alice.item(), loss_bob.item()))
#                 break;

#             with torch.no_grad():
#                 model_bob.move(secure_worker)
#                 model_alice.move(secure_worker)
            
#                 model_bob= model_bob.get()
#                 model_alice= model_alice.get()
                
#                 for pram_bob,pram_alice in zip(model_bob.parameters(),model_alice.parameters()):
#                     C_bob=torch.median(torch.abs(pram_bob.grad))
#                     C_alice=torch.median(torch.abs(pram_alice.grad))
#                     grad_bob = torch.normal(pram_bob.grad/max(1, torch.norm(pram_bob.grad)/C_bob),sigma*C_bob)
#                     grad_alice =  torch.normal(pram_alice.grad/max(1, torch.norm(pram_alice.grad)/C_alice),sigma*C_alice)            
#                     grad_avg = (grad_bob+grad_alice)/2.0
#                     pram_bob.grad = grad_avg
#                     pram_alice.grad = grad_avg
                    
#                 optimizer_bob.step()
#                 optimizer_alice.step()
            
#             print("test bob_model")
#             test(args, model_bob, device, test_loader)
#             print("test alice_model")
#             test(args, model_alice, device, test_loader)

#             # add DP after averaging random.gauss(0,float(torch.mean(torch.abs(pram1[i].data)))*0.1)
#             pram1 = [i for i in model_bob.parameters()]
#             pram2 = [i for i in model_alice.parameters()]
#             for i in range(len(pram1)):
# #                 print(float(torch.mean(torch.abs(pram1[i].data))))
#                 temp=((pram1[i] + random.gauss(0,float(torch.mean(torch.abs(pram1[i].data)))*t) + 
#                        pram2[i] + random.gauss(0,float(torch.mean(torch.abs(pram2[i].data)))*t)) * 0.5).float().to(device)
#                 pram1[i].set_(temp)
#                 pram2[i].set_(temp)

            
                
            if(idx%20==0):
                correct=test(args, model_alice, device, test_loader)
                if(correct>best_correct):
                    best_correct=correct

            model_bob = model_bob.send(bob)
            model_alice = model_alice.send(alice)

    best_correct_list.append(best_correct/10000)
    print('\n')

print(best_correct_list)
print("done! use time ",(time.time()-start)/60)