<a href="https://colab.research.google.com/github/pankajattri/CSC591/blob/master/PreventAttacks_FederatedLearning_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
############
#INSTALL PySyft
##########

!git clone https://github.com/OpenMined/PySyft.git
!cd PySyft/
!pip install -r PySyft/pip-dep/requirements.txt
!pip install -r PySyft/pip-dep/requirements_udacity.txt
!python PySyft/setup.py install

Before Running the next step make sure to "Restart Runtime"

In [0]:
# Run this cell to add PySyft path 
import os
import sys
module_path = os.path.abspath(os.path.join('./PySyft'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import copy

Import PySyft and create hook to use torch libraries

In [43]:
import syft as sy  # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning




Create Four agents. Each of them will be used to train on the data independently (sequentially though!)

In [0]:
Agent_1 = sy.VirtualWorker(hook, id="Agent_1")
Agent_2 = sy.VirtualWorker(hook, id="Agent_2")
Agent_3 = sy.VirtualWorker(hook, id="Agent_3")
Agent_4 = sy.VirtualWorker(hook, id="Agent_4")

Arguments to be used for training. I have taken this directly from the PySyft tutorial

In [0]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 1
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [0]:
# Download dataset

mnist_full_train_dataset = datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))


Split the dataset into three parts. This is done to mimic training on the data at three different "time stamps". The first dataset will be used to train the model with no malicious data. We will include some malicious data in datasets 2 and 3 and make sure the model obtained from dataset 1 is not updated by a malicious agent's updated delta

In [0]:

ds_t1,ds_t2,ds_t3 = torch.utils.data.random_split(mnist_full_train_dataset,(40000,12000,8000))

# TO DO
How to manipulate dataset to include malicious updates???

In [0]:
temp = ds_t2
print(ds_t2[10][1])
temp[10][1] = 4
print(temp[10][1])

9


TypeError: ignored

Create three "federated" datastes. Each corresponding to the dataset created above. This is done to use "PySyft" infrastructure to mimic training model using four different slaves (or clients or agents)

In [0]:
federated_train_ds_t1 = sy.FederatedDataLoader( ds_t1.federate((Agent_1,Agent_2,Agent_3,Agent_4)),batch_size=64,shuffle=True, **kwargs)
federated_train_ds_t2 = sy.FederatedDataLoader( ds_t2.federate((Agent_1,Agent_2,Agent_3,Agent_4)),batch_size=64,shuffle=True, **kwargs)
federated_train_ds_t3 = sy.FederatedDataLoader( ds_t3.federate((Agent_1,Agent_2,Agent_3,Agent_4)),batch_size=64,shuffle=True, **kwargs)

In [0]:
# This is the dataset that will be used repeateadly to test model performance
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

CNN model. I have taken this directly from the PySyft tutorial. We can potentialy look at changig this but I have not spent time on this yet.

In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

This is the train function. This is sort of a hack solution to train the model on the 1st dataset using PySyft's federated training

In [0]:
def train(args, model, device, federated_train_loader, optimizer,epoch):
    
    model.train()
        
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        cal_grad_bkpropgt(data,target,batch_idx,federated_train_loader,model,device,epoch)

Test fucniton to test model performance

In [0]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    '''
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    '''
    return test_loss, 100. * correct / len(test_loader.dataset)

Helper function to calculate gradient. This will be used on the 1st dataset when training on non-malicious data to get a basline model

In [0]:
def cal_grad_bkpropgt(data,target,batch_idx,federated_train_loader,model,device,epoch):
  model.send(data.location) # <-- NEW: send the model to the right location
  data, target = data.to(device), target.to(device)
  optimizer.zero_grad()
  output = model(data)
  loss = F.nll_loss(output, target)
  loss.backward()
  optimizer.step()
  model.get() # <-- NEW: get the model back
  if batch_idx % args.log_interval == 0:
      loss = loss.get() # <-- NEW: get the loss back
      print('Agent: {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          data.location,epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
          100. * batch_idx / len(federated_train_loader), loss.item()))
  
  #return model

Helper function to get delta's from each Agent when training on datasets 2 and 3

In [0]:
def cal_grad_bkpropgt_return_delta(data,target,batch_idx,federated_train_loader,model,device):
  #org_model_dict = model.state_dict()
  #print('Starting training on batch', batch_idx)
  model.send(data.location) # <-- NEW: send the model to the right location
  data, target = data.to(device), target.to(device)
  optimizer.zero_grad()
  output = model(data)
  loss = F.nll_loss(output, target)
  loss.backward()
  optimizer.step()
  model.get() # <-- NEW: get the model back
  if batch_idx % args.log_interval == 0:
      loss = loss.get() # <-- NEW: get the loss back
      print('Agent: {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          data.location,1, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
          100. * batch_idx / len(federated_train_loader), loss.item()))
      
  #return model

Train model on dataset 1 (remember dataset 1 corresponds to training at timestamp 1)

In [55]:
%%time
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1,5):
  train(args, model, device, federated_train_ds_t1, optimizer,epoch)
  test(args, model, device, test_loader)

#if (args.save_model):
#    torch.save(model.state_dict(), "mnist_cnn.pt")




Test set: Average loss: 0.4698, Accuracy: 8512/10000 (85%)


