# Project Lama Into the Wild:Federated Learning
<b>Group Number: 13</b><br>
<b>Name Group Member 1: Léo Brucker</b><br>
<b>u-Kürzel Group Member 1: uhugu</b><br>
<b>Name Group Member 2: Cyril Rudolph</b><br>
<b>u-Kürzel Group Member 2: udjvh</b>

This file is a Template for creating a new scenario for Testing.

Just change the parameters, run the simulation, the output and paramaters will
automatically be saved in an excel-sheet (you will have to change the location
and create the excel-file before it works). Next step would be to automate the
Param Optimization.

# Preparation

## Installs

In [None]:
!pip install -q flwr[simulation] torch torchvision flwr_datasets
!pip install --upgrade flwr jax
!pip install openpyxl

## Imports

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import EMNIST
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import math
import random
from google.colab import drive
from datetime import datetime
import copy
from torch.utils.data import Subset
import openpyxl
from openpyxl import load_workbook


from collections import OrderedDict
from typing import Dict, Tuple
from flwr.common import NDArrays, Scalar
from collections import Counter

import flwr as fl
from flwr_datasets import FederatedDataset

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)

# Initial SetUp ofr Google Drive Sign-In
drive.mount('/content/drive')

# Parameters
Change following parameters to create your own simulation.

In [None]:
EMNIST_SPLITS = {
    "balanced": 47,
    "bymerge": 47,
    "byclass": 62,
    "letters": 26,
    "digits": 10,
    "mnist": 10
}

##   MAIN  PARAMS   ##
######################
CENTRALIZED = False  ## False if Federated
SPLIT = "digits" # Choose from "digits, balanced"...
MODEL = "internet" # Choose from "chatgpt", "internet", "tutorial"
######################

## OTHER GLOBAL PARAMS ##
PROGRESS_BAR = CENTRALIZED # set to "CENTRALIZED" to only progress bar then
USE_GPU = True # False for CPU
VALIDATION_SPLIT = 0.1
OPTIMIZER = "adam" # Choose from "sgd" and "adam"
CRITERION = nn.CrossEntropyLoss()
VISUALIZE_N = 8 # Amount of visualized images

### CENTRALIZED PARAMS ###
LR_CENTRALIZED = 0.001  # Learning rate
MOM_CENTRALIZED = 0.9 # Momentum
EPOCHS_CENTRALIZED = 50 # Epochs
BATCH_SIZE_CENTRALIZED = 256
PATIENCE = 3

### FEDERATED PARAMS ###
IMBALANCE_PERCENTAGE = 0  # set to 0 to use original dataset
KEEP_INTACT_PERCENTAGE = 0.3

LR_FEDERATED = 0.001  # Learning rate
MOM_FEDERATED = 0.9 # Momentum ?
EPOCHS_FEDERATED = 1 # Epochs
BATCH_SIZE_FEDERATED = 128
NUM_CLIENTS = 100
FRACTION_TRAIN = 0.2 # Percentage of clients chosen to train each round
FRACTION_VALIDATE = 0.2 # Percentage of clients chosen to validate each round
STRATEGY = "FedAvg" # Choose from "FedAvg", "FedAdam", "FedProx"


# ---------- DO NOT CHANGE ------------ #
NUM_CLASSES = EMNIST_SPLITS[SPLIT]
now = datetime.now()
SAVE_DATE = now.strftime("%Y%d%m_%H%M")

# Functions

## Dataset

### Raw dataset

In [None]:
def get_dataset():
    data_path = "./data"
    # get the transforms
    emnist_train = EMNIST(root=data_path, split=SPLIT, train = True, download=True)
    mean = emnist_train.data.float().mean() / 255
    std = emnist_train.data.float().std() / 255
    EMNIST_TRANSFORM = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])
    # prepare train and test set
    trainset = EMNIST(data_path, split=SPLIT, train=True, download=True, transform=EMNIST_TRANSFORM)
    testset = EMNIST(data_path, split=SPLIT, train=False, download=True, transform=EMNIST_TRANSFORM)

    return trainset, testset

