# Test1 federeted
Implementa l'algoritmo FedAvg.

Fissa K=100, C=0.1, e adotta una partizione iid del dataset di addestramento.

Esegui FedAvg su Shakespeare per 200 round senza alcun learning rate schedule.

# Import

In [1]:
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
import random
from torch.utils.data import Subset
from statistics import mean
import numpy as np
from collections import defaultdict
import re
from torchvision import datasets, transforms
import kagglehub

# Parameters

In [2]:
# Constants for FL training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

NUM_CLIENTS = 100  # Total number of clients in the federation
FRACTION_CLIENTS = 0.1  # Fraction of clients selected per round (C)
LOCAL_STEPS = 4  # Number of local steps (J)
GLOBAL_ROUNDS = 2000  # Total number of communication rounds

BATCH_SIZE = 4  # Batch size for local training
LR = 1e-2  # Initial learning rate for local optimizers
MOMENTUM = 0.9  # Momentum for SGD optimizer
WEIGHT_DECAY = 1e-4  # Regularization term for local training

LOG_FREQUENCY = 10  # Frequency of logging training progress

TRAIN_FRACTION = 0.9  # percentage of the training data
SEQ_LEN = 80  # length of the sequence for the model
BATCH_SIZE = 4
N_VOCAB = 90  # Numero di caratteri nel vocabolario (ASCII)
EPOCHS = 200
LEARNING_RATE = 0.01
EMBEDDING_SIZE = 8
LSTM_HIDDEN_DIM = 256
SEQ_LENGTH = 80

cuda


# Utility code

In [3]:
import torch
from statistics import mean
import torch.nn as nn

"""
Utility function used both in the centralized and federated learning
Computes the accuracy and the loss on the validation/test set depending on the dataloader passed
"""

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.NLLLoss()# our loss function for classification tasks on CIFAR-100
def evaluate(model, dataloader):
    with torch.no_grad():
        model.train(False) # Set Network to evaluation mode
        total_predictions = 0  # Track total predictions for normalization
        running_corrects = 0
        losses = []
        for data, targets in dataloader:
            data = data.to(DEVICE)        # Move the data to the GPU
            targets = targets.to(DEVICE)  # Move the targets to the GPU
            # Forward Pass
            #state = model.init_hidden(data.size(0))
            # outputs is a tuple: (logits, hidden_state)
            outputs, _ = model(data) # unpack the tuple and get only the output (predictions)
            # Reshape the outputs for CrossEntropyLoss
            outputs = outputs.view(-1, model.vocab_size)
            targets = targets.view(-1)
            loss = criterion(outputs, targets)
            losses.append(loss.item())
            # Get predictions
            _, preds = torch.max(outputs.data, 1)
            # Update Corrects (element-wise comparison for accuracy)
            running_corrects += (preds == targets).sum().item()
            total_predictions += targets.size(0)  # Update total prediction count

        # Calculate Accuracy (divide by total predictions)
        accuracy = (running_corrects / total_predictions) * 100

    return accuracy, mean(losses)

In [4]:
import os
import torch
import json
CHECKPOINT_DIR = "cartellaCheckpoint"
# Ensure the checkpoint directory exists, creating it if necessary
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def save_checkpoint(model, optimizer, epoch, hyperparameters, subfolder="", checkpoint_data=None):
    """
    Saves the model checkpoint and removes the previous one if it exists.

    Arguments:
    model -- The model whose state is to be saved.
    optimizer -- The optimizer whose state is to be saved (can be None).
    epoch -- The current epoch of the training process.
    hyperparameters -- A string representing the model's hyperparameters for file naming.
    subfolder -- Optional subfolder within the checkpoint directory to save the checkpoint.
    checkpoint_data -- Data to save in a JSON file (e.g., training logs).
    """
    # Define the path for the subfolder where checkpoints will be stored
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    # Create the subfolder if it doesn't exist
    os.makedirs(subfolder_path, exist_ok=True)

    # Construct filenames for both the model checkpoint and the associated JSON file
    filename = f"model_epoch_{epoch}_params_{hyperparameters}.pth"
    filepath = os.path.join(subfolder_path, filename)
    filename_json = f"model_epoch_{epoch}_params_{hyperparameters}.json"
    filepath_json = os.path.join(subfolder_path, filename_json)

    # Define the filenames for the previous checkpoint files, to remove them if necessary
    previous_filepath = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.pth")
    previous_filepath_json = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.json")

    # Remove the previous checkpoint if it exists, but only for epochs greater than 1
    if epoch > 1 and os.path.exists(previous_filepath):
        os.remove(previous_filepath)
        os.remove(previous_filepath_json)

    # Prepare the checkpoint data dictionary
    checkpoint = {'model_state_dict': model.state_dict(), 'epoch': epoch}
    # If an optimizer is provided, save its state as well
    if optimizer is not None:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()

    # Save the model and optimizer (if provided) state dictionary to the checkpoint file
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")

    # If additional data (e.g., training logs) is provided, save it to a JSON file
    if checkpoint_data:
        with open(filepath_json, 'w') as json_file:
            json.dump(checkpoint_data, json_file, indent=4)

