# AML Project 5 - Federated Learning - Track B



## Setup & Installation

In [49]:
# %%capture
# %!pip install torch torchvision tqdm tensorboard

In [50]:
import copy
import dataclasses
import logging
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple

import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from torchvision.datasets.cifar import CIFAR100
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm


In [51]:
class TqdmLoggingHandler(logging.Handler):
    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.write("\r\033[K" + msg)
            self.flush()
        except Exception:
            self.handleError(record)


class ColoredFormatter(logging.Formatter):
    COLORS = {
        "DEBUG": "\033[1;34m",
        "INFO": "\033[1;32m",
        "WARNING": "\033[1;33m",
        "ERROR": "\033[1;31m",
        "CRITICAL": "\033[1;35m",
        "RESET": "\033[0m",
    }

    def format(self, record):
        levelname = record.levelname
        if levelname in self.COLORS:
            record.levelname = (
                f"{self.COLORS[levelname]}{levelname}{self.COLORS['RESET']}"
            )
        return super().format(record)


def setup_logging(level=logging.INFO):
    root_logger = logging.getLogger()
    root_logger.handlers.clear()
    root_logger.setLevel(level)

    tqdm_handler = TqdmLoggingHandler()
    formatter = ColoredFormatter(
        fmt="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
    )
    tqdm_handler.setFormatter(formatter)
    root_logger.addHandler(tqdm_handler)


setup_logging()

In [52]:
@dataclass(frozen=True)
class BaseConfig:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    CPU_COUNT = os.cpu_count() or 1
    NUM_WORKERS = min(4, CPU_COUNT)
    SEED = 42

    # Paths
    ROOT_DIR: Path = Path.cwd()
    CONFIGS_DIR: Path = ROOT_DIR / "configs"
    DATA_DIR: Path = ROOT_DIR / "data"
    MODELS_DIR: Path = ROOT_DIR / "models"
    RESULTS_DIR: Path = ROOT_DIR / "results"
    RUNS_DIR: Path = ROOT_DIR / "runs"
    OLD_RUNS_DIR: Path = RUNS_DIR / "old_runs"

    # Training Parameters
    BATCH_SIZE = 64
    LEARNING_RATE = 0.01
    NUM_EPOCHS = 20
    MOMENTUM = 0.9
    WEIGHT_DECAY = 4e-4
    NUM_CLASSES = 100


# Create directories
config = BaseConfig()
for dir_path in [
    config.DATA_DIR,
    config.MODELS_DIR,
    config.RESULTS_DIR,
    config.CONFIGS_DIR,
    config.RUNS_DIR,
    config.OLD_RUNS_DIR,
]:
    dir_path.mkdir(parents=True, exist_ok=True)


In [53]:
@dataclass(frozen=True)
class FederatedConfig(BaseConfig):
    """Federated Learning specific configuration."""

    NUM_CLIENTS: int = 100
    PARTICIPATION_RATE: float = 0.1
    LOCAL_EPOCHS: int = 4
    NUM_ROUNDS: int = 2000
    CLASSES_PER_CLIENT: Optional[int] = None  # None for IID
    PARTICIPATION_MODE: str = "uniform"
    DIRICHLET_ALPHA: Optional[float] = None

## Model

In [54]:
class LeNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.connected = nn.Sequential(
            nn.Linear(5 * 5 * 64, 384),
            nn.ReLU(),
            nn.Linear(384, 192),
            nn.ReLU(),
        )

        self.classifier = nn.Linear(192, config.NUM_CLASSES)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = x.view(x.size(0), -1)
        x = self.connected(x)
        x = self.classifier(x)
        return x