def get_testloader():
    _, testset = get_dataset()
    return DataLoader(testset, batch_size=BATCH_SIZE_FEDERATED)

### Federated Dataset

In [None]:
# Create Federated Dataset
def get_federated_loaders():
    trainset, testset = get_dataset()
    global num_images
    num_images = len(trainset) // NUM_CLIENTS
    partition_len = [num_images] * NUM_CLIENTS

    trainset_adjusted = Subset(trainset, range(NUM_CLIENTS * num_images))

    print("Number of Images per client before imblaning:", num_images)
    print("Total images among clients:", num_images*NUM_CLIENTS)

    trainsets = random_split(trainset_adjusted, partition_len)

    trainsets = imbalance_sets(trainsets)

    #
    global TRAINSETS
    TRAINSETS = trainsets

    # create dataloaders with train+val support
    trainloaders = []
    valloaders = []
    for trainset_ in trainsets:
        num_total = len(trainset_)
        num_val = int(VALIDATION_SPLIT * num_total)
        num_train = num_total - num_val

        for_train, for_val = random_split(trainset_, [num_train, num_val])

        trainloaders.append(
            DataLoader(for_train, batch_size=BATCH_SIZE_FEDERATED, shuffle=True, num_workers=2))
        valloaders.append(
            DataLoader(for_val, batch_size=BATCH_SIZE_FEDERATED, shuffle=False, num_workers=2))

    testloader = DataLoader(testset, batch_size=BATCH_SIZE_FEDERATED)

    return trainloaders, valloaders, testloader

def imbalance_sets(sets):
    amount_to_remove = int(len(sets)*IMBALANCE_PERCENTAGE)
    for i in range(amount_to_remove):
        for set in sets:
            random_idx = random.randint(0, len(set) - 1)
            set.pop(random_idx)
    return sets


## CNN Models


In [None]:
class TutorialNet(nn.Module):
    def __init__(self) -> None:
        super(TutorialNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, NUM_CLASSES)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def __repr__(self):
        return "TutorialNet"

class InternetNet(nn.Module):
    def __init__(self, fmaps1 = 40, fmaps2 = 160, dense = 200, dropout = 0.4):
        super(InternetNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=fmaps1, kernel_size=5, stride=1, padding='same'),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=fmaps1, out_channels=fmaps2, kernel_size=5, stride=1, padding='same'),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.fcon1 = nn.Sequential(nn.Linear(49*fmaps2, dense), nn.LeakyReLU())
        self.fcon2 = nn.Linear(dense, NUM_CLASSES)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(self.fcon1(x))
        x = self.fcon2(x)
        return x

    def __repr__(self):
        return "InternetNet"


class ChatGPTNet(nn.Module):
    def __init__(self):
        super(ChatGPTNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, NUM_CLASSES)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

    def __repr__(self):
        return "CHA-GPTNet"

def get_model():
    if MODEL == "chatgpt":
        return ChatGPTNet()
    elif MODEL == "internet":
        return InternetNet()
    elif MODEL == "tutorial":
        return TutorialNet()

## Train and Test Loops


