In [1]:
! pip install torch torchvision medmnist flwr[simulation]



In [2]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context


from medmnist import INFO, Evaluator
import medmnist

In [3]:
# Global Variables for Datasets
dataset_info = INFO['pathmnist']
DataClass = getattr(medmnist, dataset_info['python_class'])

# Transformations
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# Load Full Datasets Globally
full_train_dataset = DataClass(split='train', transform=data_transforms, download=True)
val_dataset = DataClass(split='val', transform=data_transforms, download=True)
test_dataset = DataClass(split='test', transform=data_transforms, download=True)


  and should_run_async(code)


Using downloaded and verified file: /root/.medmnist/pathmnist.npz
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
Using downloaded and verified file: /root/.medmnist/pathmnist.npz


In [4]:
print(f"Train samples: {len(full_train_dataset)}, Validation samples: {len(val_dataset)}, Test samples: {len(test_dataset)}")

Train samples: 89996, Validation samples: 10004, Test samples: 7180


In [5]:
NUM_PARTITIONS = 5
BATCH_SIZE = 128
LOCAL_EPOCHS = 5
NUM_ROUNDS = 6


def load_datasets(partition_id: int, num_partitions: int):
    # Partition the training dataset into NUM_PARTITIONS
    def partition_dataset(dataset, num_partitions):
        dataset_size = len(dataset)
        indices = np.random.permutation(dataset_size)  # Shuffle dataset indices
        partition_size = dataset_size // num_partitions  # Size of each partition
        partitions = [
            Subset(dataset, indices[i * partition_size: (i + 1) * partition_size])
            for i in range(num_partitions)
        ]
        return partitions

    # Partition training and validation datasets
    client_train_datasets = partition_dataset(full_train_dataset, num_partitions)
    client_train_dataset = client_train_datasets[partition_id]

    client_val_datasets = partition_dataset(val_dataset, num_partitions)
    client_val_dataset = client_val_datasets[partition_id]

    # Create DataLoaders
    trainloader = DataLoader(client_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    valloader = DataLoader(client_val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return trainloader, valloader, testloader

In [6]:
# Example: Load data for client 0
partition_id = 4  # Specify client ID
trainloader, valloader, testloader = load_datasets(partition_id=partition_id, num_partitions=NUM_PARTITIONS)

# Print dataset sizes
print(f"Client {partition_id} Train Samples: {len(trainloader.dataset)}")
print(f"Client {partition_id} Validation Samples: {len(valloader.dataset)}")
print(f"Test Samples (Global): {len(testloader.dataset)}")


Client 4 Train Samples: 17999
Client 4 Validation Samples: 2000
Test Samples (Global): 7180


In [7]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # First convolution
        self.pool = nn.MaxPool2d(2, 2)   # Pooling
        self.conv2 = nn.Conv2d(6, 16, 5)  # Second convolution
        self.fc1 = nn.Linear(16 * 4 * 4, 120)  # Fully connected layer 1
        self.fc2 = nn.Linear(120, 84)         # Fully connected layer 2
        self.fc3 = nn.Linear(84, 10)         # Fully connected layer 3 (output)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))  # Conv1 + ReLU + Pooling
        x = self.pool(F.relu(self.conv2(x)))  # Conv2 + ReLU + Pooling
        x = x.view(x.size(0), -1)             # Flatten feature map
        x = F.relu(self.fc1(x))               # Fully connected layer 1
        x = F.relu(self.fc2(x))               # Fully connected layer 2
        x = self.fc3(x)                       # Output layer
        return x


def get_parameters(net: nn.Module) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net: nn.Module, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net: nn.Module, trainloader: torch.utils.data.DataLoader, epochs: int):
    """Train the network on the training set."""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:  # Updated for PathMNIST DataLoader
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            labels = labels.squeeze().long()
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item()
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss:.4f}, accuracy {epoch_acc * 100:.2f}%")


def test(net: nn.Module, testloader: torch.utils.data.DataLoader):
    """Evaluate the network on the entire test set."""
    criterion = nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:  # Updated for PathMNIST DataLoader
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            labels = labels.squeeze().long()
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader)
    accuracy = correct / total
    return loss, accuracy


In [8]:
class FlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        server_round = config["server_round"]
        local_epochs = config["local_epochs"]

        print(f"[Client {self.partition_id}, round {server_round}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=local_epochs)
        return get_parameters(self.net), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader.dataset), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    # Initialize model
    net = Net().to(DEVICE)

    # Fetch data partition information from context
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    # Load datasets for this client
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [9]:
# The `evaluate` function will be called by Flower after every round
def evaluate(
    server_round: int,
    parameters: NDArrays,
    config: Dict[str, Scalar],
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
    # Initialize the model
    net = Net().to(DEVICE)

    # Load the test dataset (shared across all clients)
    _, _, testloader = load_datasets(0, NUM_PARTITIONS)  # Client ID 0 is irrelevant here

    # Set the model's parameters to the latest global parameters
    set_parameters(net, parameters)

    # Evaluate the model on the test set
    loss, accuracy = test(net, testloader)

    # Log and return metrics
    print(f"[Server Round {server_round}] Evaluation - Loss: {loss:.4f}, Accuracy: {accuracy * 100:.2f}%")
    return loss, {"accuracy": accuracy}


In [10]:
def fit_config(server_round: int):
    config = {
        "server_round": server_round,
        "local_epochs": 1 if server_round < 2 else LOCAL_EPOCHS,
    }
    return config

In [11]:
params = get_parameters(Net())

In [12]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy with specified configurations
    strategy = FedAvg(
        fraction_fit=0.3,                # Fraction of clients selected for training in each round
        fraction_evaluate=0.3,          # Fraction of clients selected for evaluation
        min_fit_clients=3,              # Minimum number of clients required for training
        min_evaluate_clients=3,         # Minimum number of clients required for evaluation
        min_available_clients=NUM_PARTITIONS,  # Total number of clients (NUM_PARTITIONS)
        initial_parameters=ndarrays_to_parameters(params),  # Pass initial model parameters
        evaluate_fn=evaluate,
        on_fit_config_fn=fit_config
    )

    # Configure the server for multiple rounds of federated learning
    config = ServerConfig(num_rounds=NUM_ROUNDS)  # Change `num_rounds` as needed

    print(f"Server initialized with {NUM_PARTITIONS} clients and {config.num_rounds} rounds.")
    return ServerAppComponents(strategy=strategy, config=config)


In [13]:
server = ServerApp(server_fn=server_fn)

In [None]:
backend_config = {"client_resources": {"num_cpus": 2}}  # Default: 2 CPUs per client
if DEVICE.type == "cuda":
    print("CUDA is available. Allocating 1 GPU per client.")
    backend_config["client_resources"]["num_gpus"] = 1

# Start simulation
print(f"Starting federated simulation with {NUM_PARTITIONS} clients...")
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,  # Number of clients participating
    backend_config=backend_config,
)

