## Import required libraries for dataset management, model building, training, and visualization.

In [85]:
import os
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.backends import cudnn
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
from torch.nn.utils.rnn import pad_sequence
import matplotlib.pyplot as plt
import collections
from collections import defaultdict
from json import JSONEncoder
import random
import re
from copy import deepcopy


## Dataset Utilities

In [None]:
def letter_to_vec(c, n_vocab=90):
    """Converts a single character to a vector index based on the vocabulary size."""
    return ord(c) % n_vocab

def word_to_indices(word, n_vocab=90):
    """
    Converts a word or list of words into a list of indices.
    Each character is mapped to an index based on the vocabulary size.
    """
    if isinstance(word, list):  # If input is a list of words
        res = []
        for stringa in word:
            res.extend([ord(c) % n_vocab for c in stringa])  # Convert each word to indices
        return res
    else:  # If input is a single word
        return [ord(c) % n_vocab for c in word]

def process_x(raw_x_batch, seq_len, n_vocab):
    """
    Processes raw input data into padded sequences of indices.
    Ensures all sequences are of uniform length.
    """
    x_batch = [word_to_indices(word, n_vocab) for word in raw_x_batch]
    x_batch = [x[:seq_len] + [0] * (seq_len - len(x)) for x in x_batch]
    return torch.tensor(x_batch, dtype=torch.long)

def process_y(raw_y_batch, seq_len, n_vocab):
    """
    Processes raw target data into padded sequences of indices.
    Shifts the sequence by one character to the right.
    y[1:seq_len + 1] takes the input data, right shift of an
    element and uses the next element of the sequence to fill
    and at the end (with [0]) final padding (zeros) are (eventually)
    added to reach the desired sequence length.
    """
    y_batch = [word_to_indices(word, n_vocab) for word in raw_y_batch]
    y_batch = [y[1:seq_len + 1] + [0] * (seq_len - len(y[1:seq_len + 1])) for y in y_batch]  # Shifting and final padding
    return torch.tensor(y_batch, dtype=torch.long)

def create_batches(data, batch_size, seq_len, n_vocab):
    """
    Creates batches of input and target data from dialogues.
    Each batch contains sequences of uniform length.
    """
    x_batches = []
    y_batches = []
    dialogues = list(data.values())
    random.shuffle(dialogues)  # Shuffle to ensure randomness in batches

    batch = []
    for dialogue in dialogues:
        batch.append(dialogue)
        if len(batch) == batch_size:
            x_batch = process_x(batch, seq_len, n_vocab)
            y_batch = process_y(batch, seq_len, n_vocab)
            x_batches.append(x_batch)
            y_batches.append(y_batch)
            batch = []

    # Add the last batch if it's not full
    if batch:
        x_batch = process_x(batch, seq_len, n_vocab)
        y_batch = process_y(batch, seq_len, n_vocab)
        x_batches.append(x_batch)
        y_batches.append(y_batch)

    return x_batches, y_batches


## Save Results

In [89]:
# Get current script directory
SCRIPT_DIR = os.getcwd()

# Set up paths relative to script location
OUTPUT_DIR = os.path.join(SCRIPT_DIR, "processed_data")

# Create directorY if they don't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

class NumpyTensorEncoder(JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.ndarray, torch.Tensor)):
            return obj.tolist()
        if isinstance(obj, (np.float32, np.float64)):
            return float(obj)
        if isinstance(obj, (np.int32, np.int64)):
            return int(obj)
        return super().default(obj)

