In [3]:
import torch
import numpy as np
import pandas as pd

from models.resnet import ResNet50
from utils.datasets import get_datasets
from utils.sampling import get_user_groups
from utils.reproducibility import make_it_reproducible
from feddyn.components import FedDynServer, FedDynClient

In [None]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [26]:
# setting parameters
ROUNDS = 200
alpha = 0.01
tot_clients = 100
participation = 0.1
cuda = device=="cuda"
norm = "Batch Norm"
iid = True
unbalanced = False
seed = 0

local_epochs = 5
lr = 0.1
weight_decay = 1e-3
momentum = 0
clip_value=10

In [27]:
make_it_reproducible(seed)

In [28]:
trainset, testset = get_datasets(augmentation=True)
user_groups, _ = get_user_groups(trainset, iid=iid, unbalanced=unbalanced, tot_users=tot_clients)

Files already downloaded and verified
Files already downloaded and verified


In [29]:
server = FedDynServer(ResNet50(norm), alpha, tot_clients, device, testset, seed)

clients = []
for cid in range(tot_clients):
    clients.append(FedDynClient(device, lr, weight_decay, momentum, alpha, cid, local_epochs, trainset,
                                user_groups[cid], clip_value))

In [None]:
train, test = [], []
for com_round in range(1, ROUNDS+1):
    print(f"Running communication round {com_round}...")
    
    server_state_dict = server.get_server_state()
    
    active_clients_models = []
    
    m = int(max(1, tot_clients * participation))
    chosen_users = np.random.choice(range(tot_clients), m, replace=False)
    
    for idx in chosen_users:
        state, metric = clients[idx].train(ResNet50(norm), server_state_dict, com_round)
        active_clients_models.append(state)
        train.append(metric)
        
    server.update_model(active_clients_models)
    test.append(server.evaluate(com_round))
    print("\n")
    
    if com_round % 5 == 0:
        df_train = pd.DataFrame(train)
        df_train.to_csv(f"feddyn_train_{'iid' if iid else 'noniid'}_{'unbalanced' if unbalanced else 'balanced'}_{seed}.csv", index=False)
        df_test = pd.DataFrame(test)
        df_test.to_csv(f"feddyn_test__{'iid' if iid else 'noniid'}_{'unbalanced' if unbalanced else 'balanced'}_{seed}.csv", index=False)