# FedAvg

Vanila FedAvg

In [1]:
import numpy as np

import copy
import os 
import gc

import torch
from torch import nn, optim, autograd
import torch.nn.functional as F
import torch.nn.init as init

from collections import OrderedDict
import matplotlib.pyplot as plt

# from torch.utils.data import DataLoader, Dataset
# from torchvision import datasets, transforms

# import torch.utils.data as data

In [2]:
print(f'Current GPU: {torch.cuda.current_device()}')
print(f'GPU Name: {torch.cuda.get_device_name()}')
print(f'Number of GPUs: {torch.cuda.device_count()}')
torch.cuda.set_device(1) ## Setting cuda on GPU:0

Current GPU: 0
GPU Name: GeForce GTX 1080 Ti
Number of GPUs: 2


In [57]:
class Args: 
    num_users = 2
    seed = 1
    gpu = 1
    
    ## CIFAR-10 has 50000 training images (5000 per class), 10 classes, 10000 test images (1000 per class)
    ## CIFAR-100 has 50000 training images (500 per class), 100 classes, 10000 test images (100 per class)
    ## MNIST has 60000 training images (min: 5421, max: 6742 per class), 10000 test images (min: 892, max: 1135
    ## per class) --> in the code we fixed 5000 training image per class, and 900 test image per class to be 
    ## consistent with CIFAR-10 
    
    ## CIFAR-10 Non-IID 250 samples per label for 2 class non-iid is the benchmark (500 samples for each client)
    
    nsample_pc = 250  ## number of samples per class for each client 
    nclass = 2        ## number of classes or shards for each client
    model = 'resnet9' ## options: lenet5
    dataset = 'cifar100'  ## options: mnist, cifar10, cifar100
    datadir = '../data/'
    logdir = '../logs/'
    partition = 'noniid-#label20'
    alg = 'cluster_fl'
    savedir = '../save/'
    beta = 0.1
    local_view = True
    batch_size= 10
    noise = 0
    noise_type = 'level'
    
    rounds = 30
    frac = 0.1
    local_bs = 10
    local_ep = 5
    lr = 0.01
    momentum = 0.5
    
    cluster_alpha = 3.5
    nclasses = 10 
    nsamples_shared = 2500
    n_basis = 3
    linkage = 'average'
    
    noniid = True
    noniid_iid = False
    shard = True
    label = False
    split_test = False
    
    print_freq = 50
    
    load_initial = ''
    
args = Args()

torch.cuda.set_device(args.gpu) ## Setting cuda on GPU 
#torch.manual_seed(args.seed)
#np.random.seed(args.seed)

args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')

## Dataset NIID Benchmark 

In [58]:
from __future__ import absolute_import
from src.client import * 
from src.utils import *
from src.data import *

In [59]:
# import sys
# sys.path
# import os
# os.getcwd()

In [60]:
args.dataset = 'cifar100'
args.partition='noniid-#label20'
args.num_users=100

args.rounds=100
args.frac = 0.1
args.local_bs = 10
args.local_ep = 10
args.lr = 0.01
args.momentum = 0.9

args.bias=True

In [61]:
train_dl_global, test_dl_global, train_ds_global, test_ds_global = get_dataloader(args.dataset,
                                                                                   args.datadir,
                                                                                   args.batch_size,
                                                                                   32)

print("len train_ds_global:", len(train_ds_global))
print("len test_ds_global:", len(test_ds_global))

Files already downloaded and verified
Files already downloaded and verified
len train_ds_global: 50000
len test_ds_global: 10000


## Model

In [62]:
from src.models import *

In [63]:
def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1):
    return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1))

In [64]:
class SimpleCNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim=10, bias=True):
        super(SimpleCNN, self).__init__()
        self.bias=bias
        self.act=OrderedDict()
        self.input_size =[]
        self.ksize=[]
        self.in_channel =[]
        
        self.conv1 = nn.Conv2d(3, 6, 5, bias=self.bias)
        self.input_size.append(32)
        self.ksize.append(5)
        self.in_channel.append(3)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        s=compute_conv_output_size(32,5)
        s=s//2
        
        self.input_size.append(s)
        self.ksize.append(5)
        self.in_channel.append(6)
        self.conv2 = nn.Conv2d(6, 16, 5, bias=self.bias)

        # for now, we hard coded this network
        # i.e. we fix the number of hidden layers i.e. 2 layers
        s=compute_conv_output_size(s,5)
        s=s//2
        self.input_size.append(s*s*16)
        self.fc1 = nn.Linear(input_dim, hidden_dims[0], bias=self.bias)
        
        self.input_size.append(hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1], bias=self.bias)
        
        self.input_size.append(hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], output_dim, bias=self.bias)

    def forward(self, x):
        self.act['conv1']=x
        x = self.pool(F.relu(self.conv1(x)))
        
        self.act['conv2']=x
        x = self.pool(F.relu(self.conv2(x)))
        
        x = x.view(-1, 16 * 5 * 5)
        self.act['fc1']=x
        x = F.relu(self.fc1(x))
        
        self.act['fc2']=x
        x = F.relu(self.fc2(x))
        
        self.act['fc3']=x
        x = self.fc3(x)
        return x