In [55]:
class MetricsManager:
    """Manages logging and visualization of training metrics."""

    def __init__(
        self,
        config: BaseConfig,
        model_name: str,
        training_type: Literal["centralized", "federated"],
        experiment_name: Optional[str] = None,
    ):
        self.config = config
        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

        # Archive old runs
        old_runs = list(config.RUNS_DIR.glob(f"{training_type}_{model_name}_*"))
        if old_runs:
            archive_dir = config.OLD_RUNS_DIR
            archive_dir.mkdir(exist_ok=True)
            for run in old_runs:
                run.rename(archive_dir / run.name)

        # Create descriptive run name for FL experiments
        if training_type == "federated" and isinstance(config, FederatedConfig):
            distribution = (
                "iid"
                if config.CLASSES_PER_CLIENT is None
                else f"noniid_{config.CLASSES_PER_CLIENT}cls"
            )
            participation = f"{config.PARTICIPATION_MODE}"
            if config.PARTICIPATION_MODE == "skewed":
                participation += f"_alpha{config.DIRICHLET_ALPHA}"
            clients_info = f"C{config.NUM_CLIENTS}_P{config.PARTICIPATION_RATE}_E{config.LOCAL_EPOCHS}"
            experiment_suffix = f"{distribution}_{participation}_{clients_info}"
        else:
            experiment_suffix = experiment_name if experiment_name else timestamp

        run_name = f"{training_type}_{model_name}_{experiment_suffix}"
        self.writer = SummaryWriter(config.RUNS_DIR / run_name)

    def log_metrics(
        self,
        split: Literal["train", "validation", "test"],
        loss: float,
        accuracy: float,
        step: int,
    ) -> None:
        """Log metrics for specified split."""
        self.writer.add_scalars("metrics/loss", {split: loss}, step)
        self.writer.add_scalars("metrics/accuracy", {split: accuracy}, step)

    def log_fl_metrics(
        self, round_idx: int, metrics: Dict, client_stats: Optional[Dict] = None
    ) -> None:
        """Log federated learning specific metrics."""
        # Log test metrics
        self.log_metrics(
            "test", metrics["test_loss"], metrics["test_accuracy"], round_idx
        )

        # Log client participation if available
        if client_stats:
            self.writer.add_scalars("federated/client_stats", client_stats, round_idx)

    def close(self) -> None:
        """Close TensorBoard writer."""
        self.writer.close()

In [56]:
class Cifar100DatasetManager:
    config: BaseConfig
    validation_split: float
    train_transform: transforms.Compose
    test_transform: transforms.Compose
    train_loader: DataLoader[CIFAR100]
    val_loader: DataLoader[CIFAR100]
    test_loader: DataLoader[CIFAR100]

    def __init__(self, config: BaseConfig, validation_split: float = 0.1) -> None:
        self.config = config
        self.validation_split = validation_split

        self.train_transform = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]
                ),
            ]
        )

        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]
                ),
            ]
        )

        self.train_loader, self.val_loader, self.test_loader = self._prepare_data()

    def _prepare_data(
        self,
    ) -> Tuple[DataLoader[CIFAR100], DataLoader[CIFAR100], DataLoader[CIFAR100]]:
        full_trainset: CIFAR100 = CIFAR100(
            root=self.config.DATA_DIR,
            train=True,
            download=True,
            transform=self.train_transform,
        )

        train_size: int = int((1 - self.validation_split) * len(full_trainset))
        val_size: int = len(full_trainset) - train_size

        train_dataset, val_dataset = random_split(
            full_trainset,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(self.config.SEED),
        )

        test_dataset: CIFAR100 = CIFAR100(
            root=self.config.DATA_DIR,
            train=False,
            download=False,
            transform=self.test_transform,
        )

        loader_kwargs = {"num_workers": self.config.NUM_WORKERS, "pin_memory": True}

        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.BATCH_SIZE,
            shuffle=True,
            **loader_kwargs,
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.BATCH_SIZE,
            shuffle=False,
            **loader_kwargs,
        )

        test_loader: DataLoader[CIFAR100] = DataLoader(
            test_dataset,
            batch_size=self.config.BATCH_SIZE,
            shuffle=False,
            **loader_kwargs,
        )

        return train_loader, val_loader, test_loader

    @property
    def train_dataset(self) -> Dataset[CIFAR100]:
        return self.train_loader.dataset

    @property
    def val_dataset(self) -> Dataset[CIFAR100]:
        return self.val_loader.dataset

    @property
    def test_dataset(self) -> Dataset[CIFAR100]:
        return self.test_loader.dataset

