# Import

In [3]:
import sys
import torch
import torch.nn as nn
from google.colab import drive
import torch.nn.functional as F
import os
import io
import json
from google.colab import files
from torch.utils.data import DataLoader, TensorDataset
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.backends import cudnn
import time
import numpy as np
import matplotlib.pyplot as plt
from statistics import mean
import torch.optim as optim
from copy import deepcopy
import logging
from torch.utils.data import Subset

# Parameters

In [4]:
# Constants for FL training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

drive.mount('/content/drive')
NUM_CLIENTS = 1129  # Total number of clients in the federation
FRACTION_CLIENTS = 0.1  # Fraction of clients selected per round (C)
LOCAL_STEPS = 100  # Number of local steps (J)
GLOBAL_ROUNDS = 2000  # Total number of communication rounds

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion =  nn.CrossEntropyLoss()

BATCH_SIZE = 100  # Batch size for local training
MOMENTUM = 0.0  # Momentum for SGD optimizer
CHECKPOINT_DIR = '/content/drive/MyDrive/colab_checkpoints/'
LOG_FREQUENCY = 10 # Frequency of logging training progress

cuda
Mounted at /content/drive


## Remove any existing checkpoint

In [5]:
!rm -r {CHECKPOINT_DIR}

# Model

In [6]:
class CharLSTM(nn.Module):
    """
    Character-level LSTM model for text processing tasks.
    Includes embedding, LSTM, and a fully connected output layer.
    We use:
    - embedding size equal to 8;
    - 2 LSTM layers, each with 256 nodes;
    - densely connected softmax output layer.

    We can avoid to use explicitly the softmax function in the model and
    keep a cross entropy loss function as a loss function.

    as mentioned in paper [2] (Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush,
    Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan; Adaptive Federated Optimization, 2021)
    """
    def __init__(self, vocab_size = 70, embedding_size = 8, lstm_hidden_dim = 256, seq_length=80):
        super(CharLSTM, self).__init__()
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.lstm_hidden_dim = lstm_hidden_dim
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)
        self.lstm1 = nn.LSTM(input_size=embedding_size, hidden_size=lstm_hidden_dim, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=lstm_hidden_dim, hidden_size=lstm_hidden_dim, batch_first=True)
        self.fc = nn.Linear(lstm_hidden_dim, vocab_size)

    def forward(self, x, hidden):
        """
        Forward pass through the model.
        """
        # Layer 1: Embedding
        x = self.embedding(x)  # Output shape: (batch_size, seq_length, embedding_dim)

        # Layer 2: First LSTM
        x, _ = self.lstm1(x)  # Output shape: (batch_size, seq_length, lstm_hidden_dim)

        # Layer 3: Second LSTM
        x, hidden = self.lstm2(x)  # Output shape: (batch_size, seq_length, lstm_hidden_dim)

        # Layer 4: Fully Connected Layer
        x = self.fc(x)  # Output shape: (batch_size, seq_length, vocab_size)

        # Softmax Activation
        #x = self.softmax(x)  # Output shape: (batch_size, seq_length, vocab_size)
        return x[:, -1, :], hidden

    def init_hidden(self, batch_size):
        """Initializes hidden and cell states for the LSTM."""
        return (torch.zeros(2, batch_size, self.lstm_hidden_dim),
            torch.zeros(2, batch_size, self.lstm_hidden_dim))
        #2 is equal to the number of lstm layers!



# Checkpointing functions

In [7]:
# Ensure the checkpoint directory exists, creating it if necessary
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def save_checkpoint(model, optimizer, epoch, hyperparameters, subfolder="", checkpoint_data=None):
    """
    Saves the model checkpoint and removes the previous one if it exists.

    Arguments:
    model -- The model whose state is to be saved.
    optimizer -- The optimizer whose state is to be saved (can be None).
    epoch -- The current epoch of the training process.
    hyperparameters -- A string representing the model's hyperparameters for file naming.
    subfolder -- Optional subfolder within the checkpoint directory to save the checkpoint.
    checkpoint_data -- Data to save in a JSON file (e.g., training logs).
    """
    # Define the path for the subfolder where checkpoints will be stored
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    # Create the subfolder if it doesn't exist
    os.makedirs(subfolder_path, exist_ok=True)

    # Construct filenames for both the model checkpoint and the associated JSON file
    filename = f"model_epoch_{epoch}_params_{hyperparameters}.pth"
    filepath = os.path.join(subfolder_path, filename)
    filename_json = f"model_epoch_{epoch}_params_{hyperparameters}.json"
    filepath_json = os.path.join(subfolder_path, filename_json)

    # Define the filenames for the previous checkpoint files, to remove them if necessary
    previous_filepath = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.pth")
    previous_filepath_json = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.json")

    # Remove the previous checkpoint if it exists, but only for epochs greater than 1
    if epoch > 1 and os.path.exists(previous_filepath):
        os.remove(previous_filepath)
        os.remove(previous_filepath_json)

    # Prepare the checkpoint data dictionary
    checkpoint = {'model_state_dict': model.state_dict(), 'epoch': epoch}
    # If an optimizer is provided, save its state as well
    if optimizer is not None:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()

    # Save the model and optimizer (if provided) state dictionary to the checkpoint file
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")

    # If additional data (e.g., training logs) is provided, save it to a JSON file
    if checkpoint_data:
        with open(filepath_json, 'w') as json_file:
            json.dump(checkpoint_data, json_file, indent=4)