def save_results_federated(model, train_accuracies, train_losses, test_accuracy, test_loss, client_selection, filename):
    """
    Save federated learning results in both .pth and .json formats.
    Handles PyTorch tensors and NumPy arrays serialization.
    """
    try:
        # Create output directory
        subfolder_path = os.path.join(OUTPUT_DIR, "Federated")
        os.makedirs(subfolder_path, exist_ok=True)
        
        # Define file paths
        filepath_pth = os.path.join(subfolder_path, f"{filename}.pth")
        filepath_json = os.path.join(subfolder_path, f"{filename}.json")
        
        # Prepare results dictionary
        results = {
            'model_state': model.state_dict(),
            'train_accuracies': train_accuracies,
            'train_losses': train_losses,
            'test_accuracy': test_accuracy,
            'test_loss': test_loss,
            'client_count': client_selection
        }
        
        # Save model checkpoint
        torch.save(results, filepath_pth)
        
        # Save JSON metrics with custom encoder
        with open(filepath_json, 'w') as json_file:
            json.dump(results, json_file, indent=4, cls=NumpyTensorEncoder)
            
        print(f"Results saved successfully to {subfolder_path}")
        
    except Exception as e:
        print(f"Error saving results: {str(e)}")
        raise

## Plot results

In [90]:
def plot_results_federated(train_losses, train_accuracies, filename):   
    # Plot federated training performance
    subfolder_path = os.path.join(OUTPUT_DIR, "Federated")
    os.makedirs(subfolder_path, exist_ok=True)

    file_path = os.path.join(subfolder_path, filename)

    # Create epochs list
    epochs = list(range(1, len(train_losses) + 1))
    
    # Create subplot figure
    plt.figure(figsize=(15, 6))
    
    # Plot Training Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss', color='blue')
    plt.xlabel('Rounds', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Federated Training Loss', fontsize=14)
    plt.legend()
    plt.grid(True)
    
    # Plot Training Accuracy 
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Train Accuracy', color='blue')
    plt.xlabel('Rounds', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.title('Federated Training Accuracy', fontsize=14)
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{file_path}.png")
    plt.close()
    
def plot_sampling_distributions(client_sel_count, filename):
    # Plot sampling distributions 
    subfolder_path = os.path.join(OUTPUT_DIR, "Federated")
    os.makedirs(subfolder_path, exist_ok=True)

    file_path = os.path.join(subfolder_path, filename)
    
    """Plot client selection distribution"""
    plt.figure(figsize=(10, 6))
    num_clients = len(client_sel_count)
    plt.figure(figsize=(10, 6))
    plt.bar(range(num_clients), client_sel_count, alpha=0.7, edgecolor='blue')
    plt.xlabel("Client ID", fontsize=14)
    plt.ylabel("Selection Count", fontsize=14)
    plt.title("Client Selection Distribution", fontsize=16)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.savefig(f"{file_path}.png")
    plt.close()



## Shakespeare Dataset

In [91]:
# Class to handle the Shakespeare dataset in a way suitable for PyTorch.
class ShakespeareDataset(Dataset):
    def __init__(self, text, clients=None, seq_length=80, n_vocab=90):
        """
        Initialize the dataset by loading and preprocessing the data.
        Args:
        - data_path: Path to the JSON file containing the dataset.
        - clients: List of client IDs to load data for (default: all clients).
        - seq_length: Sequence length for character-level data.
        """
        self.seq_length = seq_length  # Sequence length for the model
        self.n_vocab = n_vocab  # Vocabulary size

        # Create character mappings
        self.data = list(text.values())  # Convert the dictionary values to a list
            

    def __len__(self):
        """
        Return the number of sequences in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieve the input-target pair at the specified index.
        """
        diag = self.data[idx]
        x = process_x(diag, self.seq_length, self.n_vocab)
        y = process_y(diag, self.seq_length, self.n_vocab)
        return x[0], y[0]


## LSTM Model

In [92]:
# Define the character-level LSTM model for Shakespeare data.
class CharLSTM(nn.Module):
    def __init__(self, n_vocab=90, embedding_dim=8, hidden_dim=256, seq_length=80, num_layers=2):
        """
        Initialize the LSTM model.
        Args:
        - n_vocab: Number of unique characters in the dataset.
        - embedding_dim: Size of the character embedding.
        - hidden_dim: Number of LSTM hidden units.
        - num_layers: Number of LSTM layers.
        - seq_length: Length of input sequences.
        """
        super(CharLSTM, self).__init__()
        self.seq_length = seq_length
        self.n_vocab = n_vocab
        self.embedding_size = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Character embedding layer: Maps indices to dense vectors.
        self.embedding = nn.Embedding(n_vocab, embedding_dim)  # Character embedding layer.
        
        # LSTM layers
        self.lstm_first = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)  # LSTM first layer
        self.lstm_second = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)  # LSTM second layer.
        
        # Fully connected layer: Maps LSTM output to vocabulary size.
        self.fc = nn.Linear(hidden_dim, n_vocab)  # Output layer (vocab_size outputs).

    def forward(self, x, hidden=None):
        """
        Forward pass of the model.
        Args:
        - x: Input batch (character indices).
        - hidden: Hidden state for LSTM (default: None, initialized internally).
        Returns:
        - Output logits and the updated hidden state.
        """
        # Embedding layer: Convert indices to embeddings.
        x = self.embedding(x)  

        # First LSTM
        output, _ = self.lstm_first(x)  # Process through first LSTM layer.
        # Second LSTM
        output, hidden = self.lstm_second(x)  # Process through second LSTM layer.
        # Fully connected layer: Generate logits for each character.
        output = self.fc(output)

        # Note: Softmax is not applied here because CrossEntropyLoss in PyTorch
        # combines the softmax operation with the computation of the loss. 
        # Adding softmax here would be redundant and could introduce numerical instability.
        return output, hidden

    def init_hidden(self, batch_size, device):
        """
        Initializes hidden and cell states for the LSTM.
        Args:
        - batch_size: Number of sequences in the batch.
        Returns:
        - A tuple of zero-initialized hidden and cell states.
        """
        return (torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device),
                torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device))