def load_checkpoint(model, optimizer, hyperparameters, subfolder=""):
    """
    Loads the latest checkpoint available based on the specified hyperparameters.

    Arguments:
    model -- The model whose state will be updated from the checkpoint.
    optimizer -- The optimizer whose state will be updated from the checkpoint (can be None).
    hyperparameters -- A string representing the model's hyperparameters for file naming.
    subfolder -- Optional subfolder within the checkpoint directory to look for checkpoints.

    Returns:
    The next epoch to resume from and the associated JSON data if available.
    """
    # Define the path to the subfolder where checkpoints are stored
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)

    # If the subfolder doesn't exist, print a message and start from epoch 1
    if not os.path.exists(subfolder_path):
        print("No checkpoint found, starting from epoch 1.")
        return 1, None  # Epoch starts from 1

    # Search for checkpoint files in the subfolder that match the hyperparameters
    files = [f for f in os.listdir(subfolder_path) if f"params_{hyperparameters}" in f and f.endswith('.pth')]

    # If checkpoint files are found, load the one with the highest epoch number
    if files:
        latest_file = max(files, key=lambda x: int(x.split('_')[2]))  # Find the latest epoch file
        filepath = os.path.join(subfolder_path, latest_file)
        checkpoint = torch.load(filepath, weights_only=True)

        # Load the model state from the checkpoint
        model.load_state_dict(checkpoint['model_state_dict'])
        # If an optimizer is provided, load its state as well
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Try to load the associated JSON file if available
        json_filepath = os.path.join(subfolder_path, latest_file.replace('.pth', '.json'))
        json_data = None
        if os.path.exists(json_filepath):
            # If the JSON file exists, load its contents
            with open(json_filepath, 'r') as json_file:
                json_data = json.load(json_file)
            print("Data loaded!")
        else:
            # If no JSON file exists, print a message
            print("No data found")

        # Print the epoch from which the model is resuming
        print(f"Checkpoint found: Resuming from epoch {checkpoint['epoch'] + 1}\n\n")
        return checkpoint['epoch'] + 1, json_data

    # If no checkpoint is found, print a message and start from epoch 1
    print("No checkpoint found, starting from epoch 1..\n\n")
    return 1, None  # Epoch starts from 1

def delete_existing_checkpoints(subfolder=""):
    """
    Deletes all existing checkpoints in the specified subfolder.

    Arguments:
    subfolder -- Optional subfolder within the checkpoint directory to delete checkpoints from.
    """
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    if os.path.exists(subfolder_path):
        for file_name in os.listdir(subfolder_path):
            file_path = os.path.join(subfolder_path, file_name)
            if os.path.isfile(file_path):
                os.remove(file_path)
        print(f"All existing checkpoints in {subfolder_path} have been deleted.")
    else:
        print(f"No checkpoint folder found at {subfolder_path}.")

## Code for data loading

In [5]:
CHARACTER_RE = re.compile(r'^  ([a-zA-Z][a-zA-Z ]*)\. (.*)')  # Matches character lines
CONT_RE = re.compile(r'^    (.*)')  # Matches continuation lines
COE_CHARACTER_RE = re.compile(r'^([a-zA-Z][a-zA-Z ]*)\. (.*)')  # Special regex for Comedy of Errors
COE_CONT_RE = re.compile(r'^(.*)')  # Continuation for Comedy of Errors

def parse_shakespeare_file(filepath):
    """
    Reads and splits Shakespeare's text into plays, characters, and their dialogues.
    Returns training and test datasets based on the specified fraction.
    """
    with open(filepath, "r") as f:
        content = f.read()
    plays, _ = _split_into_plays(content)  # Split the text into plays
    _, train_examples, test_examples = _get_train_test_by_character(
        plays, test_fraction=1 - TRAIN_FRACTION
    )
    return train_examples, test_examples

def _split_into_plays(shakespeare_full):
    """
    Splits the full Shakespeare text into individual plays and characters' dialogues.
    Handles special parsing for "The Comedy of Errors".
    """
    plays = []
    slines = shakespeare_full.splitlines(True)[1:]  # Skip the first line (title/header)
    current_character = None
    comedy_of_errors = False

    for i, line in enumerate(slines):
        # Detect play titles and initialize character dictionary
        if "by William Shakespeare" in line:
            current_character = None
            characters = defaultdict(list)
            title = slines[i - 2].strip() if slines[i - 2].strip() else slines[i - 3].strip()
            comedy_of_errors = title == "THE COMEDY OF ERRORS"
            plays.append((title, characters))
            continue

        # Match character lines or continuation lines
        match = _match_character_regex(line, comedy_of_errors)
        if match:
            character, snippet = match.group(1).upper(), match.group(2)
            if not (comedy_of_errors and character.startswith("ACT ")):
                characters[character].append(snippet)
                current_character = character
        elif current_character:
            match = _match_continuation_regex(line, comedy_of_errors)
            if match:
                characters[current_character].append(match.group(1))

    # Filter out plays with insufficient dialogue data
    return [play for play in plays if len(play[1]) > 1], []

