In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import syft as sy
import copy
import numpy as np
import time
from opacus import PrivacyEngine
import time
from datetime import timedelta
from datetime import datetime
from torchsummary import summary

import importlib
importlib.import_module('FLDataset')
from FLDataset import load_dataset, getActualImgs, CovidDataset, Rescale, ToTensor
from utils import averageModels, averageGradients
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Arguments():
    def __init__(self):
        self.images = 3012
        self.clients = 10
        self.rounds = 1000
        self.epochs = 1
        self.local_batches = 20
        self.lr = 0.02
        self.dropout1 = 0.25
        self.dropout2 = 0.5
        self.C = 0.9
        self.drop_rate = 0.1
        self.torch_seed = 0
        self.log_interval = 10
        self.iid = 'iid'
        self.split_size = int(self.images / self.clients)
        self.samples = self.split_size / self.images 
        self.use_cuda = True
        self.save_model = True
        self.save_model_interval = 200

args = Arguments()

use_cuda = args.use_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}


cuda


In [3]:
hook = sy.TorchHook(torch)
clients = []

for i in range(args.clients):
    clients.append({'hook': sy.VirtualWorker(hook, id="client{}".format(i+1))})

print(clients)
print("number of clients : ", len(clients))

[{'hook': <VirtualWorker id:client1 #objects:0>}, {'hook': <VirtualWorker id:client2 #objects:0>}, {'hook': <VirtualWorker id:client3 #objects:0>}, {'hook': <VirtualWorker id:client4 #objects:0>}, {'hook': <VirtualWorker id:client5 #objects:0>}, {'hook': <VirtualWorker id:client6 #objects:0>}, {'hook': <VirtualWorker id:client7 #objects:0>}, {'hook': <VirtualWorker id:client8 #objects:0>}, {'hook': <VirtualWorker id:client9 #objects:0>}, {'hook': <VirtualWorker id:client10 #objects:0>}]
number of clients :  10


In [4]:
global_train, global_test, train_group, test_group = load_dataset(args.clients, args.iid)

In [5]:
print(len(global_train))
print(type(global_train))
print(len(global_test))
print(type(global_test))
print(len(train_group))
print(type(train_group))
print(len(test_group))
print(type(test_group))

3012
<class 'FLDataset.CovidDataset'>
753
<class 'FLDataset.CovidDataset'>
10
<class 'dict'>
10
<class 'dict'>


In [6]:
for inx, client in enumerate(clients):
    trainset_ind_list = list(train_group[inx])
    print("len(client", str(inx), "train set) = ", len(trainset_ind_list))
    client['trainset'] = getActualImgs(global_train, trainset_ind_list, args.local_batches)
    client['testset'] = getActualImgs(global_test, list(test_group[inx]), args.local_batches)
    client['samples'] = len(trainset_ind_list) / args.images

len(client 0 train set) =  301
len(client 1 train set) =  301
len(client 2 train set) =  301
len(client 3 train set) =  301
len(client 4 train set) =  301
len(client 5 train set) =  301
len(client 6 train set) =  301
len(client 7 train set) =  301
len(client 8 train set) =  301
len(client 9 train set) =  301


In [7]:
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
global_test_dataset = CovidDataset('./test.csv', transform=transforms.Compose([Rescale(32), ToTensor()]))
global_test_loader = DataLoader(global_test_dataset, batch_size=args.local_batches, shuffle=True)

In [8]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 1,
                               out_channels = 32,
                               kernel_size = 3,
                               stride = 1)
        self.conv2 = nn.Conv2d(in_channels = 32,
                               out_channels = 64,
                               kernel_size = 3,
                               stride = 1)
        self.fc1 = nn.Linear(14*14*64, 128)
        self.fc2 = nn.Linear(128, 3)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.dropout(x, p=args.dropout1)
        x = x.view(-1, 14*14*64)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=args.dropout2)
        x = self.fc2(x)
        return F.softmax(x)

In [9]:
def ClientUpdate(args, device, client):
#     print("client: ", client)
    client['model'].train()
#     client['model'].send(client['hook'])
    
    for epoch in range(1, args.epochs + 1):
        for batch_idx, (data, target) in enumerate(client['trainset']):
            data, target = data.to(device), target.to(device)
            client['optim'].zero_grad()
#             output = client['model'](data.float())
#             loss = F.nll_loss(output, target.squeeze(1))
            output = client['model'](data)
            loss = client['criterion'](output, target.squeeze(1))
#             print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
#             print("output: ", output)
#             print("target squeeze: ", target.squeeze(1))
#             print("loss: ", loss)
#             print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
            loss.backward()
            client['optim'].step()
            
            
            if batch_idx % args.log_interval == 0 or batch_idx==len(client['trainset'])-1:
#                 loss = loss.get() 
                print('Model [{}] Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    client['hook'].id,
                    epoch, (batch_idx+1) * args.local_batches, len(client['trainset']) * args.local_batches, 
                    100. * (batch_idx+1) / len(client['trainset']), loss.item()/args.log_interval))
                
#     client['model'].get() 

In [10]:
def test(args, model, device, test_loader, name):
    model.eval()   
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for d in test_loader:
            data = d['image']
            target = d['label']
            data, target = data.to(device), target.to(device)
            if(str(device)=='cuda'):
                model.cuda()
            output = model(data.float())
