## Import required libraries for dataset management, model building, training, and visualization.

In [235]:
import os
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt
import collections
from collections import defaultdict
from json import JSONEncoder
import random
import kagglehub
import shutil
import glob
import re
from copy import deepcopy
from tqdm import tqdm  # For progress tracking.

## Dataset Parsing

In [236]:
# Regular expressions for parsing Shakespeare text
CHARACTER_RE = re.compile(r'^  ([a-zA-Z][a-zA-Z ]*)\. (.*)')  # Matches character lines
CONT_RE = re.compile(r'^    (.*)')  # Matches continuation lines
COE_CHARACTER_RE = re.compile(r'^([a-zA-Z][a-zA-Z ]*)\. (.*)')  # Special regex for Comedy of Errors
COE_CONT_RE = re.compile(r'^(.*)')  # Continuation for Comedy of Errors


# Get current script directory
SCRIPT_DIR = os.getcwd()

# Download dataset
path = kagglehub.dataset_download("kewagbln/shakespeareonline")

# Debug: print downloaded files
print(f"Downloaded path: {path}")
print("Files in downloaded path:")
for file in glob.glob(os.path.join(path, "*")):
    print(f" - {file}")

# Set up paths relative to script location
DATA_PATH = os.path.join(SCRIPT_DIR, "shakespeare.txt")
OUTPUT_DIR = os.path.join(SCRIPT_DIR, "processed_data")