def _match_character_regex(line, comedy_of_errors=False):
    """Matches character dialogues, with special handling for 'The Comedy of Errors'."""
    return COE_CHARACTER_RE.match(line) if comedy_of_errors else CHARACTER_RE.match(line)

def _match_continuation_regex(line, comedy_of_errors=False):
    """Matches continuation lines of dialogues."""
    return COE_CONT_RE.match(line) if comedy_of_errors else CONT_RE.match(line)

def _get_train_test_by_character(plays, test_fraction=0.2):
    """
    Splits dialogues by characters into training and testing datasets.
    Ensures each character has at least one example in the training set.
    """
    all_train_examples = defaultdict(list)
    all_test_examples = defaultdict(list)

    def add_examples(example_dict, example_tuple_list):
        """Adds examples to the respective dataset dictionary."""
        for play, character, sound_bite in example_tuple_list:
            example_dict[f"{play}_{character}".replace(" ", "_")].append(sound_bite)

    for play, characters in plays:
        for character, sound_bites in characters.items():
            examples = [(play, character, sound_bite) for sound_bite in sound_bites]
            if len(examples) <= 2:
                continue

            # Calculate the number of test samples
            num_test = max(1, int(len(examples) * test_fraction))
            num_test = min(num_test, len(examples) - 1)  # Ensure at least one training example

            # Split into train and test sets
            train_examples = examples[:-num_test]
            test_examples = examples[-num_test:]

            add_examples(all_train_examples, train_examples)
            add_examples(all_test_examples, test_examples)

    return {}, all_train_examples, all_test_examples


def letter_to_vec(c, n_vocab=128):
    """Converts a single character to a vector index based on the vocabulary size."""
    return ord(c) % n_vocab

def word_to_indices(word, n_vocab=128):
    """
    Converts a word or list of words into a list of indices.
    Each character is mapped to an index based on the vocabulary size.
    """
    if isinstance(word, list):  # If input is a list of words
        res = []
        for stringa in word:
            res.extend([ord(c) % n_vocab for c in stringa])  # Convert each word to indices
        return res
    else:  # If input is a single word
        return [ord(c) % n_vocab for c in word]

def process_x(raw_x_batch, seq_len, n_vocab):
    """
    Processes raw input data into padded sequences of indices.
    Ensures all sequences are of uniform length.
    """
    x_batch = [word_to_indices(word, n_vocab) for word in raw_x_batch]
    x_batch = [x[:seq_len] + [0] * (seq_len - len(x)) for x in x_batch]
    return torch.tensor(x_batch, dtype=torch.long)


def process_y(raw_y_batch, seq_len, n_vocab):
    """
    Processes raw target data into padded sequences of indices.
    Shifts the sequence by one character to the right.
    y[1:seq_len + 1] takes the input data, right shift of an
    element and uses the next element of the sequence to fill
    and at the end (with [0]) final padding (zeros) are (eventually)
    added to reach the desired sequence length.
    """
    y_batch = [word_to_indices(word, n_vocab) for word in raw_y_batch]
    y_batch = [y[1:seq_len + 1] + [0] * (seq_len - len(y[1:seq_len + 1])) for y in y_batch]  # Shifting and final padding
    return torch.tensor(y_batch, dtype=torch.long)

def create_batches(data, batch_size, seq_len, n_vocab):
    """
    Creates batches of input and target data from dialogues.
    Each batch contains sequences of uniform length.
    """
    x_batches = []
    y_batches = []
    dialogues = list(data.values())
    random.shuffle(dialogues)  # Shuffle to ensure randomness in batches

    batch = []
    for dialogue in dialogues:
        batch.append(dialogue)
        if len(batch) == batch_size:
            x_batch = process_x(batch, seq_len, n_vocab)
            y_batch = process_y(batch, seq_len, n_vocab)
            x_batches.append(x_batch)
            y_batches.append(y_batch)
            batch = []

    # Add the last batch if it's not full
    if batch:
        x_batch = process_x(batch, seq_len, n_vocab)
        y_batch = process_y(batch, seq_len, n_vocab)
        x_batches.append(x_batch)
        y_batches.append(y_batch)

    return x_batches, y_batches



# Client class

In [6]:
from torch.backends import cudnn
import time