In [None]:
def train_centralized(net, optimizer, trainloader, valloader, epochs):
    print("----- * * * * Training Starting * * * * -----")
    valid_accuracy_hist, loss_hist = [], []
    best_accuracy = 0.0
    best_model = None
    no_improve_epochs = 0

    for epoch in range(epochs):
        running_loss = 0.0
        tqdm_trainloader = tqdm(trainloader, desc=f"Epoch {epoch + 1}/{epochs}", unit="batch")
        for i, (images, labels) in enumerate(tqdm_trainloader):
            if USE_GPU:
                images, labels =  images.to(DEVICE), labels.to(DEVICE)

            outputs = net(images)
            loss = CRITERION(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 5:
                tqdm_trainloader.set_postfix({'loss': running_loss / len(tqdm_trainloader)})

        _, running_accuracy = test(net, valloader)
        valid_accuracy_hist.append(running_accuracy)
        loss_hist.append(running_loss)
        print(f"Epoch {epoch + 1} completed. Loss: {running_loss}, Accuracy: {running_accuracy}")

        # Save the model if it has the best accuracy so far
        if running_accuracy > best_accuracy:
            best_accuracy = running_accuracy
            best_model = copy.deepcopy(net)
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1

        # Early stopping
        if no_improve_epochs >= PATIENCE:
            print("Early stopping triggered.")
            break

    history = (valid_accuracy_hist, loss_hist)
    print("----- * * * * Training Finished * * * * -----")
    return best_model, history


def train_federated(net, optimizer, trainloader):
    net.train()
    for images, labels in trainloader:
        if USE_GPU:
            images, labels =  images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        loss = CRITERION(net(images), labels)
        loss.backward()
        optimizer.step()
    return net

def test(net, testloader):
    correct, loss = 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            if USE_GPU:
                images, labels =  images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += CRITERION(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    accuracy = correct / len(testloader.dataset)
    return loss, accuracy

## Other Helper Functions

In [None]:
def get_distributions():
    local_distributions = []
    trainset, _ = get_dataset()
    for subset in TRAINSETS:
        labels = trainset.targets[subset.indices]
        unique, counts = torch.unique(labels, return_counts=True)
        distribution = torch.zeros(len(trainset.classes), dtype=torch.long)
        distribution[unique.long()] = counts
        local_distributions.append(distribution.float())

    global_distribution = torch.sum(torch.stack(local_distributions), dim=0).float()

    return global_distribution, local_distributions


In [None]:
def get_optim(model):
    if CENTRALIZED:
        lr = LR_CENTRALIZED
        mom = MOM_CENTRALIZED
    else:
        lr = LR_FEDERATED
        mom = MOM_FEDERATED
    if OPTIMIZER == "sgd":
        return optim.SGD(model.parameters(), lr=lr, momentum=mom)
    elif OPTIMIZER == "adam":
        return optim.Adam(model.parameters(), lr=lr)

In [None]:
# MID functions outside the model, as it is a "global"- Calculation for the parameters we put in
def calculate_MID(global_distribution):
    global_distribution, local_distributions = get_distributions()
    N = sum(global_distribution)  # Total number of instances
    MID = 0.0
    for c in range(NUM_CLASSES):
        n_c = global_distribution[c]  # Number of instances in class c
        MID += (n_c / float(N)) * math.log(NUM_CLASSES * n_c / float(N), NUM_CLASSES)
    return MID

In [None]:
# WCS now outside the eval-function (makes more sense, as the individual Client
# is unimportant, we only need the dataset (trainloader(s)))
def calculate_WCS(global_distribution, local_distributions):
    # Calculate WCS
    global_distribution = global_distribution.float()
    local_sum = 0
    g_norm_1 = torch.linalg.norm(global_distribution, dim=0, ord=1)
    g_norm_2 = torch.linalg.norm(global_distribution, dim=0, ord=2)

    for local_distribution in local_distributions:
        local_distribution = local_distribution.float()
        l_norm_1 = torch.linalg.norm(local_distribution, dim=0, ord=1)
        l_norm_2 = torch.linalg.norm(local_distribution, dim=0, ord=2)
        local_sum += ((l_norm_1/l_norm_2)*((global_distribution).dot(local_distribution)))

    # Mathematical definition of the WCS
    WCS = (1/(g_norm_1*g_norm_2))*local_sum

    return WCS

## Centralized

Format for history dict:
- model: latest model
- accuracies_centralized
- accuracies_federated
- losses_centralized
- losses_federated

In [None]:
def run_centralized():
    model = get_model()
    if USE_GPU:
        model.to(DEVICE)

    optim = get_optim(model)

    trainset, testset = get_dataset()

    train_size = int((1-VALIDATION_SPLIT) * len(trainset))
    val_size = len(trainset) - train_size

    trainset, valset = random_split(trainset, [train_size, val_size])

    trainloader = DataLoader(trainset, batch_size=BATCH_SIZE_CENTRALIZED, shuffle=True, num_workers=2)
    valloader = DataLoader(valset, batch_size=BATCH_SIZE_CENTRALIZED)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE_CENTRALIZED)

    trained_model, history = train_centralized(model, optim, trainloader, valloader, EPOCHS_CENTRALIZED)

    return {"model":trained_model, "accuracies_centralized": history[0], "accuracies_federated": [],
            "losses_centralized": history[1], "losses_federated": []}

## Federated

### Flower Client

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, trainloader, valloader) -> None:
        super().__init__()
        self.trainloader = trainloader
        self.valloader = valloader
        self.model = get_model().to(DEVICE)

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def get_parameters(self, config: Dict[str, Scalar]):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        optim = get_optim(self.model)
        _ = train_federated(self.model, optim, self.trainloader)
        return self.get_parameters({}), len(self.trainloader), {}

    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        self.set_parameters(parameters)
        loss, accuracy = test(self.model, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": accuracy}


### Other Federated Functionality


In [None]:
def get_evalulate_fn(testloader):
    """This is a function that returns a function. The returned
    function (i.e. `evaluate_fn`) will be executed by the strategy
    at the end of each round to evaluate the stat of the global
    model."""

    def evaluate_fn(server_round: int, parameters, config):
        """This function is executed by the strategy it will instantiate
        a model and replace its parameters with those from the global model.
        The, the model will be evaluate on the test set (recall this is the
        whole MNIST test set)."""

        model = get_model()
        model.to(DEVICE)

        # set parameters to the model
        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict, strict=True)

        loss, accuracy = test(model, testloader)


        return loss, {"accuracy": accuracy, "model": model}
    return evaluate_fn

