In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms, models
import torch
import sys

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img

# from torch.utils.tensorboard import SummaryWriter
# %load_ext tensorboard

In [2]:
sys.argv = ['','--num_channels','3','--model','resnet','--epochs','60',
            '--gpu','3','--num_users','3','--dataset', 'cifar', '--lr','0.00001', 
            '--local_ep','1', '--bs','69', '--local_b', '11']

In [3]:
args = args_parser()

In [4]:
args

Namespace(bs=69, dataset='cifar', epochs=60, frac=0.1, gpu=3, iid=False, kernel_num=9, kernel_sizes='3,4,5', local_bs=11, local_ep=1, lr=1e-05, max_pool='True', model='resnet', momentum=0.5, norm='batch_norm', num_channels=3, num_classes=10, num_filters=32, num_users=3, seed=1, split='user', stopping_rounds=10, verbose=False)

In [5]:
classDict = {'T-shirt/top':0, 'Trouser':1, 'Pullover':2, 'Dress':3, 'Coat':4, 'Sandal':5, 'Shirt':6, 'Sneaker':7, 'Bag':8, 'Ankle boot':9}

# Define a function to separate CIFAR classes by class index

def get_class_i(x, y, i):
    """
    x: trainset.train_data or testset.test_data
    y: trainset.train_labels or testset.test_labels
    i: class label, a number between 0 to 9
    return: x_i
    """
    # Convert to a numpy array
    y = np.array(y)
    # Locate position of labels that equal to i
    pos_i = np.argwhere(y == i)
    # Convert the result into a 1-D list
    pos_i = list(pos_i[:,0])
    # Collect all data that match the desired label
#     x_i = [x[j] for j in pos_i]
    
    return pos_i

In [6]:
def create_dict(dataset_train):
    
    frac = int(len(dataset_train.data) * 0.05)
    x_reserve = dataset_train.data[:frac]
    y_reserve = dataset_train.targets[:frac]
    x_train = dataset_train.data[frac:]
    y_train = dataset_train.targets[frac:]
    
    reserved = get_class_i(x_reserve, y_reserve, classDict['T-shirt/top']) \
               +get_class_i(x_reserve, y_reserve, classDict['Trouser']) \
               +get_class_i(x_reserve, y_reserve, classDict['Pullover']) \
               +get_class_i(x_reserve, y_reserve, classDict['Dress']) \
               +get_class_i(x_reserve, y_reserve, classDict['Coat']) \
               +get_class_i(x_reserve, y_reserve, classDict['Sandal']) \
               +get_class_i(x_reserve, y_reserve, classDict['Shirt']) \
               +get_class_i(x_reserve, y_reserve, classDict['Sneaker']) \
               +get_class_i(x_reserve, y_reserve, classDict['Bag']) \
               +get_class_i(x_reserve, y_reserve, classDict['Ankle boot'])

    train1 = get_class_i(x_train, y_train, classDict['T-shirt/top']) \
             +get_class_i(x_train, y_train, classDict['Trouser']) \
             +get_class_i(x_train, y_train, classDict['Pullover'])

    train2 = get_class_i(x_train, y_train, classDict['Dress']) \
             +get_class_i(x_train, y_train, classDict['Coat']) \
             +get_class_i(x_train, y_train, classDict['Sandal'])
    
    train3 = get_class_i(x_train, y_train, classDict['Shirt']) \
             +get_class_i(x_train, y_train, classDict['Sneaker']) \
             +get_class_i(x_train, y_train, classDict['Bag']) \
             +get_class_i(x_train, y_train, classDict['Ankle boot'])
    
    dict_users = {0: set(reserved+train3), 1:set(train1), 2:set(train2)}
    return dict_users