class Client:
    def __init__(self, client_id, data_loader, model, device):
        """
        Initializes a federated learning client.
        :param client_id: Unique identifier for the client.
        :param data_loader: Data loader specific to the client.
        :param model: The model class to be used by the client.
        :param device: The device (CPU/GPU) to perform computations.
        """
        self.client_id = client_id
        self.data_loader = data_loader
        self.model = model.to(device)
        self.device = device

    def client_update(self, client_data, criterion, optimizer, local_steps=4, detailed_print=False):
        """
        Trains a given client's local model on its dataset for a fixed number of steps (`local_steps`).

        Args:
            model (nn.Module): The local model to be updated.
            client_id (int): Identifier for the client (used for logging/debugging purposes).
            client_data (DataLoader): The data loader for the client's dataset.
            criterion (Loss): The loss function used for training (e.g., CrossEntropyLoss).
            optimizer (Optimizer): The optimizer used for updating model parameters (e.g., SGD).
            local_steps (int): Number of local epochs to train on the client's dataset.
            detailed_print (bool): If True, logs the final loss after training.

        Returns:
            dict: The state dictionary of the updated model.
        """


        cudnn.benchmark  # Calling this optimizes runtime

        self.model.train()  # Set the model to training mode
        step_count = 0
        total_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        while step_count < local_steps:
            x_batches, y_batches = create_batches(train_data, BATCH_SIZE, SEQ_LEN, N_VOCAB)
            for x_batch, y_batch in zip(x_batches, y_batches):

                x_batch = x_batch.to(DEVICE)  # Move the data to the GPU
                y_batch = y_batch.to(DEVICE)  # Move the targets to the GPU
                # Move data and targets to the specified device (e.g., GPU or CPU)
                #data, targets = data.to(self.device), targets.to(self.device)


                start_time = time.time()  # for testing-----------------------------

                # Reset the gradients before backpropagation
                optimizer.zero_grad()

                # Forward pass: compute model predictions
                logits, _ = self.model(x_batch)

                logits = logits.view(-1, N_VOCAB)  # Reshape to (batch_size * seq_length, vocab_size)
                y_batch = y_batch.view(-1)  # Reshape to (batch_size * seq_length)

                # Compute the loss
                loss = criterion(logits, y_batch)

                # Backward pass: compute gradients and update weights
                loss.backward()
                optimizer.step()

                #---------- Accumulate metrics
                #  Accumulates the weighted loss for the number of samples in the batch to account for any variation in
                #  batch size due to, for example, the smaller final batch. A little too precise? :)
                total_loss += loss.item() * x_batch.size(0)
                _, predicted = logits.max(1)
                correct_predictions += predicted.eq(y_batch).sum().item()
                total_samples += x_batch.size(0)

                step_count +=1
                if step_count >= local_steps:
                  break

        #---------- Compute averaged metrics
        avg_loss = total_loss / total_samples
        avg_accuracy = correct_predictions / total_samples * 100

        # Optionally, print the loss for the last epoch of training
        if detailed_print:
          print(f'Client {self.client_id} --> Final Loss (Step {step_count}/{local_steps}): {loss.item()}')


        # Return the updated model's state dictionary (weights)
        return self.model.state_dict(), avg_loss, avg_accuracy

In [7]:
import torch
import numpy as np
import os
import matplotlib.pyplot as plt

DIR = '/content/drive/MyDrive/colab_plots/'

def client_selection(number_of_clients, clients_fraction, gamma=None):
    """
    Selects a subset of clients based on uniform or skewed distribution.

    Args:
    number_of_clients (int): Total number of clients.
    clients_fraction (float): Fraction of clients to be selected.
    uniform (bool): If True, selects clients uniformly. If False, selects clients based on a skewed distribution.
    gamma (float): Hyperparameter for the Dirichlet distribution controlling the skewness (only used if uniform=False).

    Returns:
    list: List of selected client indices.
    """
    num_clients_to_select = int(number_of_clients * clients_fraction)

    if gamma is None:
        # Uniformly select clients without replacement
        selected_clients = np.random.choice(number_of_clients, num_clients_to_select, replace=False)
    else:
        # Generate skewed probabilities using a Dirichlet distribution
        probabilities = np.random.dirichlet(np.ones(number_of_clients) * gamma)
        selected_clients = np.random.choice(number_of_clients, num_clients_to_select, replace=False, p=probabilities)

    return selected_clients


def plot_client_selection(client_selection_count, file_name):
    """
    Bar plot to visualize the frequency of client selections in a federated learning simulation.

    Args:
        client_selection_count (list): list containing the number of times each client was selected.
        file_name (str): name of the file to save the plot.
    """
    # Fixed base directory
    directory = DIR +  'plots_federated/'
    # Ensure the base directory exists
    os.makedirs(directory, exist_ok=True)

    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    num_clients = len(client_selection_count)
    plt.figure(figsize=(10, 6))
    plt.bar(range(num_clients), client_selection_count, alpha=0.7, edgecolor='black')
    plt.xlabel("Client ID", fontsize=14)
    plt.ylabel("Selection Count", fontsize=14)
    plt.title("Client Selection Frequency", fontsize=16)
    plt.xticks(range(num_clients), fontsize=10, rotation=90 if num_clients > 20 else 0)
    plt.tight_layout()
    plt.savefig(file_path, format="png", dpi=300)
    plt.close()

def test(global_model, test_loader):
    """
    Evaluate the global model on the test dataset.

    Args:
        global_model (nn.Module): The global model to be evaluated.
        test_loader (DataLoader): DataLoader for the test dataset.

    Returns:
        float: The accuracy of the model on the test dataset.
    """
    test_accuracy, _ = evaluate(global_model, test_loader)
    return test_accuracy