def generate_client_fn(trainloaders, valloaders):
    def client_fn(cid: str):
        return FlowerClient(
            trainloader=trainloaders[int(cid)], valloader=valloaders[int(cid)]
        ).to_client()
    return client_fn

def aggregate_evaluate_metrics(metrics):
    # This function will receive a list of (num_examples, client_metric_dict) tuples
    # Extract accuracies from client_metric_dict and weights from num_examples
    accuracies = [client_metric_dict["accuracy"] for num_examples, client_metric_dict in metrics]
    weights = [num_examples for num_examples, client_metric_dict in metrics]

    # Weighted average of accuracies
    accuracy_aggregated = sum(a*w for a, w in zip(accuracies, weights)) / sum(weights)

    # Return a dict with the aggregated metrics
    return {"accuracy": accuracy_aggregated}


def get_strategy(testloader):
    if STRATEGY == "FedAvg":
        return fl.server.strategy.FedAvg(
            fraction_fit = FRACTION_TRAIN,
            fraction_evaluate = FRACTION_VALIDATE,
            min_available_clients = NUM_CLIENTS,  # total number of clients available in the experiment
            evaluate_fn=get_evalulate_fn(testloader),  # a callback to a function that the strategy can execute to evaluate the state of the global model on a centralised dataset
            initial_parameters=fl.common.ndarrays_to_parameters([val.cpu().numpy() for _, val in get_model().state_dict().items()]),
            evaluate_metrics_aggregation_fn=aggregate_evaluate_metrics,)
    elif STRATEGY == "FedAdam":
        return fl.server.strategy.FedAdam(
            fraction_fit = FRACTION_TRAIN,
            fraction_evaluate = FRACTION_VALIDATE,
            min_available_clients = NUM_CLIENTS,  # total number of clients available in the experiment
            evaluate_fn=get_evalulate_fn(testloader),  # a callback to a function that the strategy can execute to evaluate the state of the global model on a centralised dataset
            initial_parameters=fl.common.ndarrays_to_parameters([val.cpu().numpy() for _, val in get_model().state_dict().items()]),
            evaluate_metrics_aggregation_fn=aggregate_evaluate_metrics,)
    elif STRATEGY == "FedProx":
        return fl.server.strategy.FedProx(
            fraction_fit = FRACTION_TRAIN,
            fraction_evaluate = FRACTION_VALIDATE,
            min_available_clients = NUM_CLIENTS,  # total number of clients available in the experiment
            evaluate_fn=get_evalulate_fn(testloader),  # a callback to a function that the strategy can execute to evaluate the state of the global model on a centralised dataset
            initial_parameters=fl.common.ndarrays_to_parameters([val.cpu().numpy() for _, val in get_model().state_dict().items()]),
            evaluate_metrics_aggregation_fn=aggregate_evaluate_metrics,)

