# Federated Learning, Local Implementation

## Preliminaries

In [None]:
# Libraries to be imported, please make sure you have them installed

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, F1Score
from torchvision import datasets, transforms
from copy import deepcopy
import numpy as np
import torchvision.models as models
from tqdm.autonotebook import tqdm


: 

In PyTorch, we can use the GPU to accelerate the training process. We can check if a GPU is available and set the device accordingly. This will allow us to move the data and the model to the GPU. We check if a GPU is available and set the device accordingly. We also set the random seed for reproducibility.



In [2]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Random seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
def resnet18(num_classes, **kwargs):
    # original_model = models.resnet18(weights=ResNet18_Weights.DEFAULT, **kwargs)
    original_model = models.resnet18(**kwargs)

    # Replace the first convolutional layer
    original_model.conv1 = nn.Conv2d(
        in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False
    )

    # Remove the max pooling layer
    original_model.maxpool = nn.Identity()

    # Replace the fully connected layer
    original_model.fc = nn.Linear(in_features=512, out_features=num_classes)

    return original_model

In [None]:
def uniform_allocation(Y, num_clients):
    client_indices = []

    # Randomly shuffle indices
    indices = np.arange(len(Y))
    np.random.shuffle(indices, )
    indices_split = np.array_split(indices, num_clients)
    client_indices = [list(idx) for idx in indices_split]

    return client_indices

In [None]:
# Hyperparameters
num_clients = 10 # Number of clients
batch_size = 64  # Batch size for training and testing
global_epochs = 30  # Number of global epochs
local_epochs = 3 # Number of local epochs (round size)
learning_rate = 1e-2  # Learning rate for the optimizer
loss_fn = nn.CrossEntropyLoss()

In [None]:
# Model
model = resnet18(10)
client_models = [
    deepcopy(model).to(device) for _ in range(num_clients)
]
client_optims = [optim.SGD(model.parameters(), lr=learning_rate) for cm in client_models]

# CIFAR-10 Dataset and Dataloaders

# Image transformations
transform = transforms.Compose(
    [
        transforms.ToTensor(), # 1) Convert images to tensors
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # 2) Normalize the dataset
    ]
)

# ToDo: Download the training and test datasets
train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
train_subsets = uniform_allocation(
    train_dataset.targets,
    num_clients,
)
train_subsets = [
    torch.utils.data.Subset(train_dataset, indices) for indices in train_subsets
]
train_subset_dataloaders = [
    DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_subsets
]

test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



# ToDo: Federated Learning with FedAvg

### Tasks

1. Implement the FedAVG algorithm for federated learning. We do not want a version that uses a step size, i.e., a fixed number of SGD steps before averaging the models. Instead, we want a version that uses a fixed number of communication rounds. One communication round equals a local epoch on each client. We use `local_epoch` (also referred to as "round size") to denote the number of local epochs. 
2. After each communication round, we want to evaluate the global model on the test set. After averaging the models, we evaluate the global model on the test set. We want to store the test loss, test accuracy, and test F1 score for each communication round and plot them at the end. 
3. Use different values for `num_clients` and `local_epoch` and compare the results. What happens if you increase the number of clients? What happens if you increase the number of local epochs? Consider both the performance and training time.