def plot_metrics(train_accuracies, train_losses, val_accuracies,val_losses, file_name):
    """
    Plot the training and validation metrics for a federated learning simulation.

    Args:
        train_accuracies (list): List of training accuracies.
        train_losses (list): List of training losses.
        val_accuracies (list): List of validation accuracies.
        val_losses (list): List of validation losses.
        file_name (str): Name of the file to save the plot.
    """
    # Fixed base directory
    directory = DIR + '/plots_federated/'
    # Ensure the base directory exists
    os.makedirs(directory, exist_ok=True)

    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    # Create a list of epochs for the x-axis
    epochs = list(range(1, len(train_losses) + 1))

    # Plot the training and validation losses
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses, label='Train Loss', color='blue')
    plt.plot(epochs, val_losses, label='Validation Loss', color='red')
    plt.xlabel('Rounds', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.title('Training and Validation Loss', fontsize=16)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(file_path.replace('.png', '_loss.png'), format='png', dpi=300)
    plt.close()

    # Plot the training and validation accuracies
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_accuracies, label='Train Accuracy', color='blue')
    plt.plot(epochs, val_accuracies, label='Validation Accuracy', color='red')
    plt.xlabel('Rounds', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title('Training and Validation Accuracy', fontsize=16)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(file_path.replace('.png', '_accuracy.png'), format='png', dpi=300)
    plt.close()


def save_data(global_model, val_accuracies, val_losses, train_accuracies, train_losses,client_count, file_name):
    """
    Save the global model, val_accuracies, val_losses, train_accuracies,train_losses and client_count to a file.

    Args:
        global_model (nn.Module): The global model to be saved.
        val_accuracies (list): List of validation accuracies.
        val_losses (list): List of validation losses.
        train_accuracies (list): List of training accuracies.
        train_losses (list): List of training losses.
        file_name (str): Name of the file to save the data.
    """
    # Fixed base directory
    directory = DIR + '/trained_models/'
    # Ensure the base directory exists
    os.makedirs(directory, exist_ok=True)

    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    # Save all data (model state and metrics) into a dictionary
    save_dict = {
        'model_state': global_model.state_dict(),
        'val_accuracies': val_accuracies,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'train_losses': train_losses,
        'client_count': client_count
    }

    # Save the dictionary to the specified file
    torch.save(save_dict, file_path)
    print(f"Data saved successfully to {file_path}")

def load_data(model, file_name):
    """
    Load the model weights and metrics from a file.

    Args:
        model (nn.Module): The model to load the weights into.
        file_name (str): Name of the file to load the data from.

    Returns:
        tuple: A tuple containing the model, val_accuracies, val_losses, train_accuracies train_losses and client_count.
    """
    # Fixed base directory
    directory = DIR+ 'trained_models/'
    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    # Load the saved data from the specified file
    save_dict = torch.load(file_path)

    # Load the model state
    model.load_state_dict(save_dict['model_state'])

    # Extract the metrics
    val_accuracies = save_dict['val_accuracies']
    val_losses = save_dict['val_losses']
    train_accuracies = save_dict['train_accuracies']
    train_losses = save_dict['train_losses']
    client_count = save_dict['client_count']

    print(f"Data loaded successfully from {file_path}")

    return model, val_accuracies, val_losses, train_accuracies, train_losses,client_count

# Server class

In [8]:
import torch
import torch.optim as optim
from copy import deepcopy
import numpy as np
from torch.utils.data import Subset
import os
from torch.utils.data import DataLoader, Subset


class Server:
    def __init__(self, global_model, device, CHECKPOINT_DIR):
        self.global_model = global_model
        self.device = device
        self.CHECKPOINT_DIR = CHECKPOINT_DIR
        # Ensure the checkpoint directory exists, creating it if necessary
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    def fedavg_aggregate(self, client_states, client_sizes, client_avg_losses, client_avg_accuracies):
        """
        Aggregates model updates and client metrics from selected clients using the Federated Averaging (FedAvg) algorithm.
        The updates and metrics are weighted by the size of each client's dataset.

        Args:
            global_model (nn.Module): The global model whose structure is used for aggregation.
            client_states (list[dict]): A list of state dictionaries (model weights) from participating clients.
            client_sizes (list[int]): A list of dataset sizes for the participating clients.
            client_avg_losses (list[float]): A list of average losses for the participating clients.
            client_avg_accuracies (list[float]): A list of average accuracies for the participating clients.

        Returns:
            tuple: The aggregated state dictionary with updated model parameters, global average loss, and global average accuracy.
        """
        # Copy the global model's state dictionary for aggregation
        new_state = deepcopy(self.global_model.state_dict())

        # Calculate the total number of samples across all participating clients
        total_samples = sum(client_sizes)

        # Initialize all parameters in the new state to zero
        for key in new_state:
            new_state[key] = torch.zeros_like(new_state[key])

        # Initialize metrics
        total_loss = 0.0
        total_accuracy = 0.0

        # Perform a weighted average of client updates and metrics
        for state, size, avg_loss, avg_accuracy in zip(client_states, client_sizes, client_avg_losses, client_avg_accuracies):
            for key in new_state:
                # Add the weighted contribution of each client's parameters
                new_state[key] += (state[key] * size / total_samples)
            total_loss += avg_loss * size
            total_accuracy += avg_accuracy * size

        # Calculate global metrics
        global_avg_loss = total_loss / total_samples
        global_avg_accuracy = total_accuracy / total_samples

        # Return the aggregated state dictionary with updated weights and global metrics
        return new_state, global_avg_loss, global_avg_accuracy


    # Federated Learning Training Loop
    def train_federated(self, criterion, trainloader, validloader, num_clients, num_classes, rounds, lr, momentum, batchsize, wd, C=0.1, local_steps=4, log_freq=100, detailed_print=False,gamma=None):
        val_accuracies = []
        val_losses = []
        train_accuracies = []
        train_losses = []
        best_model_state = None  # The model with the best accuracy
        client_selection_count = [0] * num_clients #Count how many times a client has been selected
        best_val_acc = 0.0

        shards = self.sharding(trainloader.dataset, num_clients, num_classes) #each shard represent the training data for one client
        client_sizes = [len(shard) for shard in shards]

        self.global_model.to(self.device) #as alwayse, we move the global model to the specified device (CPU or GPU)

        #loading checkpoint if it exists
        checkpoint_start_step, data_to_load = load_checkpoint(model=self.global_model,optimizer=None,hyperparameters=f"LR{lr}_WD{wd}", subfolder="Federated/")
        if data_to_load is not None:
          val_accuracies = data_to_load['val_accuracies']
          val_losses = data_to_load['val_losses']
          train_accuracies = data_to_load['train_accuracies']
          train_losses = data_to_load['train_losses']
          client_selection_count = data_to_load['client_selection_count']


        # ********************* HOW IT WORKS ***************************************
        # The training runs for rounds iterations (GLOBAL_ROUNDS=2000)
        # Each round simulates one communication step in federated learning, including:
        # 1) client selection
        # 2) local training (of each client)
        # 3) central aggregation
        for round_num in range(checkpoint_start_step, rounds):
            if (round_num+1) % log_freq == 0 and detailed_print:
              print(f"------------------------------------- Round {round_num+1} ------------------------------------------------" )

            # 1) client selection: In each round, a fraction C (e.g., 10%) of clients is randomly selected to participate.
            #     This reduces computation costs and mimics real-world scenarios where not all devices are active.
            selected_clients = client_selection(num_clients, C,gamma)
            client_states = []
            client_avg_losses = []
            client_avg_accuracies = []
            for client_id in selected_clients:
                client_selection_count[client_id] += 1

            # 2) local training: for each client updates the model using the client's data for local_steps epochs
            for client_id in selected_clients:
                local_model = deepcopy(self.global_model) #it creates a local copy of the global model
                optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #same of the centralized version
                client_loader = DataLoader(shards[client_id], batch_size=batchsize, shuffle=True)

                print_log =  (round_num+1) % log_freq == 0 and detailed_print
                client = Client(client_id, client_loader, local_model, self.device)
                client_local_state, client_avg_loss, client_avg_accuracy  = client.client_update(client_loader, criterion, optimizer, local_steps, print_log)

                client_states.append(client_local_state)
                client_avg_losses.append(client_avg_loss)
                client_avg_accuracies.append(client_avg_accuracy)


            # 3) central aggregation: aggregates participating client updates using fedavg_aggregate
            #    and replaces the current parameters of global_model with the returned ones.
            aggregated_state, train_loss, train_accuracy = self.fedavg_aggregate(client_states, [client_sizes[i] for i in selected_clients], client_avg_losses, client_avg_accuracies)

            self.global_model.load_state_dict(aggregated_state)

            train_accuracies.append(train_accuracy)
            train_losses.append(train_loss)
            #Validation at the server
            val_accuracy, val_loss = evaluate(self.global_model, validloader)
            val_accuracies.append(val_accuracy)
            val_losses.append(val_loss)
            if val_accuracy > best_val_acc:
                best_val_acc = val_accuracy
                best_model_state = deepcopy(self.global_model.state_dict())
            if(round_num+1) % 100 == 0:
              print('Reached round '+str(round_num+1))
            if (round_num+1) % log_freq ==0:
                if(detailed_print):
                  print(f"--> best validation accuracy: {best_val_acc:.2f}\n--> training accuracy: {train_accuracy:.2f}")
                  print(f"--> validation loss: {val_loss:.4f}\n--> training loss: {train_loss:.4f}")

                # checkpointing
                checkpoint_data = {
                    'val_accuracies': val_accuracies,
                    'val_losses': val_losses,
                    'train_accuracies': train_accuracies,
                    'train_losses': train_losses,
                    'client_selection_count': client_selection_count
                }
                save_checkpoint(self.global_model,optimizer=None, epoch=round_num, hyperparameters=f"LR{lr}_WD{wd}", subfolder="Federated/", checkpoint_data=checkpoint_data)
                if(detailed_print):
                  print(f"------------------------------ Round {round_num+1} terminated: model updated -----------------------------\n\n" )

        self.global_model.load_state_dict(best_model_state)

        return self.global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count

    def sharding(self, dataset, number_of_clients, number_of_classes=N_VOCAB):
        """
        Function that performs the sharding of the dataset given as input.
        dataset: dataset to be split;
        number_of_clients: the number of partitions we want to obtain;
        number_of_classes: (int) the number of classes inside each partition, or vocab len for IID;
        """

        # Validation of input parameters
        if not (1 <= number_of_classes <= N_VOCAB):
            raise ValueError("number_of_classes should be an integer between 1 and 100")

        # Shuffle dataset indices for randomness
        indices = np.random.permutation(len(dataset))

        # Compute basic partition sizes
        basic_partition_size = len(dataset) // number_of_clients
        remainder = len(dataset) % number_of_clients

        shards = []
        start_idx = 0

        if number_of_classes == N_VOCAB:  # IID Case
            # Equally distribute indices among clients: we can just randomly assign to each client an equal amount of records
            for i in range(number_of_clients):
                end_idx = start_idx + basic_partition_size + (1 if i < remainder else 0)
                shards.append(Subset(dataset, indices[start_idx:end_idx]))
                start_idx = end_idx
        else:  # non-IID Case
            # Count of each class in the dataset
            from collections import Counter
            target_counts = Counter(target for _, target in dataset)

            # Calculate per client class allocation
            class_per_client = np.random.choice(list(target_counts.keys()), size=number_of_classes, replace=False)
            class_idx = {class_: np.where([target == class_ for _, target in dataset])[0] for class_ in class_per_client}

            # Assign class indices evenly to clients
            for i in range(number_of_clients):
                client_indices = np.array([], dtype=int)
                for class_ in class_per_client:
                    n_samples = len(class_idx[class_]) // number_of_clients + (1 if i < remainder else 0)
                    client_indices = np.concatenate((client_indices, class_idx[class_][:n_samples]))
                    class_idx[class_] = np.delete(class_idx[class_], np.arange(n_samples))

                shards.append(Subset(dataset, indices=client_indices))

        return shards

# Model

In [9]:
import torch.nn.functional as F

class CharLSTM(nn.Module):
    """
    Character-level LSTM model for text processing tasks.
    Includes embedding, LSTM, and a fully connected output layer.
    We use:
    - embedding size equal to 8;
    - 2 LSTM layers, each with 256 nodes;
    - densely connected softmax output layer.

    We can avoid to use explicitly the softmax function in the model and
    keep a cross entropy loss function as a loss function.

    as mentioned in paper [2] (Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush,
    Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan; Adaptive Federated Optimization, 2021)
    """
    def __init__(self, vocab_size = 90, embedding_size = 8, lstm_hidden_dim = 256, seq_length=80):
        super(CharLSTM, self).__init__()
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.lstm_hidden_dim = lstm_hidden_dim
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)
        self.lstm1 = nn.LSTM(input_size=embedding_size, hidden_size=lstm_hidden_dim, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=lstm_hidden_dim, hidden_size=lstm_hidden_dim, batch_first=True)
        self.fc = nn.Linear(lstm_hidden_dim, vocab_size)

    def forward(self, x):
        """
        Forward pass through the model.
        """
        # Layer 1: Embedding
        x = self.embedding(x)  # Output shape: (batch_size, seq_length, embedding_dim)

        # Layer 2: First LSTM
        x, _ = self.lstm1(x)  # Output shape: (batch_size, seq_length, lstm_hidden_dim)

        # Layer 3: Second LSTM
        x, hidden = self.lstm2(x)  # Output shape: (batch_size, seq_length, lstm_hidden_dim)

        # Layer 4: Fully Connected Layer
        x = self.fc(x)  # Output shape: (batch_size, seq_length, vocab_size)

        # Softmax Activation
        #x = self.softmax(x)  # Output shape: (batch_size, seq_length, vocab_size)
        return x, hidden

    def init_hidden(self, batch_size):
        """Initializes hidden and cell states for the LSTM."""
        return (torch.zeros(2, batch_size),
            torch.zeros(2, batch_size))
        #2 is equal to the number of lstm layers!