def get_metrics(history):
    # Extract the metrics from the history object

    model = history.metrics_centralized['model'][-1][1]  # Get the last model
    accuracies_centralized = [x[1] for x in history.metrics_centralized['accuracy']]
    accuracies_federated = [history.metrics_centralized['accuracy'][0][1]] + [x[1] for x in history.metrics_distributed['accuracy']]
    losses_centralized = [x[1] for x in history.losses_centralized]
    losses_federated = [history.losses_centralized[0][1]] + [x[1] * NUM_CLIENTS for x in history.losses_distributed]

    # Return the metrics in the desired format
    return {
        "model": model,
        "accuracies_centralized": accuracies_centralized,
        "accuracies_federated": accuracies_federated,
        "losses_centralized": losses_centralized,
        "losses_federated": losses_federated
    }



### Run Federated

In [None]:
def run_federated():
    trainloaders, valloaders, testloader = get_federated_loaders()
    client_fn_callback = generate_client_fn(trainloaders, valloaders)
    strategy = get_strategy(testloader)

    client_resources = None
    if DEVICE.type == "cuda":
        client_resources = {"num_gpus": 1}


    history = fl.simulation.start_simulation(
        client_fn=client_fn_callback,
        num_clients=NUM_CLIENTS,
        config=fl.server.ServerConfig(num_rounds=EPOCHS_FEDERATED),
        strategy=strategy,
        client_resources=client_resources)
    print(history)
    return get_metrics(history)


## Plots


In [None]:
def visualise_n_random_examples(testset, model):
    if VISUALIZE_N == 0:
        return
    # take n examples at random
    idx = list(range(len(testset)))
    random.shuffle(idx)
    idx = idx[:VISUALIZE_N]
    # construct canvas
    num_cols = 8
    num_rows = int(np.ceil(len(idx) / num_cols))
    fig, axs = plt.subplots(figsize=(16, 6), nrows=num_rows, ncols=num_cols)

    # display images on canvas
    for c_i, i in enumerate(idx):
        image, label = testset[i]
        image_squeezed = np.squeeze(image)
        image_input = image.unsqueeze(0).to(DEVICE)

        model.eval()
        with torch.no_grad():
            output = model(image_input)
            prediction = output.argmax(dim=1).item()
            confidence = F.softmax(output, dim=1)[0][prediction].item()

        correct = (label == prediction)
        color = 'green' if correct else 'red'

        axs.flat[c_i].imshow(image_squeezed, cmap="gray")
        axs.flat[c_i].set_title(f'True/Pred: [{label}/{prediction}] ({confidence*100:.0f}%)', color=color)
    plt.tight_layout()
    plt.show()

