In [104]:
%reload_ext autoreload
%autoreload 2

import sys
import time
from datetime import datetime
from copy import deepcopy
from queue import Queue

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np

sys.path.append('../')
import models
from models import CF10Net
from devices import *
from data_utils import split_data, CustomSubset
from server import *
from client import *
from leader import *

In [86]:
def server_prepare_data():
    print("--> Preparing data...")
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=True, transform=transform)

    testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)

    test_idcs = np.random.permutation(len(testset))

    return CustomSubset(testset, test_idcs, transforms.Compose([transforms.ToTensor()])), testloader

In [87]:
def client_prepare_data(N_CLIENTS):
    print("--> Preparing and splitting data...")
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

    train_idcs = np.random.permutation(len(trainset))

    client_idcs = np.arange(0, len(trainset)).reshape(N_CLIENTS, int(len(trainset) / N_CLIENTS))

    train_labels = []
    for idc in client_idcs:
        for idcc in idc:
            train_labels.append(trainset[idcc][1])
    train_labels = np.array(train_labels)
    DIRICHLET_ALPHA = 10
    client_idcs = split_data(train_idcs, train_labels, alpha=DIRICHLET_ALPHA, n_clients=N_CLIENTS)

    return [CustomSubset(trainset, idcs) for idcs in client_idcs]

In [105]:
N_LEADERS = 2
N_CLIENTS = 10
ROUNDS = 10
lr = 0.001
l2_lambda = 0.01
beta = 5
select_rate=0.9

# Server
test_data, testloader = server_prepare_data()
server = Server(CF10Net, test_data, testloader, lr=lr, N=N_CLIENTS, beta=beta)

# Client
client_list = []
client_datas = client_prepare_data(N_CLIENTS)
for i, data in enumerate(client_datas):
    leader_id = -1
    if N_LEADERS > 0:
        group_size = int(N_CLIENTS / N_LEADERS)
        leader_id = int(i / group_size)
    client_list.append(Client(CF10Net, lambda x : torch.optim.SGD(x, lr=lr, momentum=0.9), data, id=i, l2_lambda=l2_lambda))
    
# Leader
leader_list = []
for i in range(N_LEADERS):
    leader = Leader(CF10Net, i)
    leader.server = server
    server.child_list.append(leader)
    for j in range(group_size * i, group_size * (i+1)):
        leader.child_list.append(client_list[j])
        client_list[j].parent = leader
    leader_list.append(leader)


--> Preparing data...
Files already downloaded and verified
--> Preparing and splitting data...
Files already downloaded and verified


In [107]:

for i in range(ROUNDS):
    
    server.send()
    
    for leader in leader_list:
        leader.pass_W()
       
    selected_client_list = random.sample(client_list, int(N_CLIENTS * select_rate))
        
    for client in selected_client_list:
        client.train()
        client.send()
        
    for leader in leader_list:
        leader.compute_dW()
        leader.send_dW()
        
    server.update()
    server.eval()

ROUNDS =  10 , selected =  [<client.Client object at 0x7fd1b07399d0>, <client.Client object at 0x7fd179e8f590>, <client.Client object at 0x7fd1a160f910>, <client.Client object at 0x7fd1e0b094d0>, <client.Client object at 0x7fd1b05df490>, <client.Client object at 0x7fd1b0617750>, <client.Client object at 0x7fd1b060f410>, <client.Client object at 0x7fd1b064b3d0>, <client.Client object at 0x7fd1b0733e50>]
[Client - 2 - trn] TIME = 0 - train done
[Client - 8 - trn] TIME = 0 - train done
[Client - 1 - trn] TIME = 0 - train done
[Client - 3 - trn] TIME = 0 - train done
[Client - 5 - trn] TIME = 0 - train done
[Client - 6 - trn] TIME = 0 - train done
[Client - 7 - trn] TIME = 0 - train done
[Client - 4 - trn] TIME = 0 - train done
[Client - 0 - trn] TIME = 0 - train done
[Server - upd]: Updated model with T = 1, t = 0, num = 5, alpha = 0.016
[Server - upd]: Updated model with T = 1, t = 0, num = 4, alpha = 0.0128
[Server - eval] TIME = 1, acc = [ 0.1069 ]
[Server - eval] - error
ROUNDS =  10 

[Client - 2 - trn] TIME = 8 - train done
[Client - 1 - trn] TIME = 8 - train done
[Client - 9 - trn] TIME = 8 - train done
[Client - 8 - trn] TIME = 8 - train done
[Client - 4 - trn] TIME = 8 - train done
[Client - 3 - trn] TIME = 8 - train done
[Client - 5 - trn] TIME = 8 - train done
[Client - 6 - trn] TIME = 8 - train done
[Client - 7 - trn] TIME = 8 - train done
[Server - upd]: Updated model with T = 9, t = 8, num = 4, alpha = 0.0128
[Server - upd]: Updated model with T = 9, t = 8, num = 5, alpha = 0.016
[Server - eval] TIME = 9, acc = [ 0.1086 ]
[Server - eval] - error
ROUNDS =  10 , selected =  [<client.Client object at 0x7fd1b060f410>, <client.Client object at 0x7fd1b05df150>, <client.Client object at 0x7fd1b05df490>, <client.Client object at 0x7fd1b0617750>, <client.Client object at 0x7fd179e8f590>, <client.Client object at 0x7fd1b064b3d0>, <client.Client object at 0x7fd1a160f910>, <client.Client object at 0x7fd1e0b094d0>, <client.Client object at 0x7fd1b07399d0>]


KeyboardInterrupt: 