# Data Loading

## Data loading

In [10]:
from torch.utils.data import Dataset, DataLoader

class ShakespeareDataset(Dataset):
    """
    Custom PyTorch Dataset for processing Shakespeare dialogues.
    Converts input data into sequences of indices for input and target processing.
    """
    def __init__(self, data, seq_len, n_vocab):
        """
        Initializes the ShakespeareDataset instance.

        Args:
            data: Dictionary containing dialogues (e.g., train_data or test_data).
            seq_len: Length of sequences to generate for the model.
            n_vocab: Size of the vocabulary for mapping characters to indices.
        """
        self.data = list(data.values())  # Convert the dictionary values to a list
        self.seq_len = seq_len  # Sequence length for the model
        self.n_vocab = n_vocab  # Vocabulary size

    def __len__(self):
        """
        Returns the total number of samples in the dataset.

        Returns:
            int: Number of dialogues in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieves a single sample (input and target) from the dataset.

        Args:
            idx: Index of the sample to retrieve.

        Returns:
            Tuple: Processed input (x) and target (y) tensors for the model.
        """
        dialogue = self.data[idx]  # Get the dialogue at the specified index
        x = process_x([dialogue], self.seq_len, self.n_vocab)[0]  # Prepare input tensor
        y = process_y([dialogue], self.seq_len, self.n_vocab)[0]  # Prepare target tensor
        return x, y