def save_params(test_results):
    if not CENTRALIZED:
        global_d, local_d = get_distributions()
        MID_result = calculate_MID(global_d)
        WCS_result = calculate_WCS(global_d,local_d)
        print(f"MID: {MID_result}, WCS: {WCS_result}")
        print("Images per client: " + str(num_images))
    with open(f'/content/drive/My Drive/LAMA_ItW/{SAVE_DATE}_Params_{MODEL}{SPLIT}NC{NUM_CLIENTS}BS{BATCH_SIZE_FEDERATED}EP{EPOCHS_FEDERATED}.txt', 'w') as f:
        text_1 = f'''
        Filename: {SAVE_DATE}_Params_{MODEL}{SPLIT}NC{NUM_CLIENTS}BS{BATCH_SIZE_FEDERATED}EP{EPOCHS_FEDERATED}
        ## Test results ##
        ACCURACY = {test_results[1]}
        LOSS = {test_results[0]}

        ##  MAIN  PARAMS ##
        CENTRALIZED = {CENTRALIZED} ## False if Federated
        SPLIT = {SPLIT}
        MODEL = {MODEL}

        ## OTHER GLOBAL PARAMS ##
        VALIDATION_SPLIT = {VALIDATION_SPLIT}
        OPTIMIZER = {OPTIMIZER}
        '''
        if CENTRALIZED:
            text_2 = f'''
        ### CENTRALIZED PARAMS ###
        LR_CENTRALIZED = {LR_CENTRALIZED}  # Learning rate
        MOM_CENTRALIZED = {MOM_CENTRALIZED} # Momentum
        EPOCHS_CENTRALIZED = {EPOCHS_CENTRALIZED} # Epochs
        BATCH_SIZE_CENTRALIZED = {BATCH_SIZE_CENTRALIZED}
            '''
        else:
            text_2 = f'''
        ## MID & WCS results ##
        MID = {MID_result}
        WCS = {WCS_result}
        ### FEDERATED PARAMS ###
        LR_FEDERATED = {LR_FEDERATED}  # Learning rate
        MOM_FEDERATED = {MOM_FEDERATED} # Momentum
        EPOCHS_FEDERATED = {EPOCHS_FEDERATED} # Epochs
        BATCH_SIZE_FEDERATED = {BATCH_SIZE_FEDERATED}
        NUM_CLIENTS = {NUM_CLIENTS}
        IMG_PER_CLIENT = {num_images}
        FRACTION_TRAIN = {FRACTION_TRAIN}
        FRACTION_VALIDATE = {FRACTION_VALIDATE}
            '''
        f.write(text_1)
        f.write(text_2)
        f.close()

def save_params_excel(test_results, acc_hist):
    if not CENTRALIZED:
        global_d, local_d = get_distributions()
        MID_result = calculate_MID(global_d).item()
        WCS_result = calculate_WCS(global_d,local_d).item()
        print(f"MID: {MID_result}, WCS: {WCS_result}")
        print("Images per client: " + str(num_images))
    # put results into an dict
    if CENTRALIZED:
        res = {'name': f'{SAVE_DATE}_Params_{MODEL}{SPLIT}NC{NUM_CLIENTS}BS{BATCH_SIZE_FEDERATED}EP{EPOCHS_FEDERATED}',
            'centralized': CENTRALIZED, 'accuracy': 'placeholder1',
            'loss': 'placeholder0', 'split': SPLIT, 'model': MODEL,
            'val_split': VALIDATION_SPLIT, 'optimizer': OPTIMIZER,
            'learning_rate': LR_CENTRALIZED, 'momentum': MOM_CENTRALIZED,
            'epochs': EPOCHS_CENTRALIZED, 'batchsize': BATCH_SIZE_CENTRALIZED,
            'mid': MID_result, 'wcs': WCS_result,
            'num_clients': 0, 'img_per_client': 0,
            'fraction_train': 0, 'fraction_val': 0,
            'acc_hist': str(acc_hist)
            }
    else:
        res = {'name': f'{SAVE_DATE}_Params_{MODEL}{SPLIT}NC{NUM_CLIENTS}BS{BATCH_SIZE_FEDERATED}EP{EPOCHS_FEDERATED}',
                'centralized': CENTRALIZED, 'accuracy': 'placeholder1',
                'loss': 'placeholder0', 'split': SPLIT, 'model': MODEL,
                'val_split': VALIDATION_SPLIT, 'optimizer': OPTIMIZER,
                'learning_rate': LR_FEDERATED, 'momentum': MOM_FEDERATED,
                'epochs': EPOCHS_FEDERATED, 'batchsize': BATCH_SIZE_FEDERATED,
                'mid': MID_result, 'wcs': WCS_result,
                'num_clients': NUM_CLIENTS, 'img_per_client': num_images,
                'fraction_train': FRACTION_TRAIN, 'fraction_val': FRACTION_VALIDATE,
                'acc_hist': str(acc_hist)
                }

    # Load the workbook
    wb = load_workbook(filename='/content/drive/My Drive/LAMA_ItW/Results.xlsx')
    # Select sheet & append data
    data_to_append = list(res.values())
    print(data_to_append)
    sheet = wb['results']
    sheet.append(data_to_append)
    wb.save(filename='/content/drive/My Drive/LAMA_ItW/Results.xlsx')