Test set: Average loss: 0.1438, Accuracy: 9563/10000 (96%)


Test set: Average loss: 0.1006, Accuracy: 9701/10000 (97%)


Test set: Average loss: 0.0882, Accuracy: 9722/10000 (97%)

CPU times: user 5min 51s, sys: 15.9 s, total: 6min 7s
Wall time: 6min 8s


Function to train on datasets 2 and 3 corresponding to time stamps t2 and t3.

In [0]:
def train_subsequent_trainings(args, model,Current_model, device, federated_train_loader, optimizer):

  #model.train()
     
  for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
      
      if data.location == Agent_1:
        #print('Working on Agent 1')
        cal_grad_bkpropgt_return_delta(data,target,batch_idx,federated_train_loader,model,device)
          

      
      elif data.location == Agent_2:
        
        #print('Working on Agent 2')
        delta_Agent_1 = {name: Current_model.state_dict()[name] - model.state_dict()[name] for name in Current_model.state_dict() if name in model.state_dict()}
        model.load_state_dict(Current_model.state_dict())
        cal_grad_bkpropgt_return_delta(data,target,batch_idx,federated_train_loader,model,device)
        
      
      elif data.location == Agent_3:
        #print('Working on Agent 3')
        delta_Agent_2 = {name: Current_model.state_dict()[name] - model.state_dict()[name] for name in Current_model.state_dict() if name in model.state_dict()}
        model.load_state_dict(Current_model.state_dict())
        cal_grad_bkpropgt_return_delta(data,target,batch_idx,federated_train_loader,model,device)
        
      
      elif data.location == Agent_4:
        #print('Working on Agent 4')
        delta_Agent_3 = {name: Current_model.state_dict()[name] - model.state_dict()[name] for name in Current_model.state_dict() if name in model.state_dict()}
        model.load_state_dict(Current_model.state_dict())
        cal_grad_bkpropgt_return_delta(data,target,batch_idx,federated_train_loader,model,device)

      else:
        pass
  
  delta_Agent_4 = {name: Current_model.state_dict()[name] - model.state_dict()[name] for name in Current_model.state_dict() if name in model.state_dict()}

  
  return delta_Agent_1, delta_Agent_2, delta_Agent_3, delta_Agent_4    


In [57]:
Current_model = Net().to(device)
Current_model.load_state_dict(model.state_dict())

#train_subsequent_trainings(args, model, device, federated_train_ds_t2, optimizer)

<All keys matched successfully>

In [58]:
%%time

# Get deltas from each agent for dataset 2 (corresponding to tiemstamp 2)

delta_Agent_1, delta_Agent_2, delta_Agent_3, delta_Agent_4 = train_subsequent_trainings(args, model,Current_model, device, federated_train_ds_t2, optimizer)

#test(args, model, device, test_loader)




CPU times: user 24.3 s, sys: 1.14 s, total: 25.5 s
Wall time: 25.5 s


In [0]:
from collections import Counter

def updated_weights(model,delta):
  for name,param in model.state_dict().items():
    new_weights = {name: model.state_dict()[name] - delta[name] for name in model.state_dict()}
  return new_weights

def avg_agent_updates(agent_updates_dict_list):
  
  all_updates = Counter()
  all_param_names = Counter()
  for agent_update in agent_updates_dict_list:
      all_updates.update(agent_update)
      all_param_names.update(agent_update.keys())

  averaged_updates_delta = {x: (1.0 * all_updates[x])/all_param_names[x] for x in all_updates.keys()}

  return averaged_updates_delta

def test_updates(agent_updates,test_loader):
  total_number_agents = len(agent_updates)
  
  for idx in range (0,total_number_agents):
    current_agent_delta = agent_updates[idx]
    mod_weights = updated_weights(Current_model,current_agent_delta)
    Test_model.load_state_dict(mod_weights)
    current_agent_loss, current_agent_accuracy = test(args, Test_model, device, test_loader)
    print('Loss if updates included from agent: {}, {:.6f}. Accuracy: {:.2f}'.format(idx, current_agent_loss,current_agent_accuracy))
    average_delta_other_Agents = avg_agent_updates([x for i,x in enumerate(agent_updates) if i!=idx])
    mod_weights = updated_weights(Current_model,average_delta_other_Agents)
    Test_model.load_state_dict(mod_weights)
    all_other_agent_loss, all_other_agent_accuracy = test(args, Test_model, device, test_loader)
    print('Loss if updates included from all other agents: {:.6f}. Acuracy: {:.2f}'.format(all_other_agent_loss,all_other_agent_accuracy))

In [67]:
test_updates([delta_Agent_1, delta_Agent_2, delta_Agent_3, delta_Agent_4],test_loader)

Loss if updates included from agent: 0 0.084581. Accuracy: 97.48
Loss if updates included from all other agents: 0.07483155288696289. Acuracy: 97.88
Loss if updates included from agent: 1 0.078709. Accuracy: 97.55
Loss if updates included from all other agents: 0.07650294570922851. Acuracy: 97.76
Loss if updates included from agent: 2 0.077368. Accuracy: 97.81
Loss if updates included from all other agents: 0.07810172119140625. Acuracy: 97.70
Loss if updates included from agent: 3 0.080095. Accuracy: 97.55
Loss if updates included from all other agents: 0.07612780494689941. Acuracy: 97.81
