In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms, models
import torch
import sys

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img

from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

In [2]:
sys.argv = ['','--iid','--num_channels','3','--model','cnn','--epochs','70',
            '--gpu','0','--num_users','3','--dataset', 'cifar', '--lr','0.01', 
            '--local_ep','1']

In [3]:
args = args_parser()

In [4]:
args

Namespace(bs=128, dataset='cifar', epochs=70, frac=0.1, gpu=0, iid=True, kernel_num=9, kernel_sizes='3,4,5', local_bs=10, local_ep=1, lr=0.01, max_pool='True', model='cnn', momentum=0.5, norm='batch_norm', num_channels=3, num_classes=10, num_filters=32, num_users=3, seed=1, split='user', stopping_rounds=10, verbose=False)

In [5]:
classDict = {'plane':0, 'car':1, 'bird':2, 'cat':3, 'deer':4, 'dog':5, 'frog':6, 'horse':7, 'ship':8, 'truck':9}

# Define a function to separate CIFAR classes by class index

def get_class_i(x, y, i):
    """
    x: trainset.train_data or testset.test_data
    y: trainset.train_labels or testset.test_labels
    i: class label, a number between 0 to 9
    return: x_i
    """
    # Convert to a numpy array
    y = np.array(y)
    # Locate position of labels that equal to i
    pos_i = np.argwhere(y == i)
    # Convert the result into a 1-D list
    pos_i = list(pos_i[:,0])
    # Collect all data that match the desired label
#     x_i = [x[j] for j in pos_i]
    
    return pos_i

In [6]:
def create_dict(dataset_train):
    
    frac = int(len(dataset_train.data) * 0.05)
    x_reserve = dataset_train.data[:frac]
    y_reserve = dataset_train.targets[:frac]
    x_train = dataset_train.data[frac:]
    y_train = dataset_train.targets[frac:]
    
    reserved = get_class_i(x_reserve, y_reserve, classDict['plane']) \
               +get_class_i(x_reserve, y_reserve, classDict['car']) \
               +get_class_i(x_reserve, y_reserve, classDict['bird']) \
               +get_class_i(x_reserve, y_reserve, classDict['cat']) \
               +get_class_i(x_reserve, y_reserve, classDict['deer']) \
               +get_class_i(x_reserve, y_reserve, classDict['dog']) \
               +get_class_i(x_reserve, y_reserve, classDict['frog']) \
               +get_class_i(x_reserve, y_reserve, classDict['horse']) \
               +get_class_i(x_reserve, y_reserve, classDict['ship']) \
               +get_class_i(x_reserve, y_reserve, classDict['truck'])

    train1 = get_class_i(x_train, y_train, classDict['plane']) \
             +get_class_i(x_train, y_train, classDict['car']) \
             +get_class_i(x_train, y_train, classDict['bird'])

    train2 = get_class_i(x_train, y_train, classDict['cat']) \
             +get_class_i(x_train, y_train, classDict['deer']) \
             +get_class_i(x_train, y_train, classDict['dog'])
    train3 = get_class_i(x_train, y_train, classDict['frog']) \
             +get_class_i(x_train, y_train, classDict['horse']) \
             +get_class_i(x_train, y_train, classDict['ship']) \
             +get_class_i(x_train, y_train, classDict['truck'])
    
    dict_users = {0: set(reserved+train3), 1:set(train1), 2:set(train2)}
    return dict_users

In [7]:
if __name__ == '__main__':
    writer = SummaryWriter('../../runs/') 
    
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    # load dataset and split users
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
#             dict_users = cifar_iid(dataset_train, args.num_users)
            dict_users = create_dict(dataset_train)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'resnet' and args.dataset == 'cifar':
        net_glob = models.resnet18(pretrained=True)
        num_ftrs = net_glob.fc.in_features
        net_glob.fc = torch.nn.Linear(num_ftrs, 10)
        net_glob.to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)
    net_glob.train()

    # copy weights
    w_glob = net_glob.state_dict()

    # training
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []

    for iter in range(args.epochs):
        w_locals, loss_locals = [], []
        m = max(int(args.frac * args.num_users), 3)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        # update global weights
        w_glob = FedAvg(w_locals)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)
        
        # testing
        net_glob.eval()
        acc_test, loss_test = test_img(net_glob, dataset_test, args)
        print("Testing accuracy: {:.2f}".format(acc_test))

        writer.add_scalar('train/loss_federated2', loss_avg, iter)
        writer.add_scalar('valid/accuracy_federated2', acc_test.data.numpy()/100., iter)
        
        net_glob.train()

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

Files already downloaded and verified
Files already downloaded and verified
CNNCifar(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
Round   0, Average loss 2.104
Testing accuracy: 34.00
Round   1, Average loss 1.676
Testing accuracy: 45.00
Round   2, Average loss 1.517
Testing accuracy: 49.00
Round   3, Average loss 1.421
Testing accuracy: 51.00
Round   4, Average loss 1.350
Testing accuracy: 54.00
Round   5, Average loss 1.286
Testing accuracy: 56.00
Round   6, Average loss 1.229
Testing accuracy: 58.00
Round   7, Average loss 1.180
Testing accuracy: 59.00
Round   8, Average loss 1.139
Testing accuracy: 60.00
Round   9, Average loss 1.103
Testing a

### Testing ground

In [18]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

confusion_matrix = torch.zeros(10, 10)
net_glob.eval()
# testing
test_loss = 0
correct = 0
data_loader = DataLoader(dataset_test, batch_size=args.bs)
l = len(data_loader)
for idx, (data, target) in enumerate(data_loader):
    if args.gpu != -1:
        data, target = data.to(args.device), target.to(args.device)
    log_probs = net_glob(data)
    # sum up batch loss
    test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
    # get the index of the max log-probability
    y_pred = log_probs.data.max(1, keepdim=True)[1]
    correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    for t, p in zip(target.data, y_pred.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

In [19]:
print(confusion_matrix)

tensor([[588.,  43.,  61.,  44.,  42.,  15.,  14.,  16., 111.,  66.],
        [ 15., 791.,   7.,  10.,   5.,  13.,   9.,  14.,  36., 100.],
        [ 58.,   5., 467., 112., 105.,  90.,  61.,  55.,  36.,  11.],
        [ 27.,  13.,  74., 435.,  88., 182.,  71.,  62.,  20.,  28.],
        [ 22.,   7., 117., 102., 507.,  69.,  51.,  90.,  23.,  12.],
        [ 19.,   5.,  60., 253.,  60., 475.,  28.,  75.,  13.,  12.],
        [  5.,  13.,  72.,  99.,  63.,  51., 659.,  13.,  11.,  14.],
        [ 21.,   6.,  33.,  76.,  76.,  80.,  11., 663.,   7.,  27.],
        [ 61.,  49.,  15.,  29.,  18.,  12.,   7.,   9., 763.,  37.],
        [ 30., 116.,  10.,  40.,  12.,  13.,  15.,  18.,  54., 692.]])