### Download the dataset

In [11]:
# Download latest version of the shakespeare dataset and save the path
path = kagglehub.dataset_download("kewagbln/shakespeareonline")
print("Path to dataset files:", path)
DATA_PATH = os.path.join(path, "t8.shakespeare.txt")
OUTPUT_DIR = "processed_data/"

Downloading from https://www.kaggle.com/api/v1/datasets/download/kewagbln/shakespeareonline?dataset_version_number=1...


100%|██████████| 1.97M/1.97M [00:00<00:00, 137MB/s]

Extracting files...
Path to dataset files: /root/.cache/kagglehub/datasets/kewagbln/shakespeareonline/versions/1





In [12]:
from torch.utils.data import random_split

train_data, test_data = parse_shakespeare_file(DATA_PATH)

train_dataset = ShakespeareDataset(train_data, seq_len=SEQ_LEN, n_vocab=N_VOCAB)
test_dataset = ShakespeareDataset(test_data, seq_len=SEQ_LEN, n_vocab=N_VOCAB)

# Split the train dataset into train and validation:
train_size = int(0.9 * len(train_dataset))  # 90%
valid_size = len(train_dataset) - train_size  # 10%
#random split:
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])

# Creation of the DataLoaders
trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Initialize Model & Loss

In [13]:
global_model = CharLSTM(vocab_size = N_VOCAB, embedding_size = EMBEDDING_SIZE, lstm_hidden_dim = LSTM_HIDDEN_DIM, seq_length = SEQ_LENGTH)
global_model = global_model.to(DEVICE) # Move the model to the device
criterion = nn.CrossEntropyLoss()