def load_checkpoint(model, optimizer, hyperparameters, subfolder=""):
    """
    Loads the latest checkpoint available based on the specified hyperparameters.

    Arguments:
    model -- The model whose state will be updated from the checkpoint.
    optimizer -- The optimizer whose state will be updated from the checkpoint (can be None).
    hyperparameters -- A string representing the model's hyperparameters for file naming.
    subfolder -- Optional subfolder within the checkpoint directory to look for checkpoints.

    Returns:
    The next epoch to resume from and the associated JSON data if available.
    """
    # Define the path to the subfolder where checkpoints are stored
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)

    # If the subfolder doesn't exist, print a message and start from epoch 1
    if not os.path.exists(subfolder_path):
        print("No checkpoint found, starting from epoch 1.")
        return 1, None  # Epoch starts from 1

    # Search for checkpoint files in the subfolder that match the hyperparameters
    files = [f for f in os.listdir(subfolder_path) if f"params_{hyperparameters}" in f and f.endswith('.pth')]

    # If checkpoint files are found, load the one with the highest epoch number
    if files:
        latest_file = max(files, key=lambda x: int(x.split('_')[2]))  # Find the latest epoch file
        filepath = os.path.join(subfolder_path, latest_file)
        checkpoint = torch.load(filepath, weights_only=True)

        # Load the model state from the checkpoint
        model.load_state_dict(checkpoint['model_state_dict'])
        # If an optimizer is provided, load its state as well
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Try to load the associated JSON file if available
        json_filepath = os.path.join(subfolder_path, latest_file.replace('.pth', '.json'))
        json_data = None
        if os.path.exists(json_filepath):
            # If the JSON file exists, load its contents
            with open(json_filepath, 'r') as json_file:
                json_data = json.load(json_file)
            print("Data loaded!")
        else:
            # If no JSON file exists, print a message
            print("No data found")

        # Print the epoch from which the model is resuming
        print(f"Checkpoint found: Resuming from epoch {checkpoint['epoch'] + 1}\n\n")
        return checkpoint['epoch'] + 1, json_data

    # If no checkpoint is found, print a message and start from epoch 1
    print("No checkpoint found, starting from epoch 1..\n\n")
    return 1, None  # Epoch starts from 1

def delete_existing_checkpoints(subfolder=""):
    """
    Deletes all existing checkpoints in the specified subfolder.

    Arguments:
    subfolder -- Optional subfolder within the checkpoint directory to delete checkpoints from.
    """
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    if os.path.exists(subfolder_path):
        for file_name in os.listdir(subfolder_path):
            file_path = os.path.join(subfolder_path, file_name)
            if os.path.isfile(file_path):
                os.remove(file_path)
        print(f"All existing checkpoints in {subfolder_path} have been deleted.")
    else:
        print(f"No checkpoint folder found at {subfolder_path}.")

# DataLoading

We must import the dataset manually since it is taken by the LEAF project.

So far the project is to go on the data folder of shakespeare and:
1. ./get_data.sh inside the preprocess folder
2. ./data_to_json.sh
3. cd ..
3. ././preprocess.sh -s niid --sf 0.2 -k 0 -t sample -tf 0.8 [depending on the preferencies]

In [8]:
uploaded = files.upload()

file_train = next(iter(uploaded))

Saving all_data_niid_1_keep_0_train_9.json to all_data_niid_1_keep_0_train_9.json


In [9]:
uploaded2 = files.upload()


file_test = next(iter(uploaded2))

Saving all_data_niid_1_keep_0_test_9.json to all_data_niid_1_keep_0_test_9.json


In [10]:
data = json.load(io.BytesIO(uploaded[file_train]))

In [11]:
test_data  = json.load(io.BytesIO(uploaded2[file_test]))

In [12]:
users = test_data['users']
num_samples = test_data['num_samples']
user_data = test_data['user_data']

In [13]:
#Load the Json file
with open(file_train, 'r') as file:
    data = json.load(file)

In [14]:
with open(file_test, 'r') as f:
    test_data = json.load(f)

In [15]:
num_clients = len(data['users'])
print("Number of clients:", num_clients)
NUM_CLIENTS = num_clients

Number of clients: 63


In [16]:
users = data['users']
num_samples = data['num_samples']
user_data = data['user_data']

In [17]:
all_texts = ''.join([''.join(seq) for user in users for seq in user_data[user]['x']])
chars = sorted(set(all_texts))
char_to_idx = {ch: idx for idx, ch in enumerate(chars)}