# Create directories if they don't exist
os.makedirs(os.path.dirname(DATA_PATH), exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Find and copy Shakespeare text file
shakespeare_file = None
for file in glob.glob(os.path.join(path, "*.txt")):
    shakespeare_file = file
    break

if shakespeare_file:
    shutil.copy2(shakespeare_file, DATA_PATH)
    print(f"Dataset saved to: {DATA_PATH}")
else:
    raise FileNotFoundError(f"Could not find Shakespeare text file in {path}")


Downloaded path: C:\Users\rosif\.cache\kagglehub\datasets\kewagbln\shakespeareonline\versions\1
Files in downloaded path:
 - C:\Users\rosif\.cache\kagglehub\datasets\kewagbln\shakespeareonline\versions\1\t8.shakespeare.txt
Dataset saved to: c:\Users\rosif\OneDrive\Desktop\Advance Machine Learning\Project2024\AdvanceML_project5\shakespeare.txt


## Dataset Preprocessing

In [237]:
def __txt_to_data(txt_dir, seq_length=80):
    """Parses text file in given directory into data for next-character model.

    Args:
        txt_dir: path to text file
        seq_length: length of strings in X
    """
    raw_text = ""
    with open(txt_dir,'r') as inf:
        raw_text = inf.read()
    raw_text = raw_text.replace('\n', ' ')
    raw_text = re.sub(r"   *", r' ', raw_text)
    dataX = []
    dataY = []
    for i in range(0, len(raw_text) - seq_length, 1):
        seq_in = raw_text[i:i + seq_length]
        seq_out = raw_text[i + seq_length]
        dataX.append(seq_in)
        dataY.append(seq_out)
    return dataX, dataY

def parse_data_in(data_dir, users_and_plays_path, raw=False):
    '''
    returns dictionary with keys: users, num_samples, user_data
    raw := bool representing whether to include raw text in all_data
    if raw is True, then user_data key
    removes users with no data
    '''
    with open(users_and_plays_path, 'r') as inf:
        users_and_plays = json.load(inf)
    files = os.listdir(data_dir)
    users = []
    hierarchies = []
    num_samples = []
    user_data = {}
    for f in files:
        user = f[:-4]
        passage = ''
        filename = os.path.join(data_dir, f)
        with open(filename, 'r') as inf:
            passage = inf.read()
        dataX, dataY = __txt_to_data(filename)
        if(len(dataX) > 0):
            users.append(user)
            if raw:
                user_data[user] = {'raw': passage}
            else:
                user_data[user] = {}
            user_data[user]['x'] = dataX
            user_data[user]['y'] = dataY
            hierarchies.append(users_and_plays[user])
            num_samples.append(len(dataY))
    all_data = {}
    all_data['users'] = users
    all_data['hierarchies'] = hierarchies
    all_data['num_samples'] = num_samples
    all_data['user_data'] = user_data
    return all_data

def parse_shakespeare(filepath, train_split=0.8):
    """
    Parses Shakespeare's text into training and testing datasets.
    """
    with open(filepath, "r") as file:
        raw_text = file.read()

    plays_data, _ = process_plays(raw_text)
    _, training_set, testing_set = split_train_test_data(plays_data, 1.0 - train_split)

    total_train = sum(len(lines) for lines in training_set.values())
    total_test = sum(len(lines) for lines in testing_set.values())
    print(f"Training examples: {total_train}")
    print(f"Testing examples: {total_test}")
    
    assert total_train > total_test, "Training set should be larger than test set"

    return training_set, testing_set

def process_plays(shakespeare_full):
    """
    Processes the Shakespeare text into individual plays and characters' dialogues.
    Handles special cases for "The Comedy of Errors".
    """
    plays = []
    slines = shakespeare_full.splitlines(True)[1:]  # Skip the first line (title/header)
    current_character = None
    comedy_of_errors = False

    for i, line in enumerate(slines):
        # Detect play titles and initialize character dictionary
        if "by William Shakespeare" in line:
            current_character = None
            characters = defaultdict(list)
            title = slines[i - 2].strip() if slines[i - 2].strip() else slines[i - 3].strip()
            comedy_of_errors = title == "THE COMEDY OF ERRORS"
            plays.append((title, characters))
            continue

        # Match character lines or continuation lines
        match = _match_character_regex(line, comedy_of_errors)
        if match:
            character, snippet = match.group(1).upper(), match.group(2)
            if not (comedy_of_errors and character.startswith("ACT ")):
                characters[character].append(snippet)
                current_character = character
        elif current_character:
            match = _match_continuation_regex(line, comedy_of_errors)
            if match:
                characters[current_character].append(match.group(1))

    # Filter out plays with insufficient dialogue data
    return [play for play in plays if len(play[1]) > 1], []

def _match_character_regex(line, comedy_of_errors=False):
    """Matches character dialogues, with special handling for 'The Comedy of Errors'."""
    return COE_CHARACTER_RE.match(line) if comedy_of_errors else CHARACTER_RE.match(line)

def _match_continuation_regex(line, comedy_of_errors=False):
    """Matches continuation lines of dialogues."""
    return COE_CONT_RE.match(line) if comedy_of_errors else CONT_RE.match(line)

def extract_play_title(lines, index):
    """
    Extracts the title of the play from the lines of the text.
    """
    for i in range(index - 1, -1, -1):
        if lines[i].strip():
            return lines[i].strip()
    return "UNKNOWN"

def detect_character_line(line, comedy_of_errors):
    """
    Matches a line of character dialogue.
    """
    return COE_CHARACTER_RE.match(line) if comedy_of_errors else CHARACTER_RE.match(line)

def detect_continuation_line(line, comedy_of_errors):
    """
    Matches a continuation line of dialogue.
    """
    return COE_CONT_RE.match(line) if comedy_of_errors else CONT_RE.match(line)

def _split_into_plays(shakespeare_full):
    """Splits the full data by play."""
    # List of tuples (play_name, dict from character to list of lines)
    plays = []
    discarded_lines = []  # Track discarded lines.
    slines = shakespeare_full.splitlines(True)[1:]

    # skip contents, the sonnets, and all's well that ends well
    author_count = 0
    start_i = 0
    for i, l in enumerate(slines):
        if 'by William Shakespeare' in l:
            author_count += 1
        if author_count == 2:
            start_i = i - 5
            break
    slines = slines[start_i:]

    current_character = None
    comedy_of_errors = False
    for i, line in enumerate(slines):
        # This marks the end of the plays in the file.
        if i > 124195 - start_i:
            break
        # This is a pretty good heuristic for detecting the start of a new play:
        if 'by William Shakespeare' in line:
            current_character = None
            characters = collections.defaultdict(list)
            # The title will be 2, 3, 4, 5, 6, or 7 lines above "by William Shakespeare".
            if slines[i - 2].strip():
                title = slines[i - 2]
            elif slines[i - 3].strip():
                title = slines[i - 3]
            elif slines[i - 4].strip():
                title = slines[i - 4]
            elif slines[i - 5].strip():
                title = slines[i - 5]
            elif slines[i - 6].strip():
                title = slines[i - 6]
            else:
                title = slines[i - 7]
            title = title.strip()

            assert title, (
                'Parsing error on line %d. Expecting title 2 or 3 lines above.' %
                i)
            comedy_of_errors = (title == 'THE COMEDY OF ERRORS')
            # Degenerate plays are removed at the end of the method.
            plays.append((title, characters))
            continue
        match = _match_character_regex(line, comedy_of_errors)
        if match:
            character, snippet = match.group(1), match.group(2)
            # Some character names are written with multiple casings, e.g., SIR_Toby
            # and SIR_TOBY. To normalize the character names, we uppercase each name.
            # Note that this was not done in the original preprocessing and is a
            # recent fix.
            character = character.upper()
            if not (comedy_of_errors and character.startswith('ACT ')):
                characters[character].append(snippet)
                current_character = character
                continue
            else:
                current_character = None
                continue
        elif current_character:
            match = _match_continuation_regex(line, comedy_of_errors)
            if match:
                if comedy_of_errors and match.group(1).startswith('<'):
                    current_character = None
                    continue
                else:
                    characters[current_character].append(match.group(1))
                    continue
        # Didn't consume the line.
        line = line.strip()
        if line and i > 2646:
            # Before 2646 are the sonnets, which we expect to discard.
            discarded_lines.append('%d:%s' % (i, line))
    # Remove degenerate "plays".
    return [play for play in plays if len(play[1]) > 1], discarded_lines


def _remove_nonalphanumerics(filename):
    return re.sub('\\W+', '_', filename)

def play_and_character(play, character):
    return _remove_nonalphanumerics((play + '_' + character).replace(' ', '_'))

def split_train_test_data(plays, test_fraction=0.2):
    """
    Splits the plays into training and testing datasets by character dialogues.
    """
    skipped_characters = 0
    all_train_examples = collections.defaultdict(list)
    all_test_examples = collections.defaultdict(list)

    def add_examples(example_dict, example_tuple_list):
        for play, character, sound_bite in example_tuple_list:
            example_dict[play_and_character(
                play, character)].append(sound_bite)

    users_and_plays = {}
    for play, characters in plays:
        curr_characters = list(characters.keys())
        for c in curr_characters:
            users_and_plays[play_and_character(play, c)] = play
        for character, sound_bites in characters.items():
            examples = [(play, character, sound_bite)
                        for sound_bite in sound_bites]
            if len(examples) <= 2:
                skipped_characters += 1
                # Skip characters with fewer than 2 lines since we need at least one
                # train and one test line.
                continue
            train_examples = examples
            if test_fraction > 0:
                num_test = max(int(len(examples) * test_fraction), 1)
                train_examples = examples[:-num_test]
                test_examples = examples[-num_test:]
                
                assert len(test_examples) == num_test
                assert len(train_examples) >= len(test_examples)

                add_examples(all_test_examples, test_examples)
                add_examples(all_train_examples, train_examples)

    return users_and_plays, all_train_examples, all_test_examples


def _write_data_by_character(examples, output_directory):
    """Writes a collection of data files by play & character."""
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    for character_name, sound_bites in examples.items():
        filename = os.path.join(output_directory, character_name + '.txt')
        with open(filename, 'w') as output:
            for sound_bite in sound_bites:
                output.write(sound_bite + '\n')


## Dataset Utilities

In [238]:
def letter_to_vec(c, n_vocab=90):
    """Converts a single character to a vector index based on the vocabulary size."""
    return ord(c) % n_vocab

def word_to_indices(word, n_vocab=90):
    """
    Converts a word or list of words into a list of indices.
    Each character is mapped to an index based on the vocabulary size.
    """
    if isinstance(word, list):  # If input is a list of words
        res = []
        for stringa in word:
            res.extend([ord(c) % n_vocab for c in stringa])  # Convert each word to indices
        return res
    else:  # If input is a single word
        return [ord(c) % n_vocab for c in word]

def process_x(raw_x_batch, seq_len, n_vocab):
    """
    Processes raw input data into padded sequences of indices.
    Ensures all sequences are of uniform length.
    """
    x_batch = [word_to_indices(word, n_vocab) for word in raw_x_batch]
    x_batch = [x[:seq_len] + [0] * (seq_len - len(x)) for x in x_batch]
    return torch.tensor(x_batch, dtype=torch.long)

def process_y(raw_y_batch, seq_len, n_vocab):
    """
    Processes raw target data into padded sequences of indices.
    Shifts the sequence by one character to the right.
    y[1:seq_len + 1] takes the input data, right shift of an
    element and uses the next element of the sequence to fill
    and at the end (with [0]) final padding (zeros) are (eventually)
    added to reach the desired sequence length.
    """
    y_batch = [word_to_indices(word, n_vocab) for word in raw_y_batch]
    y_batch = [y[1:seq_len + 1] + [0] * (seq_len - len(y[1:seq_len + 1])) for y in y_batch]  # Shifting and final padding
    return torch.tensor(y_batch, dtype=torch.long)

def create_batches(data, batch_size, seq_len, n_vocab):
    """
    Creates batches of input and target data from dialogues.
    Each batch contains sequences of uniform length.
    """
    x_batches = []
    y_batches = []
    dialogues = list(data.values())
    random.shuffle(dialogues)  # Shuffle to ensure randomness in batches

    batch = []
    for dialogue in dialogues:
        batch.append(dialogue)
        if len(batch) == batch_size:
            x_batch = process_x(batch, seq_len, n_vocab)
            y_batch = process_y(batch, seq_len, n_vocab)
            x_batches.append(x_batch)
            y_batches.append(y_batch)
            batch = []

    # Add the last batch if it's not full
    if batch:
        x_batch = process_x(batch, seq_len, n_vocab)
        y_batch = process_y(batch, seq_len, n_vocab)
        x_batches.append(x_batch)
        y_batches.append(y_batch)

    return x_batches, y_batches


## Save Results

In [239]:
class NumpyTensorEncoder(JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.ndarray, torch.Tensor)):
            return obj.tolist()
        if isinstance(obj, (np.float32, np.float64)):
            return float(obj)
        if isinstance(obj, (np.int32, np.int64)):
            return int(obj)
        return super().default(obj)