#             test_loss += F.nll_loss(output, target.squeeze(1), reduction='sum').item() # sum up batch loss
            loss_fn = nn.CrossEntropyLoss(reduction='sum')
            test_loss += loss_fn(output, target.squeeze(1)).item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss for {} model: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        name, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return 100. * correct / len(test_loader.dataset)

In [11]:
writer = SummaryWriter()

In [12]:
torch.manual_seed(args.torch_seed)
global_model = Net().to(device)
summary(global_model, (1, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 30, 30]             320
            Conv2d-2           [-1, 64, 28, 28]          18,496
            Linear-3                  [-1, 128]       1,605,760
            Linear-4                    [-1, 3]             387
Total params: 1,624,963
Trainable params: 1,624,963
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.60
Params size (MB): 6.20
Estimated Total Size (MB): 6.81
----------------------------------------------------------------


  response = command_method(*args_, **kwargs_)


In [None]:
# training
for client in clients:
    torch.manual_seed(args.torch_seed)
    client['model'] = Net().to(device)
    client['optim'] = optim.SGD(client['model'].parameters(), lr=args.lr, momentum=0.9)
    client['criterion'] = nn.CrossEntropyLoss(reduction='mean')
#     client['pengine'] = PrivacyEngine(
#                                        client['model'],
#                                        batch_size=args.local_batches,
#                                        sample_size=len(client['trainset']),
#                                        alphas=range(2,32),
#                                        noise_multiplier=0,
#                                        max_grad_norm=1000
#                                     )
#     client['pengine'].attach(client['optim']) 
    
# start training model
training_start_time = time.time()
for fed_round in range(args.rounds):
    print("")
    print("===================================================================")
    print("[round] = ", fed_round+1, "/", args.rounds)
    print("===================================================================")
    
    round_train_start_time = time.time()
    
#     uncomment if you want a randome fraction for C every round
#     args.C = float(format(np.random.random(), '.1f'))    
    
    # number of selected clients
    m = int(max(args.C * args.clients, 1))

    # Selected devices
    np.random.seed(fed_round)
    selected_clients_inds = np.random.choice(range(len(clients)), m, replace=False)
    selected_clients = [clients[i] for i in selected_clients_inds]
    
    # Active devices
    np.random.seed(fed_round)
    active_clients_inds = np.random.choice(selected_clients_inds, int((1-args.drop_rate) * m), replace=False)
    active_clients = [clients[i] for i in active_clients_inds]
    
    # Training 
    client_cnt = 0
    for client in active_clients:
        print("* [client count] = ", client_cnt+1 , "/", len(active_clients))
        client_train_start_time = time.time()
        ClientUpdate(args, device, client)
        client_cnt += 1
        client_train_time = round(time.time()-client_train_start_time)
        print("* [client_train_time] = ", str(timedelta(seconds=(client_train_time))))
        print("---------------------------------------------------------------")
    
#         # Testing 
#         for client in active_clients:
#             test(args, client['model'], device, client['testset'], client['hook'].id)
    
    # Averaging 
    global_model = averageModels(global_model, active_clients)
    
    # Testing the average model
    acc = test(args, global_model, device, global_test_loader, 'Global')
    writer.add_scalar("Accuracy/train", acc, fed_round)
    writer.flush()
            
    # Share the global model with the clients
    for client in clients:
        client['model'].load_state_dict(global_model.state_dict())
        
    # training time per round
    total_train_time = round(time.time()-training_start_time)
    round_train_time = round(time.time()-round_train_start_time)
    print("** [total train time]: ", str(timedelta(seconds=total_train_time)))
    print("** [round train time]: ", str(timedelta(seconds=round_train_time)))
    
    if (args.save_model and fed_round%args.save_model_interval==0 and fed_round!=0):
        now = datetime.now() 
        date = now.strftime("%Y_%m_%d_%H%M")
        torch.save(global_model.state_dict(), date + "_FedAvg_with_DP_round_" + str(fed_round) + ".pth")
        print("model saved : "+ date +"_FedAvg_with_DP_round_" + str(fed_round) + ".pth")


  current_tensor = hook_self.torch.native_tensor(*args, **kwargs)



[round] =  1 / 1000
* [client count] =  1 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  2 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  3 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  4 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  5 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  6 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  7 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  8 / 8
* [client_train_time] =  0:00:03
-----------------------------------------------------------

* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  2 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  3 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  4 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  5 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  6 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  7 / 8
* [client_train_time] =  0:00:03
---------------------------------------------------------------
* [client count] =  8 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------

Test set: Average loss for Global model: 

* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  2 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  3 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  4 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  5 / 8
* [client_train_time] =  0:00:03
---------------------------------------------------------------
* [client count] =  6 / 8
* [client_train_time] =  0:00:03
---------------------------------------------------------------
* [client count] =  7 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------
* [client count] =  8 / 8
* [client_train_time] =  0:00:04
---------------------------------------------------------------

Test set: Average loss for Global model: 

In [None]:
# tensorboard open 
# tensorboard --logdir=/home/citi302/Desktop/Codefolder/FL_DP_covid/runs