def plot_metrics(history):
    testloader = get_testloader()
    _, testset = get_dataset()
    model = history["model"]

    test_results = test(model, testloader)
    print(f'Final accuracy: {test_results[1]*100:.1f}')

    save_params(test_results)
    save_params_excel(test_results, history['accuracies_centralized'])

    # Check if federated metrics are not empty
    if history['accuracies_federated'] and history['losses_federated']:
        # Plot federated accuracies
        plt.figure(figsize=(10, 4))
        plt.plot(history['accuracies_federated'], label='Federated Accuracy')
        plt.plot(history['accuracies_centralized'], label='Centralized Accuracy')
        plt.title('Model Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.yticks([i/10 for i in range(11)])  # Add ticks every 0.1
        plt.legend()
        plt.savefig(f'/content/drive/My Drive/LAMA_ItW/{SAVE_DATE}_Fed_Acc_{MODEL}{SPLIT}NC{NUM_CLIENTS}BS{BATCH_SIZE_FEDERATED}EP{EPOCHS_FEDERATED}.png')
        plt.show()

        # Plot federated losses
        plt.figure(figsize=(10, 4))
        plt.plot(history['losses_federated'], label='Federated Loss')
        plt.plot(history['losses_centralized'], label='Centralized Loss')
        plt.title('Model Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend()
        plt.savefig(f'/content/drive/My Drive/LAMA_ItW/{SAVE_DATE}_Fed_Loss_{MODEL}{SPLIT}NC{NUM_CLIENTS}BS{BATCH_SIZE_FEDERATED}EP{EPOCHS_FEDERATED}.png')
        plt.show()
    else:
        # Plot centralized accuracies and losses only
        plt.figure(figsize=(10, 4))
        plt.plot(history['accuracies_centralized'], label='Centralized Accuracy')
        plt.title('Model Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.yticks([i/10 for i in range(11)])  # Add ticks every 0.1
        plt.legend()
        plt.savefig(f'/content/drive/My Drive/LAMA_ItW/{SAVE_DATE}_Cent_Acc_{MODEL}{SPLIT}NC{NUM_CLIENTS}BS{BATCH_SIZE_FEDERATED}EP{EPOCHS_FEDERATED}.png')
        plt.show()

        plt.figure(figsize=(10, 4))
        plt.plot(history['losses_centralized'], label='Centralized Loss')
        plt.title('Model Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend()
        plt.savefig(f'/content/drive/My Drive/LAMA_ItW/{SAVE_DATE}_Cent_Loss_{MODEL}{SPLIT}NC{NUM_CLIENTS}BS{BATCH_SIZE_FEDERATED}EP{EPOCHS_FEDERATED}.png')
        plt.show()

    # Print the final accuracy

    # Make a Save of the model Params as a textfile

    visualise_n_random_examples(testset, model)



# Simulation

In [None]:
if CENTRALIZED:
    history = run_centralized()
else:
    history = run_federated()

# Results


In [None]:
plot_metrics(history)