def save_results_federated(model, train_accuracies, train_losses, test_accuracy, test_loss, client_selection, filename):
    """
    Save federated learning results in both .pth and .json formats.
    Handles PyTorch tensors and NumPy arrays serialization.
    """
    try:
        # Create output directory
        subfolder_path = os.path.join(OUTPUT_DIR, "Federated")
        os.makedirs(subfolder_path, exist_ok=True)
        
        # Define file paths
        filepath_pth = os.path.join(subfolder_path, f"{filename}.pth")
        filepath_json = os.path.join(subfolder_path, f"{filename}.json")
        
        # Prepare results dictionary
        results = {
            'model_state': model.state_dict(),
            'train_accuracies': train_accuracies,
            'train_losses': train_losses,
            'test_accuracy': test_accuracy,
            'test_loss': test_loss,
            'client_count': client_selection
        }
        
        # Save model checkpoint
        torch.save(results, filepath_pth)
        
        # Save JSON metrics with custom encoder
        with open(filepath_json, 'w') as json_file:
            json.dump(results, json_file, indent=4, cls=NumpyTensorEncoder)
            
        print(f"Results saved successfully to {subfolder_path}")
        
    except Exception as e:
        print(f"Error saving results: {str(e)}")
        raise

## Plot results

In [240]:
def plot_results_federated(train_losses, train_accuracies, filename):   
    # Plot federated training performance
    subfolder_path = os.path.join(OUTPUT_DIR, "Federated")
    os.makedirs(subfolder_path, exist_ok=True)

    file_path = os.path.join(subfolder_path, filename)

    # Create epochs list
    epochs = list(range(1, len(train_losses) + 1))
    
    # Create subplot figure
    plt.figure(figsize=(15, 6))
    
    # Plot Training Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss', color='blue')
    plt.xlabel('Rounds', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Federated Training Loss', fontsize=14)
    plt.legend()
    plt.grid(True)
    
    # Plot Training Accuracy 
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Train Accuracy', color='blue')
    plt.xlabel('Rounds', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.title('Federated Training Accuracy', fontsize=14)
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{file_path}.png")
    plt.close()


## Shakespeare Dataset

In [241]:
# Class to handle the Shakespeare dataset in a way suitable for PyTorch.
class ShakespeareDataset(Dataset):
    def __init__(self, text, clients=None, seq_length=80, n_vocab=90):
        """
        Initialize the dataset by loading and preprocessing the data.
        Args:
        - data_path: Path to the JSON file containing the dataset.
        - clients: List of client IDs to load data for (default: all clients).
        - seq_length: Sequence length for character-level data.
        """
        self.seq_length = seq_length  # Sequence length for the model
        self.n_vocab = n_vocab  # Vocabulary size

        # Create character mappings
        self.data = list(text.values())  # Convert the dictionary values to a list
            

    def __len__(self):
        """
        Return the number of sequences in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieve the input-target pair at the specified index.
        """
        diag = self.data[idx]
        x = process_x(diag, self.seq_length, self.n_vocab)
        y = process_y(diag, self.seq_length, self.n_vocab)
        return x[0], y[0]


## LSTM Model

In [242]:
# Custom LSTM model tailored for FL with Mixture of Experts and Personalization
class FLTailoredCharLSTM(nn.Module):
    def __init__(self, n_vocab=90, embedding_dim=8, hidden_dim=256, seq_length=80, num_layers=1, num_experts=5, personalization=False):
        super(FLTailoredCharLSTM, self).__init__()
        self.n_vocab = n_vocab  # Store vocabulary size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.personalization = personalization

        # Shared global layers
        self.embedding = nn.Embedding(n_vocab, embedding_dim)
        self.shared_lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)

        # Mixture of Experts for personalization
        self.experts = nn.ModuleList([nn.LSTM(hidden_dim, hidden_dim, num_layers=1, batch_first=True) for _ in range(num_experts)])
        self.expert_selector = nn.Linear(hidden_dim, num_experts)

        # Personalized local layers (only if personalization is enabled)
        self.personalized_lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=1, batch_first=True) if personalization else None

        self.fc = nn.Linear(hidden_dim, n_vocab)

    def forward(self, x, hidden=None, personalize=False):
        x = self.embedding(x)

        # Se hidden è None, inizializza lo stato nascosto
        if hidden is None:
            hidden = self.init_hidden(x.size(0), x.device)

        if personalize:
            x, hidden = self.personalized_lstm(x, hidden)
        else:
            expert_weights = torch.softmax(self.expert_selector(x[:, -1, :]), dim=-1)  # [batch_size, num_experts]
            expert_weights = expert_weights.unsqueeze(1).unsqueeze(1)  # [batch_size, 1, 1, num_experts]

            expert_outputs = torch.stack([expert(x)[0] for expert in self.experts], dim=-1)  
            # expert_outputs shape: [batch_size, seq_length, hidden_dim, num_experts]

            x = torch.sum(expert_weights * expert_outputs, dim=-1)  # Reduce over experts

        output = self.fc(x)
        return output, hidden

    def init_hidden(self, batch_size, device):
        return (torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device),
                torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device))