# Add the padding character
char_to_idx['<pad>'] = len(char_to_idx)
idx_to_char = {idx: ch for ch, idx in char_to_idx.items()}

## Convert data into indices

In [18]:
inputs = [[char_to_idx[char] for char in user_data[user]['x'][0]] for user in users]
targets = [[char_to_idx[char] for char in user_data[user]['y'][0]] for user in users]

## Creation of TensorDataset and DataLoader

In [19]:
input_tensors = [torch.tensor(seq) for seq in inputs]
target_tensors = [torch.tensor([seq]) for seq in targets]

chars = sorted(set(all_texts))
char_to_idx = {ch: idx for idx, ch in enumerate(chars)}
char_to_idx['<pad>'] = len(char_to_idx)
idx_to_char = {idx: ch for ch, idx in char_to_idx.items()}

padded_inputs = pad_sequence(input_tensors, batch_first=True, padding_value=char_to_idx['<pad>'])

target_tensors = torch.cat(target_tensors, dim=0)

dataset = TensorDataset(padded_inputs, target_tensors)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


In [20]:
#for testing porpouses
def tensor_to_string(tensor, idx_to_char):
    """Convert a tensor of indices in a string of characters."""
    return ''.join(idx_to_char[idx.item()] for idx in tensor)

In [21]:
# Function to convert character values into indices
def char_to_tensor(characters):
    indices = [char_to_idx.get(char, char_to_idx['<pad>']) for char in characters] # Get the index for the character. If not found, use the index for padding.
    return torch.tensor(indices, dtype=torch.long)

# Prepare the training data_loader
input_tensors = []
target_tensors = []
for user in data['users']:
    for entry, target in zip(data['user_data'][user]['x'], data['user_data'][user]['y']):
        input_tensors.append(char_to_tensor(entry))  # Use the full sequence of x
        target_tensors.append(char_to_tensor(target))  # Directly use the corresponding y as target

# Padding e creazione di DataLoader
padded_inputs = pad_sequence(input_tensors, batch_first=True, padding_value=char_to_idx['<pad>'])
targets = torch.cat(target_tensors)
dataset = TensorDataset(padded_inputs, targets)
# for elem1, elem2 in dataset:
#   elem2 = elem2.unsqueeze(0)

data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

In [22]:
# len of the trainig split:
print(len(data_loader.dataset))

368469


In [23]:
# Prepare the training data_loader
input_tensors = []
target_tensors = []
for user in test_data['users']:
    for entry, target in zip(test_data['user_data'][user]['x'], test_data['user_data'][user]['y']):
        input_tensors.append(char_to_tensor(entry))  # Use the full sequence of x
        target_tensors.append(char_to_tensor(target))  # Directly use the corresponding y as target

# Padding e creazione di DataLoader
padded_inputs = pad_sequence(input_tensors, batch_first=True, padding_value=char_to_idx['<pad>'])
targets = torch.cat(target_tensors)
dataset = TensorDataset(padded_inputs, targets)
# for elem1, elem2 in dataset:
#   elem2 = elem2.unsqueeze(0)

test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

In [24]:
# len of the test split:
print(len(test_loader.dataset))

35998


# Model inizialization

In [25]:
global_model = CharLSTM(vocab_size=len(char_to_idx))
criterion = nn.CrossEntropyLoss()

# Client class

In [26]:
class Client:
    def __init__(self, client_id, data_loader, model, device, char_to_idx):
        """
        Initializes a federated learning client.
        :param client_id: Unique identifier for the client.
        :param data_loader: Data loader specific to the client.
        :param model: The model class to be used by the client.
        :param device: The device (CPU/GPU) to perform computations.
        """
        self.client_id = client_id
        self.data_loader = data_loader
        self.model = model.to(device)
        self.device = device
        self.char_to_idx = char_to_idx

    def client_update(self, client_data, criterion, optimizer, local_steps=4, detailed_print=False):
        """
        Trains a given client's local model on its dataset for a fixed number of steps (`local_steps`).

        Args:
            model (nn.Module): The local model to be updated.
            client_id (int): Identifier for the client (used for logging/debugging purposes).
            client_data (DataLoader): The data loader for the client's dataset.
            criterion (Loss): The loss function used for training (e.g., CrossEntropyLoss).
            optimizer (Optimizer): The optimizer used for updating model parameters (e.g., SGD).
            local_steps (int): Number of local epochs to train on the client's dataset.
            detailed_print (bool): If True, logs the final loss after training.

        Returns:
            dict: The state dictionary of the updated model.
        """


        cudnn.benchmark  # Calling this optimizes runtime

        self.model.train()  # Set the model to training mode
        step_count = 0
        total_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        while step_count < local_steps:
            for data, targets in client_data:
                # Move data and targets to the specified device (e.g., GPU or CPU)
                data, targets = data.to(self.device), targets.to(self.device)


                start_time = time.time()  # for testing-----------------------------

                # Reset the gradients before backpropagation
                optimizer.zero_grad()

                hidden = self.model.init_hidden(data.size(0))
                hidden = (hidden[0].to(DEVICE), hidden[1].to(DEVICE))

                # Forward pass: compute model predictions
                outputs, _ = self.model(data, hidden)

                output_flat = outputs.view(-1, len(self.char_to_idx))
                targets_flat = targets.view(-1)
                # loss = criterion(output_flat, targets_flat)

                # Compute the loss
                loss = criterion(output_flat, targets_flat)

                # # Backward pass: compute gradients and update weights
                loss.backward()
                optimizer.step()

                #---------- Accumulate metrics
                #  Accumulates the weighted loss for the number of samples in the batch to account for any variation in
                #  batch size due to, for example, the smaller final batch. A little too precise? :)
                total_loss += loss.item() * data.size(0)
                _, predicted = output_flat.max(1)
                correct_predictions += predicted.eq(targets).sum().item()
                total_samples += data.size(0)

                step_count +=1
                if step_count >= local_steps:
                  break

        #---------- Compute averaged metrics
        avg_loss = total_loss / total_samples
        avg_accuracy = correct_predictions / total_samples * 100

        # Optionally, print the loss for the last epoch of training
        if detailed_print:
          print(f'Client {self.client_id} --> Final Loss (Step {step_count}/{local_steps}): {loss.item()}')


        # Return the updated model's state dictionary (weights)
        return self.model.state_dict(), avg_loss, avg_accuracy