DEBUG:flwr:Asyncio event loop already running.
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=6, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters


Starting federated simulation with 5 clients...
Server initialized with 5 clients and 6 rounds.


[92mINFO [0m:      initial parameters (loss, other metrics): 2.289621457718966, {'accuracy': 0.18635097493036212}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


[Server Round 0] Evaluation - Loss: 2.2896, Accuracy: 18.64%


[36m(pid=25396)[0m 2024-11-20 01:10:18.820010: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=25396)[0m 2024-11-20 01:10:18.844531: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=25396)[0m 2024-11-20 01:10:18.852654: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[36m(ClientAppActor pid=25396)[0m [Client 0, round 1] fit, config: {'server_round': 1, 'local_epochs': 1}
[36m(ClientAppActor pid=25396)[0m Epoch 1: train loss 1.8442, accuracy 28.19%
[36m(ClientAppActor pid=25396)[0m [Client 1, round 1] fit, config: {'server_round': 1, 'local_epochs': 1}
[36m(ClientAppActor pid=25396)[0m Epoch 1: train loss 1.7836, accuracy 32.30%
[36m(ClientAppActor pid=25396)[0m [Client 3, round 1] fit, config: {'server_round': 1, 'local_epochs': 1}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


[36m(ClientAppActor pid=25396)[0m Epoch 1: train loss 1.7555, accuracy 32.50%


[92mINFO [0m:      fit progress: (1, 2.665029555036311, {'accuracy': 0.31142061281337047}, 50.81896580600005)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 5)


[Server Round 1] Evaluation - Loss: 2.6650, Accuracy: 31.14%
[36m(ClientAppActor pid=25396)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=25396)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=25396)[0m [Client 3] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


[36m(ClientAppActor pid=25396)[0m [Client 0, round 2] fit, config: {'server_round': 2, 'local_epochs': 5}
[36m(ClientAppActor pid=25396)[0m Epoch 1: train loss 1.5288, accuracy 41.83%
[36m(ClientAppActor pid=25396)[0m Epoch 2: train loss 1.1937, accuracy 53.78%
[36m(ClientAppActor pid=25396)[0m Epoch 3: train loss 1.0785, accuracy 57.98%
[36m(ClientAppActor pid=25396)[0m Epoch 4: train loss 0.9978, accuracy 62.08%
[36m(ClientAppActor pid=25396)[0m Epoch 5: train loss 0.9385, accuracy 64.66%
[36m(ClientAppActor pid=25396)[0m [Client 2, round 2] fit, config: {'server_round': 2, 'local_epochs': 5}
[36m(ClientAppActor pid=25396)[0m Epoch 1: train loss 1.4814, accuracy 43.72%
[36m(ClientAppActor pid=25396)[0m Epoch 2: train loss 1.1688, accuracy 54.60%
[36m(ClientAppActor pid=25396)[0m Epoch 3: train loss 1.0863, accuracy 57.29%
[36m(ClientAppActor pid=25396)[0m Epoch 4: train loss 0.9930, accuracy 62.05%
[36m(ClientAppActor pid=25396)[0m Epoch 5: train loss 0.9494, 

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


[36m(ClientAppActor pid=25396)[0m Epoch 5: train loss 0.9417, accuracy 64.42%


[92mINFO [0m:      fit progress: (2, 0.9556933861029776, {'accuracy': 0.68008356545961}, 215.45561791199998)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 5)


[Server Round 2] Evaluation - Loss: 0.9557, Accuracy: 68.01%
[36m(ClientAppActor pid=25396)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=25396)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=25396)[0m [Client 4] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


[36m(ClientAppActor pid=25396)[0m [Client 0, round 3] fit, config: {'server_round': 3, 'local_epochs': 5}
[36m(ClientAppActor pid=25396)[0m Epoch 1: train loss 0.9210, accuracy 65.65%
[36m(ClientAppActor pid=25396)[0m Epoch 2: train loss 0.8491, accuracy 67.97%
[36m(ClientAppActor pid=25396)[0m Epoch 3: train loss 0.8304, accuracy 68.86%
[36m(ClientAppActor pid=25396)[0m Epoch 4: train loss 0.7997, accuracy 69.71%
[36m(ClientAppActor pid=25396)[0m Epoch 5: train loss 0.7883, accuracy 70.09%
[36m(ClientAppActor pid=25396)[0m [Client 2, round 3] fit, config: {'server_round': 3, 'local_epochs': 5}
[36m(ClientAppActor pid=25396)[0m Epoch 1: train loss 0.9349, accuracy 64.91%