## Federated Training 

### Sample Clients and Create Shards

In [243]:
# Sample clients uniformly for a round of training.
def sample_clients_uniform(clients, fraction):
    """
    Sample a fraction of clients uniformly.
    Args:
    - clients: List of all clients.
    - fraction: Fraction of clients to sample.
    Returns:
    - A list of selected clients.
    """
    num_clients = clients
    num_selected = max(1, int(fraction * num_clients))  # Compute number of selected clients.
    selected = np.random.choice(range(num_clients), num_selected, replace=False)
    return selected.tolist()  # Convert to list for consistent indexing



# Sample clients skewed using Dirichlet distribution.
def sample_clients_skewed(clients, fraction, gamma):
    """
    Sample a fraction of clients based on Dirichlet distribution.
    Args:
    - clients: List of all clients.
    - fraction: Fraction of clients to sample.
    - gamma: Skewness parameter for Dirichlet distribution.
    Returns:
    - List of selected clients and their probabilities.
    """
    num_clients = clients
    num_selected = max(1, int(fraction * num_clients))
    probabilities = np.random.dirichlet([gamma] * num_clients)  # Generate skewed probabilities.
    selected_indices = np.random.choice(range(num_clients), num_selected, replace=False, p=probabilities)
    return selected_indices.tolist(), probabilities