In [57]:
class CentralizedTrainer:
    model: LeNet
    config: BaseConfig
    device: torch.device
    metrics: MetricsManager

    def __init__(self, model: LeNet, config: BaseConfig) -> None:
        self.model = model.to(config.DEVICE)
        self.config = config
        self.device = config.DEVICE
        self.metrics = MetricsManager(
            config, model.__class__.__name__.lower(), "centralized"
        )

    def evaluate_model(
        self, model: LeNet, data_loader: DataLoader[CIFAR100]
    ) -> Tuple[float, float]:
        model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for inputs, targets in data_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                total_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        avg_loss: float = total_loss / total
        accuracy: float = 100.0 * correct / total
        return avg_loss, accuracy

    def train(
        self,
        train_loader: DataLoader[CIFAR100],
        val_loader: DataLoader[CIFAR100],
        test_loader: DataLoader[CIFAR100],
        max_epochs: Optional[int] = None,
    ) -> None:
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.config.LEARNING_RATE,
            momentum=self.config.MOMENTUM,
            weight_decay=self.config.WEIGHT_DECAY,
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=self.config.NUM_EPOCHS
        )

        if max_epochs is None:
            max_epochs = self.config.NUM_EPOCHS

        best_val_loss = float("inf")
        best_model_state = None
        patience = 10
        patience_counter = 0
        train_acc = 0.0
        avg_train_loss = 0.0

        epoch_pbar = tqdm(
            range(max_epochs or self.config.NUM_EPOCHS),
            desc="Training",
            unit="epoch",
            position=0,
            leave=True,
        )
        epoch = 0

        try:
            for epoch in epoch_pbar:
                self.model.train()
                train_loss = 0
                correct = 0
                total = 0

                # batch_pbar = tqdm(
                #     train_loader,
                #     desc=f"Epoch {epoch}",
                #     colour="yellow",
                #     unit="batch",
                #     leave=True,
                #     position=1,
                #     bar_format='{l_bar}{bar:20}{r_bar}{bar:-10b}'
                # )
                for batch_idx, (inputs, targets) in enumerate(train_loader):
                    inputs, targets = inputs.to(self.device), targets.to(self.device)

                    optimizer.zero_grad()
                    outputs = self.model(inputs)
                    loss = criterion(outputs, targets)
                    loss.backward()
                    optimizer.step()

                    train_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += targets.size(0)
                    correct += predicted.eq(targets).sum().item()

                    # Update metrics
                    train_acc = 100.0 * correct / total
                    avg_train_loss = train_loss / (batch_idx + 1)

                    global_step = epoch * len(train_loader) + batch_idx
                    self.metrics.log_metrics(
                        "train", avg_train_loss, train_acc, global_step
                    )

                    # batch_pbar.set_postfix(
                    #     {"loss": f"{avg_train_loss:.3f}", "acc": f"{train_acc:.2f}%"}
                    # )

                # Validation phase
                val_loss, val_acc = self.evaluate_model(self.model, val_loader)
                scheduler.step()

                self.metrics.log_metrics("validation", val_loss, val_acc, epoch)
                epoch_pbar.set_postfix(
                    {
                        "ep": f"{epoch+1}/{max_epochs or self.config.NUM_EPOCHS}",
                        "tr_loss": f"{avg_train_loss:.3f}",
                        "tr_acc": f"{train_acc:.1f}%",
                        "val_loss": f"{val_loss:.3f}",
                        "val_acc": f"{val_acc:.1f}%",
                    },
                    refresh=True,
                )

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_model_state = self.model.state_dict().copy()
                    patience_counter = 0
                else:
                    patience_counter += 1

                if patience_counter >= patience:
                    logging.info(f"Early stopping at epoch {epoch}")
                    break

            else:
                logging.info("Training completed!")
            # Final evaluation
            if best_model_state is not None:
                self.model.load_state_dict(best_model_state)

            test_loss, test_acc = self.evaluate_model(self.model, test_loader)
            self.metrics.log_metrics("test", test_loss, test_acc, epoch)
            logging.info(
                f"Final Test Results - Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%"
            )

        finally:
            self.metrics.close()