In [None]:
if __name__ == '__main__':
#     writer = SummaryWriter('../../runs/') 
    reslist = []
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    reslist = []
    # load dataset and split users
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
        dataset_train = datasets.FashionMNIST('../../../data/fmnist', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.FashionMNIST('../../../data/fmnist', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            dict_users = create_dict(dataset_train)
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'resnet' and args.dataset == 'cifar':
        net_glob = models.resnet50(pretrained=True)
        num_ftrs = net_glob.fc.in_features
        net_glob.fc = torch.nn.Linear(num_ftrs, 10)
        net_glob.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        net_glob.to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)
    net_glob.train()

    # copy weights
    w_glob = net_glob.state_dict()

    # training
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []

    for iter in range(args.epochs):
        w_locals, loss_locals = [], []
        m = max(int(args.frac * args.num_users), 3)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        # update global weights
        w_glob = FedAvg(w_locals)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)
        
        # testing
        net_glob.eval()
        acc_test, loss_test = test_img(net_glob, dataset_test, args)
        print("Testing accuracy: {:.2f}".format(acc_test))

#         writer.add_scalar('train/loss_federated', loss_avg, iter)
#         writer.add_scalar('valid/accuracy_federated', acc_test.data.numpy()/100., iter)
        reslist.append(acc_test.data.numpy()/100.)
        net_glob.train()

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [11]:
reslist

[0.1397999954223633,
 0.3611000061035156,
 0.44810001373291014,
 0.49650001525878906,
 0.5650999832153321,
 0.6027999877929687,
 0.4343000030517578,
 0.647300033569336,
 0.6658000183105469,
 0.645199966430664,
 0.6773999786376953,
 0.7038999938964844,
 0.29270000457763673,
 0.544000015258789,
 0.7112999725341796,
 0.7286000061035156,
 0.7366999816894532,
 0.7440000152587891,
 0.7280000305175781,
 0.7527999877929688,
 0.7616999816894531,
 0.7647000122070312,
 0.7708000183105469,
 0.7683999633789063,
 0.7708999633789062,
 0.7758999633789062,
 0.7776999664306641,
 0.779000015258789,
 0.7769000244140625,
 0.7619000244140625,
 0.7851000213623047,
 0.7898999786376953,
 0.7901000213623047,
 0.7922000122070313,
 0.7961000061035156,
 0.7962999725341797,
 0.7972000122070313,
 0.8005000305175781,
 0.8012000274658203,
 0.7998999786376954,
 0.8056999969482422,
 0.8076000213623047,
 0.8091999816894532,
 0.8036000061035157,
 0.8090000152587891,
 0.7926000213623047,
 0.8102999877929687,
 0.81290000915

### Testing ground

In [12]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

confusion_matrix = torch.zeros(10, 10)
net_glob.eval()
# testing
test_loss = 0
correct = 0
data_loader = DataLoader(dataset_test, batch_size=args.bs)
l = len(data_loader)
for idx, (data, target) in enumerate(data_loader):
    if args.gpu != -1:
        data, target = data.to(args.device), target.to(args.device)
    log_probs = net_glob(data)
    # sum up batch loss
    test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
    # get the index of the max log-probability
    y_pred = log_probs.data.max(1, keepdim=True)[1]
    correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    for t, p in zip(target.data, y_pred.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

In [13]:
print(confusion_matrix)

tensor([[752.,   6.,  40.,  76.,  12.,   1.,  87.,   3.,  23.,   0.],
        [  4., 951.,   6.,  27.,   6.,   0.,   3.,   0.,   3.,   0.],
        [ 12.,   0., 729.,  21., 139.,   0.,  90.,   0.,   9.,   0.],
        [ 21.,   8.,  17., 875.,  40.,   0.,  32.,   0.,   7.,   0.],
        [  1.,   1., 126.,  42., 705.,   0., 116.,   0.,   9.,   0.],
        [  0.,   0.,   0.,   3.,   0., 933.,   0.,  44.,   2.,  18.],
        [125.,   6., 154.,  66., 102.,   0., 509.,   0.,  38.,   0.],
        [  0.,   0.,   0.,   0.,   0.,  42.,   0., 918.,   1.,  39.],
        [  1.,   1.,   9.,   4.,   2.,   6.,  10.,   8., 959.,   0.],
        [  0.,   0.,   0.,   1.,   0.,   7.,   1.,  71.,   1., 919.]])
