In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms, models
import torch
import sys
import os
from PIL import Image

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, datasets, models

# from torch.utils.tensorboard import SummaryWriter
# %load_ext tensorboard

In [2]:
sys.argv = ['','--num_channels','3','--model','resnet','--epochs','60',
            '--gpu','2','--num_users','3','--dataset', 'cifar', '--lr','0.01', 
            '--local_ep','1', '--bs', '32']

In [3]:
args = args_parser()

In [4]:
args

Namespace(bs=32, dataset='cifar', epochs=60, frac=0.1, gpu=2, iid=False, kernel_num=9, kernel_sizes='3,4,5', local_bs=10, local_ep=1, lr=0.1, max_pool='True', model='resnet', momentum=0.5, norm='batch_norm', num_channels=3, num_classes=10, num_filters=32, num_users=3, seed=1, split='user', stopping_rounds=10, verbose=False)

In [5]:
class CustomDataSet(Dataset):
    def __init__(self, main_dir, x, y, transform):
        self.main_dir = main_dir
        self.transform = transform
        self.y = y
        self.all_imgs = x
#         self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.all_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image,self.y[idx]

In [6]:
classDict = {'not_smiling':0, 'smiling':1}

# Define a function to separate CIFAR classes by class index

def get_class_i(x, y, i):
    """
    x: trainset.train_data or testset.test_data
    y: trainset.train_labels or testset.test_labels
    i: class label, a number between 0 to 9
    return: x_i
    """
    # Convert to a numpy array
    y = np.array(y)
    # Locate position of labels that equal to i
    pos_i = np.argwhere(y == i)
    # Convert the result into a 1-D list
    pos_i = list(pos_i[:,0])
    # Collect all data that match the desired label
#     x_i = [x[j] for j in pos_i]
    
    return pos_i

In [7]:
# Standard transformations for improving celebA. 
# Transformations A
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop((218,178), padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Transformations B
RC   = transforms.RandomCrop((218,178), padding=4)
RHF  = transforms.RandomHorizontalFlip()
RVF  = transforms.RandomVerticalFlip()
NRM  = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
TT   = transforms.ToTensor()
TPIL = transforms.ToPILImage()

# Transforms object for trainset with augmentation
transform_with_aug = transforms.Compose([RC, RHF, TT, NRM])
# Transforms object for testset with NO augmentation
transform_no_aug   = transforms.Compose([TT, NRM])

# Downloading/Louding CELEBA data
trainset = torchvision.datasets.CelebA(root='../../../data/celebA', split = 'train',
                                        download=True, transform=transform_with_aug)

testset = torchvision.datasets.CelebA(root='../../../data/celebA', split='test',
                                       download=True, transform=transform_no_aug)

classDict = {'not_smiling':0, 'smiling':1}

# Separating trainset/testset data/label
x_train  = trainset
x_test   = testset
y_train  = trainset.attr[:,31] ## 31 is smile
y_test   = testset.attr[:,31]

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [8]:
def create_dict(x_train=x_train, x_test=x_test, y_train=y_train, y_test=y_test):
    # If we are saving a fraction of random data to be used in training
    frac1 = int(len(x_test) * 0.316666666667)
    frac2 = int(len(x_test) * 0.633333333333)

    x_train1 = list(range(0,frac1))
    x_train2 = list(range(frac1,frac2))
    x_train3 = list(range(frac2,len(x_test)))
    
#     trainset1 = CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_train1,y=y_train1, transform=transform_with_aug)
#     trainset2 = CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_train2,y=y_train2, transform=transform_with_aug)
#     trainset3 = CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_train3,y=y_train3, transform=transform_with_aug)
#     testset = CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_test,y=y_test, transform=transform_no_aug)

    
    dict_users = {0: x_train3, 1:x_train1, 2:x_train2}
    return dict_users

In [9]:
if __name__ == '__main__':
#     writer = SummaryWriter('../../runs/') 
    reslist = []
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    reslist = []
    # load dataset and split users
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        if args.iid:
            dataset_train = datasets.CIFAR10('../../../data/celebA', train=True, download=True, transform=trans_cifar)
            dataset_test = datasets.CIFAR10('../../../data/celebA', train=False, download=True, transform=trans_cifar)
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            dataset_train=CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_train.filename,y=y_train, transform=transform_with_aug)
            dataset_test=CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_test.filename,y=y_test, transform=transform_no_aug)
            dict_users = create_dict()
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'resnet' and args.dataset == 'cifar':
        net_glob = models.resnet50(pretrained=True)
        num_ftrs = net_glob.fc.in_features
        net_glob.fc = torch.nn.Linear(num_ftrs, 2)
        net_glob.to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)
    net_glob.train()

    # copy weights
    w_glob = net_glob.state_dict()

    # training
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []

    for iter in range(args.epochs):
        w_locals, loss_locals = [], []
        m = max(int(args.frac * args.num_users), 3)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        # update global weights
        w_glob = FedAvg(w_locals)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)
        
        # testing
        net_glob.eval()
        acc_test, loss_test = test_img(net_glob, dataset_test, args)
        print("Testing accuracy: {:.2f}".format(acc_test))

#         writer.add_scalar('train/loss_federated', loss_avg, iter)
#         writer.add_scalar('valid/accuracy_federated', acc_test.data.numpy()/100., iter)
        reslist.append(acc_test.data.numpy()/100.)
        net_glob.train()

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [10]:
reslist

[0.49969944000244143,
 0.5003005599975586,
 0.5003005599975586,
 0.49969944000244143,
 0.5273018646240234,
 0.49969944000244143,
 0.7937581634521484,
 0.8616371154785156,
 0.8935978698730469,
 0.8976555633544921,
 0.907925033569336,
 0.9066226196289062,
 0.9022142028808594,
 0.9087265777587891,
 0.9094279479980468,
 0.9124837493896485,
 0.9133854675292968,
 0.9121831512451172,
 0.9114317321777343,
 0.915589599609375,
 0.9152890777587891,
 0.9175934600830078,
 0.9090772247314454,
 0.9139865875244141,
 0.9180944061279297,
 0.9183448791503906,
 0.917142562866211,
 0.9159403228759766,
 0.9194970703125,
 0.9217513275146484,
 0.9204488372802735,
 0.9185953521728516,
 0.9191964721679687,
 0.911632080078125,
 0.9221019744873047,
 0.9149383544921875,
 0.9199980163574218,
 0.9163911437988281,
 0.9154393005371094,
 0.9164412689208984,
 0.9223524475097656,
 0.9153391265869141,
 0.9226530456542968,
 0.9252079010009766,
 0.9240557098388672,
 0.9212503814697266,
 0.5358180618286132,
 0.92530807495117

### Testing ground

In [13]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

confusion_matrix = torch.zeros(2, 2)
net_glob.eval()
# testing
test_loss = 0
correct = 0
data_loader = DataLoader(dataset_test, batch_size=args.bs)
l = len(data_loader)
for idx, (data, target) in enumerate(data_loader):
    if args.gpu != -1:
        data, target = data.to(args.device), target.to(args.device)
    log_probs = net_glob(data)
    # sum up batch loss
    test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
    # get the index of the max log-probability
    y_pred = log_probs.data.max(1, keepdim=True)[1]
    correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    for t, p in zip(target.data, y_pred.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

In [14]:
print(confusion_matrix)

tensor([[9327.,  648.],
        [ 832., 9155.]])