In [58]:
class DataSharder:
    def create_iid_shards(
        self, dataset: Dataset, num_clients: int
    ) -> List[Subset[CIFAR100]]:
        # FIXME: use numpy as in the non-iid case
        total_size = len(dataset)
        shard_size = total_size // num_clients
        indices = torch.randperm(total_size).tolist()
        return [
            Subset(dataset, indices[i : i + shard_size])
            for i in range(0, total_size, shard_size)
        ]

    def create_noniid_shards(
        self,
        dataset: Dataset[CIFAR100] | Subset[CIFAR100],
        num_clients: int,
        classes_per_client: int,
    ) -> List[Subset[CIFAR100]]:
        # Get labels - handle both CIFAR100 and Subset cases
        if isinstance(dataset, CIFAR100):
            labels = torch.tensor(dataset.targets)
        else:
            # For Subset, get the dataset's targets using indices
            labels = torch.tensor([dataset.dataset.targets[i] for i in dataset.indices])
        # Get labels directly from CIFAR100 targets
        labels = torch.tensor(dataset.targets)
        class_indices = {i: [] for i in range(100)}

        for idx, label in enumerate(labels):
            class_indices[int(label)].append(idx)  # Convert to int explicitly

        client_indices = [[] for _ in range(num_clients)]
        classes = list(range(100))

        for client_id in range(num_clients):
            client_classes = np.random.choice(
                classes, size=classes_per_client, replace=False
            )
            for class_id in client_classes:
                # Convert numpy types to Python int
                class_id_int = int(class_id)
                samples_per_client = len(class_indices[class_id_int]) // num_clients

                client_indices[client_id].extend(
                    np.random.choice(
                        class_indices[class_id_int],
                        size=samples_per_client,
                        replace=False,
                    ).tolist()  # Convert to list for Subset
                )

        return [Subset(dataset, indices) for indices in client_indices]


class ClientManager:
    def __init__(
        self, num_clients, participation_rate, mode="uniform", dirichlet_alpha=None
    ) -> None:
        self.num_clients = num_clients
        self.num_selected = int(participation_rate * num_clients)

        if mode == "skewed":
            if dirichlet_alpha is None:
                raise ValueError("dirichlet_alpha required for skewed mode")
            self.selection_probs = np.random.dirichlet([dirichlet_alpha] * num_clients)
        else:
            self.selection_probs = np.ones(num_clients) / num_clients

    def select_clients(self) -> npt.NDArray[np.int64]:
        return np.random.choice(
            self.num_clients,
            size=self.num_selected,
            replace=False,
            p=self.selection_probs,
        )


class FederatedClient:
    config: FederatedConfig
    model: LeNet
    client_id: int
    train_loader: DataLoader[CIFAR100]
    local_epochs: int
    device: torch.device

    def __init__(
        self,
        client_id: int,
        model: LeNet,
        train_loader: DataLoader[CIFAR100],
        config: FederatedConfig,
        local_epochs: int,
    ) -> None:
        self.config = config
        self.model = copy.deepcopy(model)
        self.client_id = client_id
        self.train_loader = train_loader
        self.local_epochs = local_epochs
        self.device = config.DEVICE

    # TODO: call the centralized trainer?
    def train(self):
        self.model.train()
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.config.LEARNING_RATE,
            momentum=self.config.MOMENTUM,
            weight_decay=self.config.WEIGHT_DECAY,
        )

        for _ in range(self.local_epochs):
            for inputs, targets in self.train_loader:
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

        return self.model


class FederatedServer:
    global_model: LeNet
    client_manager: ClientManager
    test_loader: DataLoader[CIFAR100]
    config: FederatedConfig
    device: torch.device

    def __init__(
        self,
        model: LeNet,
        client_manager: ClientManager,
        test_loader: DataLoader[CIFAR100],
        config: FederatedConfig,
    ) -> None:
        self.global_model = model
        self.client_manager = client_manager
        self.test_loader = test_loader
        self.config = config
        self.device = config.DEVICE

    def aggregate_models(self, client_models):
        global_dict = self.global_model.state_dict()

        for k in global_dict.keys():
            global_dict[k] = torch.stack(
                [
                    client_model.state_dict()[k].float()
                    for client_model in client_models
                ],
                0,
            ).mean(0)

        self.global_model.load_state_dict(global_dict)

    def evaluate(self):
        return CentralizedTrainer(self.global_model, self.config).evaluate_model(
            self.global_model, self.test_loader
        )