# ====================
# Sharding for iid and non-iid splits
# ====================
def create_sharding(data, num_clients, num_classes=90):
    """Create data shards ensuring correct client count"""
    if len(data) == 0:
        raise ValueError("Empty dataset")

    client_data = []
    indices = np.random.permutation(len(data))
    batch_size = max(1, len(data) // num_clients)  # Ensure at least 1 sample per client
    remainder = len(data) % num_clients

    # IID sharding
    if num_clients == num_classes:
        start_idx = 0
        for i in range(num_clients):
            end_idx = start_idx + batch_size + (1 if i < remainder else 0)
            shard = Subset(data, indices[start_idx:end_idx])
            client_data.append(shard)
            start_idx = end_idx
    
    # Non-IID sharding
    else:
        targets = [t[0].item() for _, t in data]
        unique_targets = np.unique(targets)
        samples_per_client = max(1, len(data) // num_clients)
        
        # Create initial shards
        for i in range(num_clients):
            start_idx = i * samples_per_client
            end_idx = start_idx + samples_per_client + (1 if i < remainder else 0)
            client_data.append(Subset(data, indices[start_idx:end_idx]))

    assert len(client_data) == num_clients, f"Created {len(client_data)} shards, expected {num_clients}"
    return client_data

### Client

In [244]:
class Client:
    def __init__(self, data_loader, id_client, model, device):
        self.data = data_loader
        self.id_client = id_client
        self.model = model.to(device)
        self.device = device

    def train_local_model(self, criterion, optimizer, local_steps, l2_lambda=1e-4):
        self.model.train()
        total_loss, total_correct, total_samples = 0, 0, 0
        
        for _ in range(local_steps):
            for inputs, targets in self.data:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                hidden = self.model.init_hidden(inputs.size(0), self.device)
                hidden = tuple(h.to(self.device) for h in hidden)

                optimizer.zero_grad()
                
                outputs, hidden = self.model(inputs, hidden, personalize=True)
                loss = criterion(outputs.view(-1, self.model.n_vocab), targets.view(-1))
                
                # Apply L2 regularization
                loss += l2_lambda * sum(torch.norm(param) for param in self.model.parameters())

                # Backpropagation with retain_graph=False (default)
                loss.backward()

                optimizer.step()
                
                total_loss += loss.item()
                _, predictions = torch.max(outputs, -1)
                total_correct += (predictions == targets).sum().item()
                total_samples += targets.numel()

        avg_loss = total_loss / total_samples
        accuracy = (total_correct / total_samples) * 100

        print(f"Client {self.id_client}: Loss={avg_loss:.4f}, Acc={accuracy:.4f}")
        return self.model.state_dict(), avg_loss, accuracy

### Server

In [245]:
class Server:
    def __init__(self, test_loader, global_model, device):
        self.test_loader = test_loader
        self.global_model = global_model.to(device)
        self.device = device

    def federated_training(self, clients, rounds, lr, momentum, weight_decay, local_steps):
        
        global_losses = []
        global_accuracies = []

        for round_num in range(rounds):
            local_weights = [
                client.train_local_model(
                    nn.CrossEntropyLoss(), 
                    optim.SGD(client.model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay), 
                    local_steps)[0] for client in clients]
            self.aggregate_weights(local_weights)
            print(f"Round {round_num + 1}/{rounds} completed")
            # Evaluate after each round and store the results
            avg_loss, accuracy = self.evaluate_model()
            global_losses.append(avg_loss)
            global_accuracies.append(accuracy)

        return global_losses, global_accuracies

    def aggregate_weights(self, local_weights):
        global_state_dict = deepcopy(self.global_model.state_dict())
        for k in global_state_dict.keys():
            global_state_dict[k] = torch.stack([local[k] for local in local_weights]).mean(dim=0)
        self.global_model.load_state_dict(global_state_dict)

    def evaluate_model(self):
        self.global_model.eval()
        total_loss, total_correct, total_samples = 0, 0, 0
        criterion = nn.CrossEntropyLoss()
        with torch.no_grad():
            for inputs, targets in self.test_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs, _ = self.global_model(inputs)
                loss = criterion(outputs.view(-1, self.global_model.n_vocab), targets.view(-1))
                total_loss += loss.item()
                _, predictions = torch.max(outputs, -1)
                total_correct += (predictions == targets).sum().item()
                total_samples += targets.numel()
        avg_loss = total_loss / total_samples
        accuracy = (total_correct / total_samples) * 100
        print(f"Evaluation - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
        return avg_loss, accuracy


## Main

In [246]:
def main():
    # Dataset and training configurations
    data_path = "shakespeare.txt"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available
    epochs = 20  # Number of epochs for centralized training
    fraction = 0.1  # Fraction of clients to select each round
    seq_length = 80  # Sequence length for LSTM inputs   
    batch_size = 64 # For local training
    n_vocab = 90 # Character number in vobulary (ASCII)
    embedding_size = 8
    hidden_dim = 256
    train_split = 0.8
    # momentum = 0 # TODO ask to TA if is correct 0
    momentum = 0.9
    learning_rate = 0.01
    weight_decay = 1e-4
    C = 0.1

    # Load data
    train_data, test_data = parse_shakespeare(data_path, train_split)

    # Centralized Dataset Preparation
    train_dataset = ShakespeareDataset(train_data, seq_length=seq_length, n_vocab=n_vocab)
    test_dataset = ShakespeareDataset(test_data, seq_length=seq_length, n_vocab=n_vocab)
    train_size = int(0.9 * len(train_dataset))  # 90% of data for training
    val_size = len(train_dataset) - train_size  # 10% of data for validation
    train_dataset, validation_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    

    # EXPERIMENTS

    local_steps = [4, 8, 16] #what is called J -> # Number of local training steps
    # Scale the number of rounds inversely with J to maintain a constant computational budget
    num_rounds = {4: 200, 8: 100, 16: 50} # Number of federated communication rounds


    num_clients = 10
    global_model = FLTailoredCharLSTM(n_vocab=90, embedding_dim=256, hidden_dim=256, num_experts=3, personalization=True)

    clients = [Client(train_loader, i, deepcopy(global_model), device) for i in range(num_clients)]

    server = Server(test_loader, global_model, device)
    server.federated_training(clients, rounds=5, lr=0.01, momentum=0.9, weight_decay=1e-4, local_steps=4)
    server.evaluate_model()
    print("Federated learning simulation completed.")


if __name__ == "__main__":
    main()


Training examples: 85924
Testing examples: 20794
Client 0: Loss=0.0005, Acc=54.9860
Client 1: Loss=0.0005, Acc=54.9567
Client 2: Loss=0.0005, Acc=54.9779
Client 3: Loss=0.0005, Acc=54.9776
Client 4: Loss=0.0005, Acc=54.9958
Client 5: Loss=0.0005, Acc=54.9836
Client 6: Loss=0.0005, Acc=54.9176
Client 7: Loss=0.0005, Acc=55.0290
Client 8: Loss=0.0005, Acc=54.9287
Client 9: Loss=0.0005, Acc=54.9827
Round 1/5 completed
Evaluation - Loss: 0.0009, Accuracy: 3.63%
Client 0: Loss=0.0004, Acc=61.5715
Client 1: Loss=0.0004, Acc=61.5568
Client 2: Loss=0.0004, Acc=61.5550
Client 3: Loss=0.0004, Acc=61.5619
Client 4: Loss=0.0004, Acc=61.5667
Client 5: Loss=0.0004, Acc=61.5637
Client 6: Loss=0.0004, Acc=61.5709
Client 7: Loss=0.0004, Acc=61.5691
Client 8: Loss=0.0004, Acc=61.5888
Client 9: Loss=0.0004, Acc=61.5795
Round 2/5 completed
Evaluation - Loss: 0.0009, Accuracy: 6.64%
Client 0: Loss=0.0003, Acc=62.6477
Client 1: Loss=0.0003, Acc=62.6537
Client 2: Loss=0.0003, Acc=62.6322
Client 3: Loss=0.000