# Utility functions

In [27]:
DIR = '/content/drive/MyDrive/colab_plots/'

def plot_client_selection(client_selection_count, file_name):
    """
    Bar plot to visualize the frequency of client selections in a federated learning simulation.

    Args:
        client_selection_count (list): list containing the number of times each client was selected.
        file_name (str): name of the file to save the plot.
    """
    # Fixed base directory
    directory = DIR +  'plots_federated/'
    # Ensure the base directory exists
    os.makedirs(directory, exist_ok=True)

    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    num_clients = len(client_selection_count)
    plt.figure(figsize=(10, 6))
    plt.bar(range(num_clients), client_selection_count, alpha=0.7, edgecolor='black')
    plt.xlabel("Client ID", fontsize=14)
    plt.ylabel("Selection Count", fontsize=14)
    plt.title("Client Selection Frequency", fontsize=16)
    plt.xticks(range(num_clients), fontsize=10, rotation=90 if num_clients > 20 else 0)
    plt.tight_layout()
    plt.savefig(file_path, format="png", dpi=300)
    plt.close()



"""
Utility function used both in the centralized and federated learning
Computes the accuracy and the loss on the validation/test set depending on the dataloader passed
"""
def evaluate(model, dataloader, criterion, DEVICE):
    model.eval()  # Set the model to evaluation mode
    running_corrects = 0
    total_samples = 0  # Total samples counter
    losses = []

    with torch.no_grad():
        for data, targets in dataloader:
            data = data.to(DEVICE)
            targets = targets.to(DEVICE)
            hidden = model.init_hidden(data.size(0))
            hidden = (hidden[0].to(DEVICE), hidden[1].to(DEVICE))
            outputs, _ = model(data, hidden)
            outputs_flat = outputs.view(-1, model.vocab_size)
            targets_flat = targets.view(-1)

            loss = criterion(outputs_flat, targets_flat)
            losses.append(loss.item())

            _, preds = outputs_flat.max(1)
            #running_corrects += torch.sum(preds == targets_flat).item()
            running_corrects += (preds == targets_flat).sum().item()
            total_samples += targets_flat.size(0)

    accuracy = (running_corrects / total_samples) * 100
    return accuracy, sum(losses) / len(losses)


def test(global_model, test_loader, criterion, DEVICE):
    """
    Evaluate the global model on the test dataset.

    Args:
        global_model (nn.Module): The global model to be evaluated.
        test_loader (DataLoader): DataLoader for the test dataset.

    Returns:
        float: The accuracy of the model on the test dataset.
        float: The loss of the model on the test dataset.
    """
    test_accuracy, loss = evaluate(global_model, test_loader, criterion, DEVICE)
    return test_accuracy, loss

