In [None]:
import copy
import numpy as np
import pandas as pd
import torch

from models.resnet import ResNet50
from utils.reproducibility import make_it_reproducible, seed_worker
from utils.datasets import get_datasets
from utils.sampling import get_user_groups
from fedavg.utils import average_weights
from fedavg.client import LocalUpdate

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# reproducibility
seed = 0  # 128, 479
g = torch.Generator()

In [None]:
# setting parameters
ROUNDS = 50
tot_users = 100
selection_fraction = 0.1
local_batch_size = 10
local_epochs = 1

In [None]:
trainset, testset = get_datasets(augmentation=True)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()

metrics = []
for iid, unbalanced, norm in [
        [True, False, "Batch Norm"], [False, True, "Batch Norm"], [False, False, "Batch Norm"],
        [True, False, "Group Norm"], [False, True, "Group Norm"], [False, False, "Group Norm"]]:
    make_it_reproducible(seed)
    g.manual_seed(seed)
    
    testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=128, shuffle=False, num_workers=2,
                                         worker_init_fn=seed_worker, generator=g)
    user_groups, _ = get_user_groups(trainset, iid=iid, unbalanced=unbalanced, tot_users=tot_users)
    
    global_net = ResNet50(norm)
    global_net.to(device)
    global_net.train()
    global_weights = global_net.state_dict()
    
    for round in range(ROUNDS):
        local_weights, counts, local_losses, global_losses = [], [], [], []

        global_net.train()
        m = max(int(selection_fraction * tot_users), 1)
        selected_users = np.random.choice(range(tot_users), m, replace=False)

        for idx in selected_users:
            local_net = LocalUpdate(dataset=trainset, idxs=user_groups[idx], local_batch_size=local_batch_size,\
                local_epochs=local_epochs, worker_init_fn=seed_worker(seed), generator=g, device=device)
            w, loss = local_net.update_weights(model=copy.deepcopy(global_net))
            counts.append(len(user_groups[idx]))

            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))
        train_loss_avg = sum(local_losses) / len(local_losses)

        global_weights = average_weights(local_weights, counts)
        global_net.load_state_dict(global_weights)

        global_net.eval()
        total, correct = 0, 0 
        with torch.no_grad():
            for x, y in testloader:
                x, y = x.to(device), y.to(device)
                yhat = global_net(x)
                _, predicted = torch.max(yhat.data, 1)
                global_losses.append(loss_fn(yhat, y).item())
                total += y.size(0)
                correct += (predicted == y).sum().item()
        test_loss_avg = sum(global_losses) / len(global_losses)
        test_accuracy = correct / total
        
        metrics.append({
            "norm": "BN" if norm == "Batch Norm" else "GN",
            "independence": "iid" if iid else "noniid",
            "balancement": "unbalanced" if unbalanced else "balanced",
            "seed": seed,
            "round": round + 1,
            "test_accuracy": test_accuracy,
            "test_loss": test_loss_avg,
            "train_loss": train_loss_avg
        })
        print(round, test_accuracy, train_loss_avg, test_loss_avg)
            
df = pd.DataFrame(metrics)
df.to_csv("./results/federated_baseline/fedavg_results.csv", index=False)