## Run the training

In [None]:
""""
gamma = 0.05, first hyperparameter tuning with 100 rounds, then training with 2000 rounds and testing
"""
# Generate 3 values for the learning rate (lr) between 1e-3 and 1e-1 in log-uniform
lr_values = np.logspace(-3, -1, num=3)
# Generate 4 values for the weight decay (lr) between 1e-4 and 1e-1 in log-uniform
wd_values = np.logspace(-4, -1, num=4)
rounds = 100 #fewer communication rounds for hyperparameter tuning
best_val_accuracy = 0

best_setting = None
for lr in lr_values:
    for wd in wd_values:
        print(f"Learning rate: {lr}, Weight decay: {wd}")
        global_model = CharLSTM(vocab_size = N_VOCAB, embedding_size = EMBEDDING_SIZE, lstm_hidden_dim = LSTM_HIDDEN_DIM, seq_length = SEQ_LENGTH)
        global_model = global_model.to(DEVICE) # Move the model to the device
        server = Server(global_model, DEVICE, CHECKPOINT_DIR)
        global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(criterion, trainloader, validloader, num_clients=NUM_CLIENTS, num_classes=N_VOCAB, rounds=rounds, lr=lr, momentum=MOMENTUM, batchsize=BATCH_SIZE, wd=wd, C=FRACTION_CLIENTS, local_steps=LOCAL_STEPS,log_freq=100, detailed_print=False,gamma=0.05)
        plot_metrics(train_accuracies, train_losses,val_accuracies, val_losses, f"Federatedgamma005Tuning_lr_{lr}_wd_{wd}.png")
        print(f"Validation accuracy: {val_accuracies[-1]} with lr: {lr} and wd: {wd}")
        max_val_accuracy = max(val_accuracies)
        if max_val_accuracy > best_val_accuracy:
            best_val_accuracy = max_val_accuracy
            best_setting = (lr,wd)
print(f"Best setting: {best_setting} with validation accuracy: {best_val_accuracy}")

Learning rate: 0.001, Weight decay: 0.0001
No checkpoint found, starting from epoch 1.
Reached round 100
Checkpoint saved: cartellaCheckpoint/Federated/model_epoch_99_params_LR0.001_WD0.0001.pth
Validation accuracy: 18.11965811965812 with lr: 0.001 and wd: 0.0001
Learning rate: 0.001, Weight decay: 0.001
No checkpoint found, starting from epoch 1..


Reached round 100
Checkpoint saved: cartellaCheckpoint/Federated/model_epoch_99_params_LR0.001_WD0.001.pth
Validation accuracy: 12.542735042735043 with lr: 0.001 and wd: 0.001
Learning rate: 0.001, Weight decay: 0.01
No checkpoint found, starting from epoch 1..


Reached round 100
Checkpoint saved: cartellaCheckpoint/Federated/model_epoch_99_params_LR0.001_WD0.01.pth
Validation accuracy: 18.11965811965812 with lr: 0.001 and wd: 0.01
Learning rate: 0.001, Weight decay: 0.1
No checkpoint found, starting from epoch 1..


Reached round 100
Checkpoint saved: cartellaCheckpoint/Federated/model_epoch_99_params_LR0.001_WD0.1.pth
Validation accurac