def plot_metrics(train_accuracies, train_losses, file_name):
    """
    Plot the training metrics for a federated learning simulation.

    Args:
        train_accuracies (list): List of training accuracies.
        train_losses (list): List of training losses.
        file_name (str): Name of the file to save the plot.
    """
    # Fixed base directory
    directory = DIR + '/plots_federated/'
    # Ensure the base directory exists
    os.makedirs(directory, exist_ok=True)

    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    # Create a list of epochs for the x-axis
    epochs = list(range(1, len(train_losses) + 1))

    # Plot the training loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses, label='Train Loss', color='blue')
    plt.xlabel('Rounds', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.title('Training Loss', fontsize=16)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(file_path.replace('.png', '_loss.png'), format='png', dpi=300)
    plt.close()

    # Plot the training accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_accuracies, label='Train Accuracy', color='blue')
    plt.xlabel('Rounds', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title('Training Accuracy', fontsize=16)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(file_path.replace('.png', '_accuracy.png'), format='png', dpi=300)
    plt.close()


def save_data(global_model, train_accuracies, train_losses,client_count, file_name):
    """
    Save the global model, train_accuracies,train_losses and client_count to a file.

    Args:
        global_model (nn.Module): The global model to be saved.
        train_accuracies (list): List of training accuracies.
        train_losses (list): List of training losses.
        file_name (str): Name of the file to save the data.
    """
    # Fixed base directory
    directory = DIR + '/trained_models/'
    # Ensure the base directory exists
    os.makedirs(directory, exist_ok=True)

    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    # Save all data (model state and metrics) into a dictionary
    save_dict = {
        'model_state': global_model.state_dict(),
        'train_accuracies': train_accuracies,
        'train_losses': train_losses,
        'client_count': client_count
    }

    # Save the dictionary to the specified file
    torch.save(save_dict, file_path)
    print(f"Data saved successfully to {file_path}")

def load_data(model, file_name):
    """
    Load the model weights and metrics from a file.

    Args:
        model (nn.Module): The model to load the weights into.
        file_name (str): Name of the file to load the data from.

    Returns:
        tuple: A tuple containing the model, train_accuracies train_losses and client_count.
    """
    # Fixed base directory
    directory = DIR+ 'trained_models/'
    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    # Load the saved data from the specified file
    save_dict = torch.load(file_path)

    # Load the model state
    model.load_state_dict(save_dict['model_state'])

    # Extract the metrics
    train_accuracies = save_dict['train_accuracies']
    train_losses = save_dict['train_losses']
    client_count = save_dict['client_count']

    print(f"Data loaded successfully from {file_path}")

    return model, train_accuracies, train_losses,client_count

# Server class

In [28]:
log = logging.getLogger(__name__)

class Server:
    def __init__(self, global_model, device, char_to_idx, CHECKPOINT_DIR):
        self.global_model = global_model
        self.device = device
        self.char_to_idx = char_to_idx
        self.CHECKPOINT_DIR = CHECKPOINT_DIR
        # Ensure the checkpoint directory exists, creating it if necessary
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    def save_checkpoint(self, model, optimizer, epoch, hyperparameters, subfolder="", checkpoint_data=None):
        """
        Saves the model checkpoint and removes the previous one if it exists.

        Arguments:
        model -- The model whose state is to be saved.
        optimizer -- The optimizer whose state is to be saved (can be None).
        epoch -- The current epoch of the training process.
        hyperparameters -- A string representing the model's hyperparameters for file naming.
        subfolder -- Optional subfolder within the checkpoint directory to save the checkpoint.
        checkpoint_data -- Data to save in a JSON file (e.g., training logs).
        """
        # Define the path for the subfolder where checkpoints will be stored
        subfolder_path = os.path.join(self.CHECKPOINT_DIR, subfolder)
        # Create the subfolder if it doesn't exist
        os.makedirs(subfolder_path, exist_ok=True)

        # Construct filenames for both the model checkpoint and the associated JSON file
        filename = f"model_epoch_{epoch}_params_{hyperparameters}.pth"
        filepath = os.path.join(subfolder_path, filename)
        filename_json = f"model_epoch_{epoch}_params_{hyperparameters}.json"
        filepath_json = os.path.join(subfolder_path, filename_json)

        # Define the filenames for the previous checkpoint files, to remove them if necessary
        previous_filepath = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.pth")
        previous_filepath_json = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.json")

        # Remove the previous checkpoint if it exists, but only for epochs greater than 1
        if epoch > 1 and os.path.exists(previous_filepath):
            os.remove(previous_filepath)
            os.remove(previous_filepath_json)

        # Prepare the checkpoint data dictionary
        checkpoint = {'model_state_dict': model.state_dict(), 'epoch': epoch}
        # If an optimizer is provided, save its state as well
        if optimizer is not None:
            checkpoint['optimizer_state_dict'] = optimizer.state_dict()

        # Save the model and optimizer (if provided) state dictionary to the checkpoint file
        torch.save(checkpoint, filepath)
        print(f"Checkpoint saved: {filepath}")

        # If additional data (e.g., training logs) is provided, save it to a JSON file
        if checkpoint_data:
            with open(filepath_json, 'w') as json_file:
                json.dump(checkpoint_data, json_file, indent=4)

    def load_checkpoint(self, model, optimizer, hyperparameters, subfolder=""):
        """
        Loads the latest checkpoint available based on the specified hyperparameters.

        Arguments:
        model -- The model whose state will be updated from the checkpoint.
        optimizer -- The optimizer whose state will be updated from the checkpoint (can be None).
        hyperparameters -- A string representing the model's hyperparameters for file naming.
        subfolder -- Optional subfolder within the checkpoint directory to look for checkpoints.

        Returns:
        The next epoch to resume from and the associated JSON data if available.
        """
        # Define the path to the subfolder where checkpoints are stored
        subfolder_path = os.path.join(self.CHECKPOINT_DIR, subfolder)

        # If the subfolder doesn't exist, print a message and start from epoch 1
        if not os.path.exists(subfolder_path):
            print("No checkpoint found, starting from epoch 1.")
            return 1, None  # Epoch starts from 1

        # Search for checkpoint files in the subfolder that match the hyperparameters
        files = [f for f in os.listdir(subfolder_path) if f"params_{hyperparameters}" in f and f.endswith('.pth')]

        # If checkpoint files are found, load the one with the highest epoch number
        if files:
            latest_file = max(files, key=lambda x: int(x.split('_')[2]))  # Find the latest epoch file
            filepath = os.path.join(subfolder_path, latest_file)
            checkpoint = torch.load(filepath, weights_only=True)

            # Load the model state from the checkpoint
            model.load_state_dict(checkpoint['model_state_dict'])
            # If an optimizer is provided, load its state as well
            if optimizer:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            # Try to load the associated JSON file if available
            json_filepath = os.path.join(subfolder_path, latest_file.replace('.pth', '.json'))
            json_data = None
            if os.path.exists(json_filepath):
                # If the JSON file exists, load its contents
                with open(json_filepath, 'r') as json_file:
                    json_data = json.load(json_file)
                print("Data loaded!")
            else:
                # If no JSON file exists, print a message
                print("No data found")

            # Print the epoch from which the model is resuming
            print(f"Checkpoint found: Resuming from epoch {checkpoint['epoch'] + 1}\n\n")
            return checkpoint['epoch'] + 1, json_data

        # If no checkpoint is found, print a message and start from epoch 1
        print("No checkpoint found, starting from epoch 1..\n\n")
        return 1, None  # Epoch starts from 1

    def delete_existing_checkpoints(self, subfolder=""):
        """
        Deletes all existing checkpoints in the specified subfolder.

        Arguments:
        subfolder -- Optional subfolder within the checkpoint directory to delete checkpoints from.
        """
        subfolder_path = os.path.join(self.CHECKPOINT_DIR, subfolder)
        if os.path.exists(subfolder_path):
            for file_name in os.listdir(subfolder_path):
                file_path = os.path.join(subfolder_path, file_name)
                if os.path.isfile(file_path):
                    os.remove(file_path)
            print(f"All existing checkpoints in {subfolder_path} have been deleted.")
        else:
            print(f"No checkpoint folder found at {subfolder_path}.")

    def char_to_tensor(self, characters):
        indices = [self.char_to_idx.get(char, self.char_to_idx['<pad>']) for char in characters] # Get the index for the character. If not found, use the index for padding.
        return torch.tensor(indices, dtype=torch.long)

    def fedavg_aggregate(self, client_states, client_sizes, client_avg_losses, client_avg_accuracies):
        """
        Aggregates model updates and client metrics from selected clients using the Federated Averaging (FedAvg) algorithm.
        The updates and metrics are weighted by the size of each client's dataset.

        Args:
            global_model (nn.Module): The global model whose structure is used for aggregation.
            client_states (list[dict]): A list of state dictionaries (model weights) from participating clients.
            client_sizes (list[int]): A list of dataset sizes for the participating clients.
            client_avg_losses (list[float]): A list of average losses for the participating clients.
            client_avg_accuracies (list[float]): A list of average accuracies for the participating clients.

        Returns:
            tuple: The aggregated state dictionary with updated model parameters, global average loss, and global average accuracy.
        """
        # Copy the global model's state dictionary for aggregation
        new_state = deepcopy(self.global_model.state_dict())

        # Calculate the total number of samples across all participating clients
        total_samples = sum(client_sizes)

        # Initialize all parameters in the new state to zero
        for key in new_state:
            new_state[key] = torch.zeros_like(new_state[key])

        # Initialize metrics
        total_loss = 0.0
        total_accuracy = 0.0

        # Perform a weighted average of client updates and metrics
        for state, size, avg_loss, avg_accuracy in zip(client_states, client_sizes, client_avg_losses, client_avg_accuracies):
            for key in new_state:
                # Add the weighted contribution of each client's parameters
                new_state[key] += (state[key] * size / total_samples)
            total_loss += avg_loss * size
            total_accuracy += avg_accuracy * size

        # Calculate global metrics
        global_avg_loss = total_loss / total_samples
        global_avg_accuracy = total_accuracy / total_samples

        # Return the aggregated state dictionary with updated weights and global metrics
        return new_state, global_avg_loss, global_avg_accuracy


    # Federated Learning Training Loop
    def train_federated(self, criterion, raw_data, num_clients, rounds, lr, momentum, batchsize, wd, C=0.1, local_steps=4, log_freq=10, detailed_print=True,gamma=None):
        # val_accuracies = []
        # val_losses = []
        train_accuracies = []
        train_losses = []
        #print("num clients: ",num_clients)
        best_model_state = None  # The model with the best accuracy
        client_selection_count = [0] * num_clients #Count how many times a client has been selected
        #best_val_acc = 0.0
        best_train_loss = float('inf')

        shards = self.sharding(raw_data) #each shard represent the training data for one client
        client_sizes = [len(shard) for shard in shards]
        #print("client sizes: ", client_sizes)

        self.global_model.to(self.device) #as alwayse, we move the global model to the specified device (CPU or GPU)

        #loading checkpoint if it exists
        # checkpoint_start_step, data_to_load = load_checkpoint(model=self.global_model,optimizer=None,hyperparameters=f"LR{lr}_WD{wd}", subfolder="Federated/")
        # if data_to_load is not None:
        #   train_accuracies = data_to_load['train_accuracies']
        #   train_losses = data_to_load['train_losses']
        #   client_selection_count = data_to_load['client_selection_count']
        probabilities = None
        if gamma is not None:
            probabilities = self.skewed_probabilities(num_clients, gamma)

        # ********************* HOW IT WORKS ***************************************
        # The training runs for rounds iterations (GLOBAL_ROUNDS=200)
        # Each round simulates one communication step in federated learning, including:
        # 1) client selection
        # 2) local training (of each client)
        # 3) central aggregation
        for round_num in range(rounds):
            if (round_num+1) % log_freq == 0 and detailed_print:
              print(f"------------------------------------- Round {round_num+1} ------------------------------------------------" )

            # 1) client selection: In each round, a fraction C (e.g., 10%) of clients is randomly selected to participate.
            #     This reduces computation costs and mimics real-world scenarios where not all devices are active.
            selected_clients = self.client_selection(num_clients, C,probabilities)
            client_states = []
            client_avg_losses = []
            client_avg_accuracies = []
            for client_id in selected_clients:
                client_selection_count[client_id] += 1

            # 2) local training: for each client updates the model using the client's data for local_steps epochs
            for client_id in selected_clients:
                local_model = deepcopy(self.global_model) #it creates a local copy of the global model
                optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=momentum, weight_decay=wd)
                client_loader = DataLoader(shards[client_id], batch_size=batchsize, shuffle=True)

                print_log =  (round_num+1) % log_freq == 0 and detailed_print
                client = Client(client_id, client_loader, local_model, self.device, self.char_to_idx)
                client_local_state, client_avg_loss, client_avg_accuracy  = client.client_update(client_loader, criterion, optimizer, local_steps, print_log)

                client_states.append(client_local_state)
                client_avg_losses.append(client_avg_loss)
                client_avg_accuracies.append(client_avg_accuracy)


            # 3) central aggregation: aggregates participating client updates using fedavg_aggregate
            #    and replaces the current parameters of global_model with the returned ones.
            aggregated_state, train_loss, train_accuracy = self.fedavg_aggregate(client_states, [client_sizes[i] for i in selected_clients], client_avg_losses, client_avg_accuracies)

            self.global_model.load_state_dict(aggregated_state)

            train_accuracies.append(train_accuracy)
            train_losses.append(train_loss)
            # #Validation at the server
            # val_accuracy, val_loss = evaluate(self.global_model, validloader)
            # val_accuracies.append(val_accuracy)
            # val_losses.append(val_loss)
            if train_loss < best_train_loss:
                best_train_loss = train_loss
                best_model_state = deepcopy(self.global_model.state_dict())

            if (round_num+1) % log_freq == 0:
                if detailed_print:
                    print(f"-->training accuracy: {train_accuracy:.2f}")
                    print(f"-->training loss: {train_loss:.4f}")

                # checkpointing
                checkpoint_data = {
                    'train_accuracies': train_accuracies,
                    'train_losses': train_losses,
                    'client_selection_count': client_selection_count
                }
                self.save_checkpoint(self.global_model,optimizer=None, epoch=round_num, hyperparameters=f"LR{lr}_WD{wd}", subfolder="Federated/", checkpoint_data=checkpoint_data)
                if detailed_print:
                    print(f"------------------------------ Round {round_num+1} terminated: model updated -----------------------------\n\n" )

        self.global_model.load_state_dict(best_model_state)

        return self.global_model, train_accuracies, train_losses, client_selection_count

    def skewed_probabilities(self, number_of_clients, gamma=0.5):
            # Generate skewed probabilities using a Dirichlet distribution
            probabilities = np.random.dirichlet(np.ones(number_of_clients) * gamma)
            return probabilities

    def client_selection(self,number_of_clients, clients_fraction, probabilities=None):
        """
        Selects a subset of clients based on uniform or skewed distribution.

        Args:
        number_of_clients (int): Total number of clients.
        clients_fraction (float): Fraction of clients to be selected.
        uniform (bool): If True, selects clients uniformly. If False, selects clients based on a skewed distribution.
        gamma (float): Hyperparameter for the Dirichlet distribution controlling the skewness (only used if uniform=False).

        Returns:
        list: List of selected client indices.
        """
        num_clients_to_select = int(number_of_clients * clients_fraction)

        if probabilities is None:
            # Uniformly select clients without replacement
            selected_clients = np.random.choice(number_of_clients, num_clients_to_select, replace=False)
        else:
            selected_clients = np.random.choice(number_of_clients, num_clients_to_select, replace=False, p=probabilities)

        return selected_clients



    def sharding(self, data):
        """
        Prepares individual shards for each user, returning a Subset for each.

        Args:
        data (dict): Dataset containing user data.
        char_to_idx (dict): Character to index mapping dictionary for character conversion.

        Returns:
        list: List of Subsets, one for each user.
        """
        subsets = []

        for user in data['users']:
            input_tensors = []
            target_tensors = []

            for entry, target in zip(data['user_data'][user]['x'], data['user_data'][user]['y']):
              input_tensors.append(self.char_to_tensor(entry))  # Use the full sequence of x
              target_tensors.append(self.char_to_tensor(target))  # Directly use the corresponding y as target

            # Padding inputs to ensure all inputs in a batch have the same length
            padded_inputs = pad_sequence(input_tensors, batch_first=True, padding_value=self.char_to_idx['<pad>'])
            targets = torch.cat(target_tensors)

            # Creating the TensorDataset for the user
            dataset = TensorDataset(padded_inputs, targets)

            # Since each user is treated as a separate "client", we create a Subset for each
            subsets.append(Subset(dataset, torch.arange(len(targets))))

        return subsets

# Training cycle and testing results

In [29]:
print(criterion)

CrossEntropyLoss()


In [30]:
print(num_clients)

63


In [31]:
print(BATCH_SIZE)

100


In [32]:
print(FRACTION_CLIENTS)

0.1


In [33]:
print(MOMENTUM)

0.0


In [42]:
lr = 1.0
wd = 0.0001
'''
These hyperparameters are taken from:
Acar, Durmus Alp Emre, et al. "Federated learning based on dynamic regularization." arXiv preprint arXiv:2111.04263 (2021).

Notice infact that the leaf version of the Shakespeare dataset doesn't come with a linked validation dataset to
choose the most accurate hyperparameters.
'''

local_steps =4
LOCAL_STEPS_VALUES = [4, 8, 16]  # Values for J (number of local steps)
NUM_RUNDS = {4: 200, 8: 100, 16:50}
print(f"Running experiment: local_steps={local_steps}")
global_model = CharLSTM(vocab_size=len(char_to_idx))
server = Server(global_model, DEVICE, char_to_idx, CHECKPOINT_DIR)

    #tuning_rounds = int(NUM_RUNDS[local_steps]/20)
    #best_lr, best_wd = to be manually set

global_model, train_accuracies, train_losses, client_selection_count = server.train_federated(
        criterion, data,
        num_clients=NUM_CLIENTS,
        rounds=NUM_RUNDS[local_steps], lr=lr, momentum=MOMENTUM,
        batchsize=BATCH_SIZE, wd=wd, C=FRACTION_CLIENTS,
        local_steps=local_steps, log_freq=100,
        detailed_print=True, gamma=None  # No skewed sampling for this experiment
)

# Testing and plotting
test_accuracy, test_loss = test(global_model, test_loader, criterion, DEVICE)
plot_metrics(train_accuracies, train_losses, f"Federated_scaled_LR_{lr}_WD_{wd}.png")
print(f"Test accuracy for local_steps={local_steps}: {test_accuracy}")

# Save data for future analysis
save_data(global_model, train_accuracies, train_losses, client_selection_count, f"Federated_LR_{lr}_WD_{wd}_j_{local_steps}.pth")

Running experiment: local_steps=4
------------------------------------- Round 100 ------------------------------------------------
Client 0 --> Final Loss (Step 4/4): 3.0484912395477295
Client 44 --> Final Loss (Step 4/4): 2.900583267211914
Client 43 --> Final Loss (Step 4/4): 2.853752851486206
Client 54 --> Final Loss (Step 4/4): 2.8242075443267822
Client 59 --> Final Loss (Step 4/4): 2.807734966278076
Client 14 --> Final Loss (Step 4/4): 3.1103081703186035
-->training accuracy: 25.16
-->training loss: 2.7717
Checkpoint saved: /content/drive/MyDrive/colab_checkpoints/Federated/model_epoch_99_params_LR1.0_WD0.0001.pth
------------------------------ Round 100 terminated: model updated -----------------------------


------------------------------------- Round 200 ------------------------------------------------
Client 53 --> Final Loss (Step 4/4): 2.285294532775879
Client 49 --> Final Loss (Step 4/4): 2.6715526580810547
Client 17 --> Final Loss (Step 4/4): 2.417442798614502
Client 35 --