class EnhancedFederatedTrainer:
    config: FederatedConfig
    model: LeNet
    device: torch.device
    metrics: MetricsManager
    sharder: DataSharder
    client_manager: ClientManager
    server: FederatedServer
    client_loaders: List[DataLoader[CIFAR100]]

    def __init__(
        self,
        model: LeNet,
        train_dataset: Dataset[CIFAR100],
        test_loader: DataLoader[CIFAR100],
        config: FederatedConfig,
    ) -> None:
        self.config = config
        self.model = model.to(config.DEVICE)
        self.device = config.DEVICE
        self.metrics = MetricsManager(
            config, model.__class__.__name__.lower(), "federated"
        )

        # Setup data sharding
        self.sharder = DataSharder()
        shards: List[Subset[CIFAR100]] = (
            self.sharder.create_noniid_shards(
                train_dataset, config.NUM_CLIENTS, config.CLASSES_PER_CLIENT
            )
            if config.CLASSES_PER_CLIENT
            else self.sharder.create_iid_shards(train_dataset, config.NUM_CLIENTS)
        )

        self.client_loaders = [
            DataLoader(
                shard,
                batch_size=config.BATCH_SIZE,
                shuffle=True,
                num_workers=config.NUM_WORKERS,
                pin_memory=True,
            )
            for shard in shards
        ]

        self.client_manager = ClientManager(
            config.NUM_CLIENTS, config.PARTICIPATION_RATE
        )

        self.server = FederatedServer(model, self.client_manager, test_loader, config)

    def train(self) -> None:
        round_pbar = tqdm(
            range(self.config.NUM_ROUNDS),
            desc="Training FL",
            colour="blue",
            unit="round",
            position=0,
            leave=True,
        )
        try:
            for round_idx in round_pbar:
                selected_clients = self.client_manager.select_clients()

                client_models = []
                for client_idx in selected_clients:
                    client = FederatedClient(
                        client_idx,
                        self.server.global_model,
                        self.client_loaders[client_idx],
                        self.config,
                        self.config.LOCAL_EPOCHS,
                    )
                    client_models.append(client.train())

                self.server.aggregate_models(client_models)

                test_loss, accuracy = self.server.evaluate()
                self.metrics.log_metrics("test", test_loss, accuracy, round_idx)

                if round_idx % 10 == 0:
                    logging.info(
                        f"Round {round_idx}/{self.config.NUM_ROUNDS}: "
                        f"test_loss: {test_loss:.4f}, "
                        f"test_acc: {accuracy:.2f}%"
                    )

        finally:
            self.metrics.close()

In [59]:
data = Cifar100DatasetManager(config)

Files already downloaded and verified


In [60]:
model = LeNet(config)
trainer = CentralizedTrainer(model, config)
trainer.train(data.train_loader, data.val_loader, data.test_loader)


Training: 100%|██████████| 20/20 [30:04<00:00, 90.25s/epoch, ep=20/20, tr_loss=1.574, tr_acc=56.4%, val_loss=2.083, val_acc=46.5%] 