In [65]:
def init_nets(args, dropout_p=0.5, bias=True):

    users_model = []

    for net_i in range(-1, args.num_users):
        if args.dataset == "generated":
            net = PerceptronModel().to(args.device)
        elif args.model == "mlp":
            if args.dataset == 'covtype':
                input_size = 54
                output_size = 2
                hidden_sizes = [32,16,8]
            elif args.dataset == 'a9a':
                input_size = 123
                output_size = 2
                hidden_sizes = [32,16,8]
            elif args.dataset == 'rcv1':
                input_size = 47236
                output_size = 2
                hidden_sizes = [32,16,8]
            elif args.dataset == 'SUSY':
                input_size = 18
                output_size = 2
                hidden_sizes = [16,8]
            net = FcNet(input_size, hidden_sizes, output_size, dropout_p).to(args.device)
        elif args.model == "vgg":
            net = vgg11().to(args.device)
        elif args.model == "simple-cnn":
            if args.dataset in ("cifar10", "cinic10", "svhn"):
                net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=10, bias=bias).to(args.device)
            elif args.dataset in ("mnist", 'femnist', 'fmnist'):
                net = SimpleCNNMNIST(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=10).to(args.device)
            elif args.dataset == 'celeba':
                net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=2).to(args.device)
        elif args.model =="simple-cnn-3":
            if args.dataset == 'cifar100': 
                net = SimpleCNN_3(input_dim=(16 * 3 * 5 * 5), hidden_dims=[120*3, 84*3], output_dim=100).to(args.device)
            if args.dataset == 'tinyimagenet':
                net = SimpleCNNTinyImagenet_3(input_dim=(16 * 3 * 13 * 13), hidden_dims=[120*3, 84*3], 
                                              output_dim=200).to(args.device)
        elif args.model == "vgg-9":
            if args.dataset in ("mnist", 'femnist'):
                net = ModerateCNNMNIST().to(args.device)
            elif args.dataset in ("cifar10", "cinic10", "svhn"):
                # print("in moderate cnn")
                net = ModerateCNN().to(args.device)
            elif args.dataset == 'celeba':
                net = ModerateCNN(output_dim=2).to(args.device)
        elif args.model == 'resnet9':
            if args.dataset == 'cifar100':
                net = ResNet9(in_channels=3, num_classes=100)
            elif args.dataset == 'tinyimagenet':
                net = ResNet9(in_channels=3, num_classes=200, dim=512*2*2)
        elif args.model == "resnet":
            net = ResNet50_cifar10().to(args.device)
        elif args.model == "vgg16":
            net = vgg16().to(args.device)
        else:
            print("not supported yet")
            exit(1)
        if net_i == -1:
            net_glob = copy.deepcopy(net)
            initial_state_dict = copy.deepcopy(net_glob.state_dict())
            server_state_dict = copy.deepcopy(net_glob.state_dict())
            if args.load_initial:
                initial_state_dict = torch.load(args.load_initial)
                server_state_dict = torch.load(args.load_initial)
                net_glob.load_state_dict(initial_state_dict)
        else:
            users_model.append(copy.deepcopy(net))
            users_model[net_i].load_state_dict(initial_state_dict)

    return users_model, net_glob, initial_state_dict, server_state_dict

print(f'MODEL: {args.model}, Dataset: {args.dataset}')

# users_model, net_glob, initial_state_dict, server_state_dict = init_nets(args, dropout_p=0.5, bias=args.bias)

# print(net_glob)

# total = 0 
# for name, param in net_glob.named_parameters():
#     print(name, param.size())
#     total += np.prod(param.size())
#     #print(np.array(param.data.cpu().numpy().reshape([-1])))
#     #print(isinstance(param.data.cpu().numpy(), np.array))
# print(total)

