In [1]:
import matplotlib
matplotlib.use('Agg')
%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from utils.customloader import CustomDataset, DatasetSplit
from utils.dataloader import get_dataloader
from utils.train_glob import train_global_model, test_model

from models.Update import LocalUpdate
from models.Fed import FedAvg

import random
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import numpy as np
import copy



import torch
class Args:
    #federated arugments
    epochs=50
    num_users=10
    
    local_ep=3
    local_bs=100
    bs=128
    lr=0.01
    momentum=0.5

    num_channels=1
    num_classes=10
    verbose='store_true'
    seed=1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    

args = Args()    
##############SET SEEDS FOR REPRODUCIBILITY#############
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
##############~SET SEEDS FOR REPRODUCIBILITY#############

<torch._C.Generator at 0x7f3db4242230>

# Define Dataloader /  Model / Optimizer / Loss

In [2]:
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

dataset_train = datasets.CIFAR10('./data/cifar', train=True, download=True, transform=transform)
global_train_loader = DataLoader(dataset_train, batch_size=1000, shuffle=True)

dataset_test = datasets.CIFAR10('./data/cifar', train=False, download=True, transform=transform)
test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)



Files already downloaded and verified
Files already downloaded and verified


In [3]:
import torch.nn as nn
class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x    


In [4]:
net_glob = CNNCifar(args=args).to(args.device)
net_glob.train()

optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
sloss = F.cross_entropy


In [5]:
checkpoint_globalnet = copy.deepcopy(net_glob)

# Define Distribution of Data

In [6]:
unique, counts = np.unique(global_train_loader.dataset.targets, return_counts=True)
print(unique)
print(counts)
sorted_y = copy.deepcopy(global_train_loader.dataset.targets)
sorted_index_y = np.argsort(np.squeeze(sorted_y))

class_dist=[]

for i in range(args.num_classes):
    print(i)
    class_dist.append(np.array(sorted_index_y[sum(counts[:i]):sum(counts[:i+1])], dtype=np.int64))
    
non_iid = np.array(class_dist)

[0 1 2 3 4 5 6 7 8 9]
[5000 5000 5000 5000 5000 5000 5000 5000 5000 5000]
0
1
2
3
4
5
6
7
8
9


In [7]:
individual = []
for j in range(10):
    individual.append(np.array_split(class_dist[j], 10))

user_dist=[]
for i in range(10):
    temp=[]
    for j in range(10):
        temp.append(individual[j][i])
        
    
    user_dist.append((np.concatenate(temp)).astype(np.int64))    
    
iid=np.array(user_dist)

# Train Model

In [8]:
from utils.multiprocessing import multi_train_local_dif
from torch.utils.data.sampler import Sampler
from torchvision import datasets, transforms
import torch.multiprocessing as mp

In [None]:
distribution = iid

if __name__ == '__main__':
    
    mp.set_start_method('fork', force=True)
    torch.set_num_threads(1)
    
    checkpoint_globalnet11 = copy.deepcopy(net_glob)
    
    for i in range(args.epochs):
        
        print('--------------------------------------------')
        print("\n\n\nstart training epoch : " + str(i) + "\n\n\n")
        print('--------------------------------------------')
        
        procs=[]
        loss_locals=[]
        w_locals=[]        
        
        q_l = mp.Queue()
        q_w = mp.Queue()        
        
        for i in range(args.num_users):

            p = mp.Process(target=multi_train_local_dif, args=(q_l, q_w, args, 
                                                               i, sloss, global_train_loader, 
                                                               distribution, checkpoint_globalnet11))
            procs.append(p)
            p.start()

        for p in procs:
            loss_locals.append(q_l.get(p))
            w_locals.append(q_w.get(p))

        for p in procs:
            p.join()

            
        print('--------------------------------------------\n\n')
        w_glob = FedAvg(w_locals)
        checkpoint_globalnet11.load_state_dict(w_glob)
        test_model(checkpoint_globalnet11, test_loader, sloss, args)
        print('\n\n--------------------------------------------')
        





# Test Model

In [10]:
#After fedlearning
print('Before Federated Learning')
test_model(net_glob, test_loader, sloss)

Before Federated Learning

Test set: Average loss: 2.30375 
Accuracy: 943/10000 (9.43%)



In [11]:
#After fedlearning
print('After Federated Learning -- checkpoint')
test_model(checkpoint_globalnet11, test_loader, sloss)

After Federated Learning -- checkpoint

Test set: Average loss: 1.13745 
Accuracy: 6007/10000 (60.07%)