# Evaluate model performance on a dataset.
def evaluate_model(model, data_loader, criterion, device):
    """
    Evaluate the model on a given dataset.
    Args:
    - model: Trained model.
    - data_loader: DataLoader for the evaluation dataset.
    - criterion: Loss function.
    - device: Device to evaluate on (CPU/GPU).
    Returns:
    - Average loss and accuracy.
    """
    
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient computation for evaluation.
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            # Initialize hidden state
            state = model.init_hidden(inputs.size(0), device) 
            state = (state[0].to(device), state[1].to(device)) 
            outputs, _ = model(inputs)
            outputs = outputs.view(-1, model.n_vocab)
            targets = targets.view(-1)
            loss = criterion(outputs, targets)  # Compute loss.
            total_loss += loss.item()
            _, predictions = outputs.max(1)
            correct_predictions += (predictions == targets).sum().item()
            total_samples += targets.size(0)

    avg_loss = total_loss / len(data_loader)  # Compute average loss.
    accuracy = (correct_predictions / total_samples ) * 100  # Compute accuracy.
    return avg_loss, accuracy


## Federated Training 

### Sample Clients and Create Shards

In [93]:
# Sample clients uniformly for a round of training.
def sample_clients_uniform(num_clients, fraction):
    """
    Sample a fraction of clients uniformly.
    Args:
    - clients: List of all clients.
    - fraction: Fraction of clients to sample.
    Returns:
    - A list of selected clients.
    """
    num_selected = int(fraction * num_clients)  # Compute number of selected clients.
    selected = np.random.choice(num_clients, num_selected, replace=False)
    return selected.tolist()  # Convert to list for consistent indexing



# Sample clients skewed using Dirichlet distribution.
def sample_clients_skewed(num_clients, fraction, gamma):
    """
    Sample a fraction of clients based on Dirichlet distribution.
    Args:
    - clients: List of all clients.
    - fraction: Fraction of clients to sample.
    - gamma: Skewness parameter for Dirichlet distribution.
    Returns:
    - List of selected clients and their probabilities.
    """
    num_selected = int(fraction * num_clients)
    probabilities = np.random.dirichlet([gamma] * num_clients)  # Generate skewed probabilities.
    selected_indices = np.random.choice(num_clients, num_selected, replace=False, p=probabilities)
    return selected_indices.tolist(), probabilities