[K2025-01-14 13:38:19 [[1;32mINFO[0m] Training completed!
[K2025-01-14 13:38:32 [[1;32mINFO[0m] Final Test Results - Loss: 1.9105, Accuracy: 50.30%


In [61]:
fed_config = FederatedConfig()

In [62]:
fed_model = LeNet(config)
fed_trainer = EnhancedFederatedTrainer(
    model=fed_model,
    train_dataset=data.train_dataset,
    test_loader=data.test_loader,
    config=fed_config,
)
fed_trainer.train()

Training:   0%|[34m          [0m| 0/2000 [05:26<?, ?round/s]


KeyboardInterrupt: 

In [None]:
non_iid_config: FederatedConfig = dataclasses.replace(
    fed_config,
    CLASSES_PER_CLIENT=5,
)

# Train with non-iid distribution
non_iid_model = LeNet(config)
non_iid_trainer = EnhancedFederatedTrainer(
    model=non_iid_model,
    train_dataset=data.train_loader.dataset,
    test_loader=data.test_loader,
    config=non_iid_config,
)
non_iid_trainer.train()


In [None]:
skewed_config = dataclasses.replace(
    fed_config, PARTICIPATION_MODE="skewed", DIRICHLET_ALPHA=0.5
)

skewed_model = LeNet(config)
skewed_trainer = EnhancedFederatedTrainer(
    model=skewed_model,
    train_dataset=data.train_loader.dataset,
    test_loader=data.test_loader,
    config=skewed_config,
)
skewed_trainer.train()

## Experiments

In [63]:
local_epochs_configs: list[FederatedConfig] = [
    dataclasses.replace(fed_config, LOCAL_EPOCHS=e) for e in [4, 8, 16]
]

# Experiment with different client counts
client_counts_configs = [
    dataclasses.replace(fed_config, NUM_CLIENTS=c) for c in [50, 100, 200]
]

# Experiment with different participation rates
participation_rates_configs = [
    dataclasses.replace(fed_config, PARTICIPATION_RATE=r) for r in [0.05, 0.1, 0.2]
]

In [71]:
import itertools

# Combine all configurations

# Get all combinations of parameters
all_configs = list(
    itertools.product(
        [4, 8, 16],  # LOCAL_EPOCHS
        [50, 100, 200],  # NUM_CLIENTS
        [0.05, 0.1, 0.2],  # PARTICIPATION_RATE
    )
)

# Create FederatedConfig objects for each combination
all_experiment_configs = [
    dataclasses.replace(
        fed_config, LOCAL_EPOCHS=epochs, NUM_CLIENTS=clients, PARTICIPATION_RATE=rate
    )
    for epochs, clients, rate in all_configs
]
print(len(all_experiment_configs))
# Print all combinations
for i, config in enumerate(all_experiment_configs):
    print(i, 
        f"LOCAL_EPOCHS={config.LOCAL_EPOCHS}, "
        f"NUM_CLIENTS={config.NUM_CLIENTS}, "
        f"PARTICIPATION_RATE={config.PARTICIPATION_RATE}"
    )


27
0 LOCAL_EPOCHS=4, NUM_CLIENTS=50, PARTICIPATION_RATE=0.05
1 LOCAL_EPOCHS=4, NUM_CLIENTS=50, PARTICIPATION_RATE=0.1
2 LOCAL_EPOCHS=4, NUM_CLIENTS=50, PARTICIPATION_RATE=0.2
3 LOCAL_EPOCHS=4, NUM_CLIENTS=100, PARTICIPATION_RATE=0.05
4 LOCAL_EPOCHS=4, NUM_CLIENTS=100, PARTICIPATION_RATE=0.1
5 LOCAL_EPOCHS=4, NUM_CLIENTS=100, PARTICIPATION_RATE=0.2
6 LOCAL_EPOCHS=4, NUM_CLIENTS=200, PARTICIPATION_RATE=0.05
7 LOCAL_EPOCHS=4, NUM_CLIENTS=200, PARTICIPATION_RATE=0.1
8 LOCAL_EPOCHS=4, NUM_CLIENTS=200, PARTICIPATION_RATE=0.2
9 LOCAL_EPOCHS=8, NUM_CLIENTS=50, PARTICIPATION_RATE=0.05
10 LOCAL_EPOCHS=8, NUM_CLIENTS=50, PARTICIPATION_RATE=0.1
11 LOCAL_EPOCHS=8, NUM_CLIENTS=50, PARTICIPATION_RATE=0.2
12 LOCAL_EPOCHS=8, NUM_CLIENTS=100, PARTICIPATION_RATE=0.05
13 LOCAL_EPOCHS=8, NUM_CLIENTS=100, PARTICIPATION_RATE=0.1
14 LOCAL_EPOCHS=8, NUM_CLIENTS=100, PARTICIPATION_RATE=0.2
15 LOCAL_EPOCHS=8, NUM_CLIENTS=200, PARTICIPATION_RATE=0.05
16 LOCAL_EPOCHS=8, NUM_CLIENTS=200, PARTICIPATION_RATE=0.1
17 L