In [32]:
%reload_ext autoreload
%autoreload 2

import sys
import time
from datetime import datetime
from copy import deepcopy
from queue import Queue

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np

sys.path.append('../')
import models
from models import CF10Net
from devices import *
from data_utils import split_data, CustomSubset
from server import *
from client import *
from leader import *

In [23]:
def server_prepare_data():
    print("--> Preparing data...")
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=True, transform=transform)

    testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)

    test_idcs = np.random.permutation(len(testset))

    return CustomSubset(testset, test_idcs, transforms.Compose([transforms.ToTensor()])), testloader

In [27]:
def client_prepare_data(N_CLIENTS):
    print("--> Preparing and splitting data...")
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

    train_idcs = np.random.permutation(len(trainset))

    client_idcs = np.arange(0, len(trainset)).reshape(N_CLIENTS, int(len(trainset) / N_CLIENTS))

    train_labels = []
    for idc in client_idcs:
        for idcc in idc:
            train_labels.append(trainset[idcc][1])
    train_labels = np.array(train_labels)
    DIRICHLET_ALPHA = 10
    client_idcs = split_data(train_idcs, train_labels, alpha=DIRICHLET_ALPHA, n_clients=N_CLIENTS)

    return [CustomSubset(trainset, idcs) for idcs in client_idcs]

In [33]:
N_LEADERS = 1
N_CLIENTS = 4

# Server
test_data, testloader = server_prepare_data()
server = Server(CF10Net, test_data, testloader)

# Client
client_list = []
client_datas = client_prepare_data(N_CLIENTS)
for i, data in enumerate(client_datas):
    leader_id = -1
    if N_LEADERS > 0:
        group_size = int(N_CLIENTS / N_LEADERS)
        leader_id = int(i / group_size)
    client_list.append(Client(CF10Net, lambda x : torch.optim.SGD(x, lr=0.001, momentum=0.9), data, id=i))
    
# Leader
leader_list = []
for i in range(N_LEADERS):
    leader = Leader(CF10Net, i)
    for j in range(group_size * i, group_size * (i+1)):
        leader.client_list.append(client_list[j])
    leader_list.append(leader)


--> Preparing data...
Files already downloaded and verified
--> Preparing and splitting data...
Files already downloaded and verified


[Server - eval] rd = 1, acc = [ 0.1 ]