### Client

In [94]:
class Client:
    def __init__(self, data_loader, id_client, model, char_to_idx, device):
        self.data = data_loader
        self.id_client = id_client
        self.model = model.to(device)
        self.char_to_idx = char_to_idx
        self.device = device
    

    def train_local_model(self, data_loader, criterion, optimizer, local_steps, device):
        """Train model locally with memory optimization"""
        cudnn.benchmark  

        self.model.train()
        total_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        
        try:
            for _ in range(local_steps):
                for i, (inputs, targets) in enumerate(data_loader):
                    # Move data to device
                    inputs = inputs.to(device)
                    targets = targets.to(device)
                    
                    optimizer.zero_grad()

                    # Initialize hidden state
                    state = self.model.init_hidden(inputs.size(0), device)
                    state = tuple(s.to(device) for s in state)
                    # Forward pass with memory efficiency
                    outputs, _ = self.model(inputs, state)
                    
                    # Ensure targets have the correct shape
                    if targets.dim() == 1:  # Targets might need an expansion
                        targets = targets.unsqueeze(1).expand(-1, inputs.shape[1])

                    # Reshape outputs and targets to align properly
                    outputs = outputs.reshape(-1, outputs.size(-1))  # Flatten to [batch_size * seq_length, vocab_size]
                    targets = targets.reshape(-1)  # Flatten to [batch_size * seq_length]

                    # print(f"Inputs shape: {inputs.shape}")
                    # print(f"Outputs shape: {outputs.shape}")
                    # print(f"Targets shape: {targets.shape}")

                    assert outputs.size(0) == targets.size(0), f"Shape mismatch: outputs={outputs.shape}, targets={targets.shape}"

                    loss = criterion(outputs, targets)
                   
                    # Backward pass
                    loss.backward()
                     
                    optimizer.step()
                   
                    # Update metrics 
                    total_loss += loss.item() * targets.numel()  # Weight by batch
                    _, predictions = outputs.max(1)  # Get predicted characters
                    correct_predictions += (predictions == targets).sum().item()  # Compare at character level
                    total_samples += targets.numel()  # Count characters, not sequences
                
            # Compute final metrics
            avg_loss = (total_loss / total_samples ) 
            accuracy = ((correct_predictions / total_samples) * 100) 
            
            print(f"Client {self.id_client}: Loss={avg_loss:.4f}, Acc={accuracy:.4f}")
            return self.model.state_dict(), avg_loss, accuracy
            
        except RuntimeError as e:
            print(f"Error training client {self.id_client}: {str(e)}")
            return None, float('inf'), 0.0


### Server