MODEL: resnet9, Dataset: cifar100


In [66]:
import torch.nn as nn

def conv_bn_relu_pool(in_channels, out_channels, pool=False):
    layers = [
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        #nn.GroupNorm(32,out_channels),
        nn.ReLU(inplace=True)
    ]
    if pool:
        layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet9(nn.Module):
    def __init__(self, in_channels, num_classes, dim=512):
        super().__init__()
        self.prep = conv_bn_relu_pool(in_channels, 64)
        self.layer1_head = conv_bn_relu_pool(64, 128, pool=True)
        self.layer1_residual = nn.Sequential(conv_bn_relu_pool(128, 128), conv_bn_relu_pool(128, 128))
        self.layer2 = conv_bn_relu_pool(128, 256, pool=True)
        self.layer3_head = conv_bn_relu_pool(256, 512, pool=True)
        self.layer3_residual = nn.Sequential(conv_bn_relu_pool(512, 512), conv_bn_relu_pool(512, 512))
        self.MaxPool2d = nn.Sequential(
            nn.MaxPool2d(4))
        self.linear = nn.Linear(dim, num_classes)
        # self.classifier = nn.Sequential(
        #     nn.MaxPool2d(4),
        #     nn.Flatten(),
        #     nn.Linear(512, num_classes))


    def forward(self, x):
        x = self.prep(x)
        x = self.layer1_head(x)
        x = self.layer1_residual(x) + x
        x = self.layer2(x)
        x = self.layer3_head(x)
        x = self.layer3_residual(x) + x
        x = self.MaxPool2d(x)
        x = x.view(x.size(0), -1)
        #print(x.shape)
        x = self.linear(x)
        return x

## Clients Data Loading

In [67]:
print(f'Loading {args.dataset}, {args.partition} for all clients')

args.local_view = True
X_train, y_train, X_test, y_test, net_dataidx_map, net_dataidx_map_test, \
traindata_cls_counts, testdata_cls_counts = partition_data(args.dataset, 
args.datadir, args.logdir, args.partition, args.num_users, beta=args.beta, local_view=args.local_view)

train_dl_global, test_dl_global, train_ds_global, test_ds_global = get_dataloader(args.dataset,
                                                                                   args.datadir,
                                                                                   args.batch_size,
                                                                                   128)