In [95]:
class Server:
    def __init__(self, test_data, val_data, global_model, char_to_idx, device):
        self.test_data = test_data
        self.val_data = val_data
        self.clients = None
        self.global_model = global_model
        self.char_to_idx = char_to_idx
        self.device = device
        self.losses_round = []
        self.accuracies_round = []
        self.client_selected = []
        self.test_losses = []
        self.test_accuracies = []


    # Federated training with FedAvg.
    def train_federated(self, train_loader, criterion, rounds, num_classes, num_clients, fraction, device, lr, momentum, batch_size, wd, seq_length, C=0.1, local_steps=4, iid=True, participation="uniform", gamma=None):
        """
        Train the global model using federated averaging (FedAvg).
        Args:
        - self -> containing global_model: Global model to train.
        - data_path: Path to dataset.
        - criterion: Loss function.
        - rounds: Number of communication rounds.
        - num_clients: Number of all clients.
        - fraction: Fraction of clients to select in each round.
        - device: Device to train on (CPU/GPU).
        - seq_length: Sequence length for local models.
        - local_steps: Number of local training steps per client.
        - participation: Participation scheme ('uniform' or 'skewed').
        - gamma: Skewness parameter for Dirichlet distribution (if 'skewed').
        Returns:
        - List of global losses and sampling distributions (if skewed).
        """

        self.global_model.to(device)

        sampling_distributions = []  # Track sampling probabilities for skewed participation.

        train_losses = []
        train_accuracies = []
        client_sel_count = np.zeros(num_clients)
        best_model = None
        best_loss = float('inf')
        
        # shards = create_sharding(train_loader.dataset, num_clients, num_classes, iid) #each shard represent the training data for one client
        shards = self.sharding(train_loader)
        assert len(shards) == num_clients, f"Expected {num_clients} shards, got {len(shards)}"
        client_sizes = [len(shard) for shard in shards]

        self.global_model.to(self.device)

        for round_num in range(rounds):
            client_states = []
            client_losses = []
            client_accuracies = []
            print(f"Round {round_num + 1}/{rounds}")
            if participation == "uniform":
                selected_clients = sample_clients_uniform(num_clients, fraction)  # Uniform sampling.
                sampling_distributions.append([1 / num_clients] * num_clients) # Uniform probabilities.
            elif participation == "skewed":
                selected_clients, probabilities = sample_clients_skewed(num_clients, fraction, gamma)  # Skewed sampling.
                sampling_distributions.append(probabilities)  # Store probabilities.
            
            # Train each selected client.
            for id_client in selected_clients:
                client_sel_count[id_client] += 1
                
                local_model = deepcopy(self.global_model)
                optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=momentum, weight_decay=wd)
                
                # Load client's dataset.
                client_loader = DataLoader(shards[id_client], batch_size, shuffle=True)

                client = Client(client_loader, id_client, local_model, self.char_to_idx, self.device)

                # Train local model.
                client_local_state, client_loss, client_accuracy = client.train_local_model(client_loader, criterion, optimizer, local_steps, device)
                client_states.append(client_local_state)
                client_losses.append(client_loss)
                client_accuracies.append(client_accuracy)
                torch.cuda.empty_cache()

            # FedAvg aggregation
            # Aggregate client updates using FedAvg
            if client_states:
                global_dict = deepcopy(self.global_model.state_dict())

                # Initialize empty tensors for aggregation
                for k in global_dict.keys():
                    global_dict[k] = torch.zeros_like(global_dict[k])

                total_samples = sum(client_sizes[i] for i in selected_clients)

                # Aggregate client updates based on data proportions
                for state, id_client, loss, accuracy in zip(client_states, selected_clients, client_losses, client_accuracies):
                    weight = client_sizes[id_client] / total_samples
                    for k in global_dict:
                        global_dict[k] += state[k] * weight  

                # Update the global model
                self.global_model.load_state_dict(global_dict)

                # Calculate weighted global metrics
                weighted_loss = sum(loss * (client_sizes[i] / total_samples) for i, loss in zip(selected_clients, client_losses))
                weighted_accuracy = sum(accuracy * (client_sizes[i] / total_samples) for i, accuracy in zip(selected_clients, client_accuracies))

                train_losses.append(weighted_loss)
                train_accuracies.append(weighted_accuracy)

                print(f"Round {round_num + 1} - Global Loss: {weighted_loss:.4f}, Accuracy: {weighted_accuracy:.2f}%")
            
                if weighted_loss < best_loss:
                    best_loss = weighted_loss
                    best_model = deepcopy(self.global_model.state_dict())

            torch.cuda.empty_cache()
        self.global_model.load_state_dict(best_model)

        return self.global_model, train_accuracies, train_losses, client_sel_count

    def char_to_tensor(self, characters):
        indices = [self.char_to_idx.get(char, self.char_to_idx['']) for char in characters] # Get the index for the character. If not found, use the index for padding.
        return torch.tensor(indices, dtype=torch.long)
    
    def sharding(self, data):
        """
        Prepares individual shards for each user, returning a Subset for each.

        Args:
            data: Dataset containing user data.
            char_to_idx: Character to index mapping dictionary for character conversion.

        Returns:
            List of Subsets, one per client.    
        """
        subsets = []

        for user in data['users']:
            input_tensors = []
            target_tensors = []

            for entry, target in zip(data['user_data'][user]['x'], data['user_data'][user]['y']):
              input_tensors.append(self.char_to_tensor(entry)) 
              target_tensors.append(self.char_to_tensor(target)) 

            padded_inputs = pad_sequence(input_tensors, batch_first=True, padding_value=self.char_to_idx[''])
            targets = torch.cat(target_tensors)

            dataset = TensorDataset(padded_inputs, targets)

            subsets.append(Subset(dataset, torch.arange(len(targets))))

        return subsets  