Loading cifar100, noniid-#label20 for all clients
Files already downloaded and verified
Files already downloaded and verified
K: 100
partition: noniid-#label20
Data statistics Train:
 {0: {0: 39, 3: 39, 4: 39, 8: 24, 10: 28, 14: 27, 19: 21, 23: 24, 36: 25, 37: 17, 42: 25, 51: 27, 53: 27, 58: 24, 66: 27, 71: 23, 72: 28, 78: 18, 83: 18, 90: 24}, 1: {1: 34, 5: 24, 6: 23, 7: 27, 9: 32, 18: 23, 20: 23, 36: 25, 38: 39, 39: 27, 40: 23, 44: 25, 45: 22, 46: 22, 50: 20, 53: 27, 62: 23, 80: 34, 84: 25, 86: 21}, 2: {2: 39, 8: 24, 12: 25, 18: 23, 27: 25, 29: 27, 30: 27, 39: 27, 50: 20, 52: 27, 56: 22, 61: 27, 62: 23, 68: 24, 72: 28, 74: 24, 81: 39, 82: 25, 90: 24, 94: 22}, 3: {3: 39, 6: 23, 16: 28, 18: 23, 21: 17, 23: 24, 24: 28, 31: 20, 35: 19, 37: 17, 53: 27, 57: 22, 66: 27, 78: 18, 79: 25, 82: 25, 83: 18, 86: 21, 88: 24, 95: 28}, 4: {4: 39, 9: 32, 12: 25, 13: 20, 19: 21, 30: 27, 32: 22, 33: 22, 35: 19, 37: 17, 47: 23, 49: 27, 60: 27, 62: 23, 71: 23, 73: 32, 82: 25, 85: 25, 87: 32, 91: 25}, 5: {5

Files already downloaded and verified
Files already downloaded and verified


In [68]:
users_model, net_glob, initial_state_dict, server_state_dict = init_nets(args, dropout_p=0.5, bias=args.bias)

total=0
for name, param in net_glob.named_parameters():
    print(name, param.size())
    total += np.prod(param.size())
print(total)

prep.0.weight torch.Size([64, 3, 3, 3])
prep.0.bias torch.Size([64])
prep.1.weight torch.Size([64])
prep.1.bias torch.Size([64])
layer1_head.0.weight torch.Size([128, 64, 3, 3])
layer1_head.0.bias torch.Size([128])
layer1_head.1.weight torch.Size([128])
layer1_head.1.bias torch.Size([128])
layer1_residual.0.0.weight torch.Size([128, 128, 3, 3])
layer1_residual.0.0.bias torch.Size([128])
layer1_residual.0.1.weight torch.Size([128])
layer1_residual.0.1.bias torch.Size([128])
layer1_residual.1.0.weight torch.Size([128, 128, 3, 3])
layer1_residual.1.0.bias torch.Size([128])
layer1_residual.1.1.weight torch.Size([128])
layer1_residual.1.1.bias torch.Size([128])
layer2.0.weight torch.Size([256, 128, 3, 3])
layer2.0.bias torch.Size([256])
layer2.1.weight torch.Size([256])
layer2.1.bias torch.Size([256])
layer3_head.0.weight torch.Size([512, 256, 3, 3])
layer3_head.0.bias torch.Size([512])
layer3_head.1.weight torch.Size([512])
layer3_head.1.bias torch.Size([512])
layer3_residual.0.0.weight to

In [69]:
torch.save(initial_state_dict, 'resnet9-init.pth')

In [None]:
#initial_state_dict = copy.deepcopy(initial_state_dict)
net_glob.load_state_dict(initial_state_dict)
for model in users_model:
    model.load_state_dict(initial_state_dict)

In [None]:
clients = []
for idx in range(args.num_users):
    
    dataidxs = net_dataidx_map[idx]
    if net_dataidx_map_test is None:
        dataidx_test = None 
    else:
        dataidxs_test = net_dataidx_map_test[idx]
    
    noise_level = 0
    bs = args.local_bs
        
    train_dl_local, test_dl_local, train_ds_local, test_ds_local = get_dataloader(args.dataset, 
                                                                   args.datadir, args.local_bs, 64, 
                                                                   dataidxs, noise_level, 
                                                                   dataidxs_test=dataidxs_test)
    
    clients.append(Client_FedAvg(idx, copy.deepcopy(users_model[idx]), args.local_bs, args.local_ep, 
               args.lr, args.momentum, args.device, train_dl_local, test_dl_local))
    
    gc.collect()

In [None]:
print('Starting FL')
print('-'*40)
start = time.time()

loss_train = []
clients_local_acc = {i:[] for i in range(args.num_users)}
w_locals, loss_locals = [], []
glob_acc = []

w_glob = copy.deepcopy(initial_state_dict)
for iteration in range(args.rounds):
        
    m = max(int(args.frac * args.num_users), 1)
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)
    
    print(f'----- ROUND {iteration+1} -----') 
    print(idxs_users)
    sys.stdout.flush()
    
    for idx in idxs_users:
        clients[idx].set_state_dict(copy.deepcopy(w_glob)) 
                   
        loss = clients[idx].train( is_print=False)
        loss_locals.append(copy.deepcopy(loss))
        
    # print loss
    loss_avg = sum(loss_locals) / len(loss_locals)
    template = '-- Average Train loss {:.3f}'
    print(template.format(loss_avg))
    
    ####### FedAvg ####### START
    total_data_points = sum([len(net_dataidx_map[r]) for r in idxs_users])
    fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in idxs_users]
    w_locals = []
    for idx in idxs_users:
        w_locals.append(copy.deepcopy(clients[idx].get_state_dict()))

    ww = FedAvg(w_locals, weight_avg=fed_avg_freqs)
    w_glob = copy.deepcopy(ww)
    net_glob.load_state_dict(copy.deepcopy(ww))
    ####### FedAvg ####### END
    _, acc = eval_test(net_glob, args, test_dl_global)
    
    glob_acc.append(acc)
    template = "-- Global Acc: {:.3f}, Global Best Acc: {:.3f}"
    print(template.format(glob_acc[-1], np.max(glob_acc)))
    
    loss_train.append(loss_avg)
    
    ## clear the placeholders for the next round
    loss_locals.clear()
    
    ## calling garbage collector
    gc.collect()
    
end = time.time()
duration = end-start
print('-'*40)

In [None]:
np.mean(glob_acc[-10:])

In [None]:
local_acc = []
for i in range(args.num_users):
    _, acc = clients[i].eval_test()
    local_acc.append(acc)
    print(f'Client {i}, Acc {acc:.2f}')
    
print(np.mean(local_acc))