## Main

In [96]:
def main():
    # Dataset and training configurations
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available
    epochs = 20  # Number of epochs for centralized training
    fraction = 0.1  # Fraction of clients to select each round
    seq_length = 80  # Sequence length for LSTM inputs   
    batch_size = 4 # For local training
    n_vocab = 90 # Character number in vobulary (ASCII)
    embedding_size = 8
    hidden_dim = 256
    train_split = 0.8 # In LEAF Dataset the common split used is 80/20
    momentum = 0
    learning_rate = 0.1
    weight_decay = 0.0001
    C = 0.1

    # Load data
    base_path = os.path.join('leaf', 'data', 'shakespeare', 'data')
    train_path = os.path.join(base_path, 'train', 'all_data_iid_01_1_keep_0_train_9.json')
    test_path = os.path.join(base_path, 'test', 'all_data_iid_01_1_keep_0_test_9.json')

    # Load JSON data
    with open(train_path, 'r') as f:
        train_dataset = json.load(f)
    with open(test_path, 'r') as f:
        test_dataset = json.load(f)

    num_clients = len(train_dataset['users'])
    print("Number of clients:", num_clients) 
    users = train_dataset['users']
    user_data = train_dataset['user_data']

    all_texts = ''.join([''.join(seq) for user in users for seq in user_data[user]['x']])
    chars = sorted(set(all_texts))
    char_to_idx = {ch: idx for idx, ch in enumerate(chars)}

    # Padding character
    char_to_idx[''] = len(char_to_idx)

    # Function to convert character values into indices
    def char_to_tensor(characters):
        indices = [char_to_idx.get(char, char_to_idx['']) for char in characters] # Get the index for the character. If not found, use the index for padding.
        return torch.tensor(indices, dtype=torch.long)

    # Prepare the training data_loader
    input_tensors = []
    target_tensors = []
    for user in train_dataset['users']:
        for entry, target in zip(train_dataset['user_data'][user]['x'], train_dataset['user_data'][user]['y']):
            input_tensors.append(char_to_tensor(entry))  # Use the full sequence of x
            target_tensors.append(char_to_tensor(target))  # Directly use the corresponding y as target

    # Padding and DataLoader creation
    padded_inputs = pad_sequence(input_tensors, batch_first=True, padding_value=char_to_idx[''])
    targets = torch.cat(target_tensors)
    dataset = TensorDataset(padded_inputs, targets)
    
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Prepare the validation data_loader
    train_size = int(0.9 * len(train_dataset))  # 90% of data for training
    val_size = len(train_dataset) - train_size  # 10% of data for validation
    train_data, validation_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

    # Prepare the testing data_loader
    input_tensors = []
    target_tensors = []
    for user in test_dataset['users']:
        for entry, target in zip(test_dataset['user_data'][user]['x'], test_dataset['user_data'][user]['y']):
            input_tensors.append(char_to_tensor(entry))  # Use the full sequence of x
            target_tensors.append(char_to_tensor(target))  # Directly use the corresponding y as target

    # Padding e creazione di DataLoader
    padded_inputs = pad_sequence(input_tensors, batch_first=True, padding_value=char_to_idx[''])
    targets = torch.cat(target_tensors)
    dataset = TensorDataset(padded_inputs, targets)
    
    test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    

    # EXPERIMENTS

    local_steps = [4, 8, 16] #what is called J -> # Number of local training steps
    # Scale the number of rounds inversely with J to maintain a constant computational budget
    num_rounds = {4: 200, 8: 100, 16: 50} # Number of federated communication rounds


    # The first FL baseline
    print("FIRST FL BASELINE")

    num_clients = num_clients
    num_classes = 100 
    iid = True #iid
    C = 0.1
    local_steps = 4

    rounds = num_rounds[local_steps]

    global_model = CharLSTM(n_vocab, embedding_size, hidden_dim, seq_length, num_layers=2) # Initialize global LSTM model
    server = Server(test_loader, validation_dataset, global_model, char_to_idx, device)
    criterion = nn.CrossEntropyLoss()  # Loss function
    optimizer = optim.SGD(global_model.parameters(), learning_rate, momentum, weight_decay)  # Optimizer
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)  # Learning rate scheduler
    global_model, train_accuracies, train_losses, client_sel_count = server.train_federated(
        train_dataset, criterion, rounds, num_classes, num_clients, fraction, device, learning_rate, momentum, 
        batch_size, weight_decay, seq_length, C, local_steps, iid, "uniform")

    # Test
    test_loss, test_accuracy = evaluate_model(global_model, test_loader, criterion, device) 
    print(f"Local steps={local_steps} -> Test Accuracy: {test_accuracy}")

    filename = f"First_baseline_Num_classes_{num_classes}_local_steps_{local_steps}"
    save_results_federated(global_model, train_accuracies, train_losses, test_accuracy, test_loss, client_sel_count, filename)
    plot_results_federated(train_losses, train_accuracies, filename)
    plot_sampling_distributions(client_sel_count, f"{filename}_distribution")


    # The impact of client participation
    print("THE IMPACT OF CLIENT PARTECIPATION")

    num_clients = num_clients
    num_classes = 100
    iid = True #iid
    C = 0.1

    local_steps = 4

    rounds = num_rounds[local_steps]

    print("Uniform partecipation")

    global_model = CharLSTM(n_vocab, embedding_size, hidden_dim, seq_length, num_layers=2) # Initialize global LSTM model
    server = Server(test_loader, validation_dataset, global_model, device)
    criterion = nn.CrossEntropyLoss()  # Loss function
    optimizer = optim.SGD(global_model.parameters(), learning_rate, momentum, weight_decay)  # Optimizer
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)  # Learning rate scheduler
    global_model, train_accuracies, train_losses, client_sel_count = server.train_federated(
        train_dataset, criterion, rounds, num_classes, num_clients, fraction, device, learning_rate, momentum, 
        batch_size, weight_decay, seq_length, C, local_steps, iid, "uniform")

    # Test
    test_loss, test_accuracy = evaluate_model(global_model, test_loader, criterion, device) 
    print(f"Local steps={local_steps} -> Test Accuracy: {test_accuracy}")

    filename = f"Uniform_Client_partecipation_Num_classes_{num_classes}_local_steps_{local_steps}"
    save_results_federated(global_model, train_accuracies, train_losses, test_accuracy, test_loss, client_sel_count, filename)
    plot_results_federated(train_losses, train_accuracies, filename)
    plot_sampling_distributions(client_sel_count, f"{filename}_distribution")


    num_clients = num_clients
    num_classes = 100
    iid = True #iid
    C = 0.1
    local_steps = 4

    rounds = num_rounds[local_steps]

    print("Skewed partecipation")

    # Values of gamma to test
    gamma_values = [0.1, 0.5, 1.0, 5.0]  # Skewness parameter for Dirichlet sampling

    for gamma in gamma_values:
        global_model = CharLSTM(n_vocab, embedding_size, hidden_dim, seq_length, num_layers=2) # Initialize global LSTM model
        server = Server(test_loader, validation_dataset, global_model, device)
        criterion = nn.CrossEntropyLoss()  # Loss function
        optimizer = optim.SGD(global_model.parameters(), learning_rate, momentum, weight_decay)  # Optimizer
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)  # Learning rate scheduler
        global_model, train_accuracies, train_losses, client_sel_count = server.train_federated(
            train_dataset, criterion, rounds, num_classes, num_clients, fraction, device, learning_rate, momentum, 
            batch_size, weight_decay, seq_length, C, local_steps, iid, "skewed", gamma)

        # Test
        test_loss, test_accuracy = evaluate_model(global_model, test_loader, criterion, device) 
        print(f"Local steps={local_steps} -> Test Accuracy: {test_accuracy}")

        filename = f"Skewed_Client_partecipation_Gamma_{gamma}_Num_classes_{num_classes}_local_steps_{local_steps}"
        save_results_federated(global_model, train_accuracies, train_losses, test_accuracy, test_loss, client_sel_count, filename)
        plot_results_federated(train_losses, train_accuracies, filename)
        plot_sampling_distributions(client_sel_count, f"{filename}_distribution")


    # Simulate heterogeneous distributions 
    print("SIMULATE HETEROGENEOUS DISTRIBUTIONS")


    print("Non-iid shardings")
    num_clients = 100
    num_classes = [1, 5, 10, 50]
    iid = False # non-iid
    C = 0.1

    local_steps_list = [4, 8, 16]  # Varying local steps

    for nc in num_classes:
        for local_steps in local_steps_list:
            rounds = num_rounds[local_steps]
            global_model = CharLSTM(n_vocab, embedding_size, hidden_dim, seq_length, num_layers=2) # Initialize global LSTM model
            server = Server(test_loader, validation_dataset, global_model, device)
            criterion = nn.CrossEntropyLoss()  # Loss function
            optimizer = optim.SGD(global_model.parameters(), learning_rate, momentum, weight_decay)  # Optimizer
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)  # Learning rate scheduler
            global_model, train_accuracies, train_losses, client_sel_count = server.train_federated(
                train_dataset, criterion, rounds, nc, num_clients, fraction, device, learning_rate, momentum, 
                batch_size, weight_decay, seq_length, C, local_steps, iid, "uniform")

            # Test
            test_loss, test_accuracy = evaluate_model(global_model, test_loader, criterion, device) 
            print(f"Local steps={local_steps} -> Test Accuracy: {test_accuracy}")

            filename = f"Non_iid_Num_classes_{nc}_local_steps_{local_steps}"
            save_results_federated(global_model, train_accuracies, train_losses, test_accuracy, test_loss, client_sel_count, filename)
            plot_results_federated(train_losses, train_accuracies, filename)
            plot_sampling_distributions(client_sel_count, f"{filename}_distribution")


    print("iid shardings")
    num_clients = 100
    num_classes = 100 
    iid = True # iid
    C = 0.1

    local_steps_list = [4, 8, 16]  # Varying local steps

    for local_steps in local_steps_list:

        rounds = num_rounds[local_steps]
    
        global_model = CharLSTM(n_vocab, embedding_size, hidden_dim, seq_length, num_layers=2) # Initialize global LSTM model
        server = Server(test_loader, validation_dataset, global_model, device)
        criterion = nn.CrossEntropyLoss()  # Loss function
        optimizer = optim.SGD(global_model.parameters(), learning_rate, momentum, weight_decay)  # Optimizer
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)  # Learning rate scheduler
        global_model, train_accuracies, train_losses, client_sel_count = server.train_federated(
            train_dataset, criterion, rounds, nc, num_clients, fraction, device, learning_rate, momentum, 
            batch_size, weight_decay, seq_length, C, local_steps, iid, "uniform")

        # Test
        test_loss, test_accuracy = evaluate_model(global_model, test_loader, criterion, device) 
        print(f"Local steps={local_steps} -> Test Accuracy: {test_accuracy}")

        filename = f"iid_Num_classes_{nc}_local_steps_{local_steps}"
        save_results_federated(global_model, train_accuracies, train_losses, test_accuracy, test_loss, client_sel_count, filename)
        plot_results_federated(train_losses, train_accuracies, filename)
        plot_sampling_distributions(client_sel_count, f"{filename}_distribution")
    
    print("All experiments completed!")

if __name__ == "__main__":
    main()


Number of clients: 11


KeyboardInterrupt: 