In [None]:
!pip install -q flwr[simulation] torch torchvision

In [1]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
import time
import flwr as fl

DEVICE = torch.device("cuda")  # Try "cuda" to train on GPU
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)

Training on cuda using PyTorch 1.13.1 and Flower 1.4.0


In [2]:
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
   process = psutil.Process(os.getpid())
   print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
   print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm()

Gen RAM Free: 125.4 GB  | Proc size: 304.4 MB
GPU RAM Free: 15850MB | Used: 318MB | Util   2% | Total 16376MB


In [3]:
NUM_CLIENTS = 10


def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader


trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
from multiprocessing.connection import Client

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):  # Use the passed 'epochs' variable here
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item()  # Make sure to call .item() to get the scalar value
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch}: train loss {epoch_loss:.6f}, accuracy {epoch_acc:.6f}")



def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [5]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        epochs = config.get("epochs", 1)
        start_time = time.time()  # Start time measurement
        train(self.net, self.trainloader, epochs)
        training_time = time.time() - start_time  # Calculate duration
        print(f"Training time for Client {self.cid}: {training_time:.2f} seconds")
        return get_parameters(self.net), len(self.trainloader), {"training_time": training_time}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

In [6]:
from typing import Callable, Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg


class FedCustom(fl.server.strategy.Strategy):
    def __init__(
        self,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
    ) -> None:
        super().__init__()
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.client_training_times = {}
    def __repr__(self) -> str:
        return "FedCustom"

    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = Net()
        ndarrays = get_parameters(net)
        return fl.common.ndarrays_to_parameters(ndarrays)

    def configure_fit(self, server_round: int, parameters: Parameters, client_manager: ClientManager):
        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)
        epochs_sc = 5
        epochs_hl = 4

        standard_config = {"lr": 0.001, "epochs": epochs_sc}
        higher_lr_config = {"lr": 0.0001, "epochs": epochs_hl}
        fit_configurations = []

        for client in clients:
            # Choose config based on the previous training time
            last_time = self.client_training_times.get(client.cid, 0)  # Default to 0 if no time recorded
            print(f"This is the last time {last_time}")

            config_to_use = standard_config if last_time < 13.8 else higher_lr_config
            fit_configurations.append((client, FitIns(parameters, config_to_use)))

        return fit_configurations

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""
        for client, fit_res in results:
            # Update training times for each client
            self.client_training_times[client.cid] = fit_res.metrics.get("training_time", 0)
        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]
        parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
        metrics_aggregated = {}
        return parameters_aggregated, metrics_aggregated


    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}

        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate global model parameters using an evalua
        tion function."""

        # Let's assume we won't perform the global model evaluation on the server side.
        return None

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients

In [7]:
if DEVICE.type == "cuda":
    # Use a single client to train the global model
    client_resources = {"num_gpus": .25, "num_cpus": 2}

In [8]:

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=10,
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=FedCustom(),  # <-- pass the new strategy here
    client_resources=client_resources,
)

INFO flwr 2024-06-26 19:27:34,205 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
2024-06-26 19:27:38,354	INFO worker.py:1752 -- Started a local Ray instance.
INFO flwr 2024-06-26 19:27:40,355 | app.py:180 | Flower VCE: Ray initialized with resources: {'node:127.0.0.1': 1.0, 'object_store_memory': 37484094259.0, 'memory': 77462886605.0, 'GPU': 1.0, 'accelerator_type:RTX': 1.0, 'node:__internal_head__': 1.0, 'CPU': 32.0}
INFO flwr 2024-06-26 19:27:40,358 | server.py:86 | Initializing global parameters
INFO flwr 2024-06-26 19:27:40,366 | server.py:269 | Using initial parameters provided by strategy
INFO flwr 2024-06-26 19:27:40,367 | server.py:88 | Evaluating initial parameters
INFO flwr 2024-06-26 19:27:40,369 | server.py:101 | FL starting
DEBUG flwr 2024-06-26 19:27:40,370 | server.py:218 | fit_round 1: strategy sampled 10 clients (out of 10)


This is the last time 0
This is the last time 0
This is the last time 0
This is the last time 0
This is the last time 0
This is the last time 0
This is the last time 0
This is the last time 0
This is the last time 0
This is the last time 0
[36m(launch_and_fit pid=30860)[0m [Client 4] fit, config: {'lr': 0.001, 'epochs': 5}
[36m(launch_and_fit pid=32076)[0m Epoch 0: train loss 0.064907, accuracy 0.233556
[36m(launch_and_fit pid=17624)[0m [Client 6] fit, config: {'lr': 0.001, 'epochs': 5}[32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m
[36m(launch_and_fit pid=30860)[0m Epoch 2: train loss 0.051712, accuracy 0.391333[32m [repeated 8x across cluster][0m
[36m(launch_and_fit pid=32076)[0m Training time for Client 5: 21.60 seconds
[36m(launch_and_fit pid=30860)[0m Epoch 4: train loss 0.047215, acc

DEBUG flwr 2024-06-26 19:29:09,753 | server.py:232 | fit_round 1 received 10 results and 0 failures
DEBUG flwr 2024-06-26 19:29:09,787 | server.py:168 | evaluate_round 1: strategy sampled 10 clients (out of 10)


[36m(launch_and_evaluate pid=26100)[0m [Client 3] evaluate, config: {}
[36m(launch_and_fit pid=25520)[0m Epoch 4: train loss 0.046451, accuracy 0.451111[32m [repeated 4x across cluster][0m
[36m(launch_and_fit pid=25520)[0m Training time for Client 9: 20.67 seconds
[36m(launch_and_evaluate pid=16172)[0m [Client 0] evaluate, config: {}[32m [repeated 4x across cluster][0m
[36m(launch_and_evaluate pid=33496)[0m [Client 8] evaluate, config: {}
[36m(launch_and_evaluate pid=18692)[0m [Client 1] evaluate, config: {}
[36m(launch_and_evaluate pid=22336)[0m [Client 5] evaluate, config: {}[32m [repeated 2x across cluster][0m
[36m(launch_and_evaluate pid=21992)[0m [Client 4] evaluate, config: {}


ERROR flwr 2024-06-26 19:31:42,598 | ray_client_proxy.py:104 | [36mray::launch_and_evaluate()[39m (pid=21992, ip=127.0.0.1)
  File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\simulation\ray_transport\ray_client_proxy.py", line 160, in launch_and_evaluate
    return maybe_call_evaluate(
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\client\client.py", line 205, in maybe_call_evaluate
    return client.evaluate(evaluate_ins)
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\client\app.py", line 321, in _evaluate
    results = self.numpy_client.evaluate(parameters, ins.config)  # type: ignore
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_26896\1810498520.py", line 25, in evaluate
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_26896\756452527.py", line 63, in test
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\torch

[33m(raylet)[0m Traceback (most recent call last):
  File "python\ray\_raylet.pyx", line 1883, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 1984, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\simulation\ray_transport\ray_client_proxy.py", line 160, in launch_and_evaluate
    return maybe_call_evaluate(
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\client\client.py", line 205, in maybe_call_evaluate
    return client.evaluate(evaluate_ins)
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\client\app.py", line 321, in _evaluate
    results = self.numpy_client.evaluate(parameters, ins.config)  # type: ignore
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_26896\1810498520.py", line 25, in evaluate
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_26896\756452527.py", line 63, in test
  

[36m(launch_and_evaluate pid=22336)[0m Stack (most recent call first):
[36m(launch_and_evaluate pid=22336)[0m   File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\ray\_private\worker.py", line 879 in main_loop
[36m(launch_and_evaluate pid=22336)[0m   File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\ray\_private\workers\default_worker.py", line 282 in <module>


[33m(raylet)[0m A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: 2896c8822a94d77a320bfbc25a92e52a2a0cbc2701000000 Worker ID: 1726800cfa4272d43d6822356a1f27f44c83334c16f978add39cac6a Node ID: 5e0795a7e98bd965c7dad56b63592a1970a24ce210fd709fe4c53c77 Worker IP address: 127.0.0.1 Worker port: 62375 Worker PID: 22336 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 10054. An existing connection was forcibly closed by the remote host. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
[36m(launch_and_evaluate pid=28288)[0m [Client 5] evaluate, config: {}


[36m(launch_and_evaluate pid=28288)[0m fatal   : Memory allocation failure[32m [repeated 4x across cluster][0m
[36m(launch_and_evaluate pid=28288)[0m *** SIGABRT received at time=1719410513 ***
[36m(launch_and_evaluate pid=28288)[0m     @   00007FFA28B1DD61  (unknown)  (unknown)
[36m(launch_and_evaluate pid=28288)[0m     @   00007FF996415136  (unknown)  (unknown)
[36m(launch_and_evaluate pid=28288)[0m     @   00007FFA28B1D492  (unknown)  (unknown)
[36m(launch_and_evaluate pid=28288)[0m     @   00007FF6354C2297  (unknown)  (unknown)
[36m(launch_and_evaluate pid=28288)[0m     @   00007FFA285CDD31  (unknown)  (unknown)
[36m(launch_and_evaluate pid=28288)[0m     @   00007FFA2AF4AD6C  (unknown)  (unknown)
[36m(launch_and_evaluate pid=28288)[0m     @   00007FFA2AF33CC6  (unknown)  (unknown)
[36m(launch_and_evaluate pid=28288)[0m     @   00007FFA2AF48CDF  (unknown)  (unknown)
[36m(launch_and_evaluate pid=28288)[0m     @   00007FFA2AED5BEA  (unknown)  (unknown)
[36m(la

[33m(raylet)[0m Traceback (most recent call last):
  File "python\ray\_raylet.pyx", line 1883, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 1984, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\simulation\ray_transport\ray_client_proxy.py", line 160, in launch_and_evaluate
    return maybe_call_evaluate(
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\client\client.py", line 205, in maybe_call_evaluate
    return client.evaluate(evaluate_ins)
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\client\app.py", line 321, in _evaluate
    results = self.numpy_client.evaluate(parameters, ins.config)  # type: ignore
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_26896\1810498520.py", line 25, in evaluate
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_26896\756452527.py", line 63, in test
  

[36m(launch_and_evaluate pid=8224)[0m     @   00007FF996415136  (unknown)  PyInit__raylet
[36m(launch_and_evaluate pid=8224)[0m     @   00007FF6354C2297  (unknown)  OPENSSL_Applink
[36m(launch_and_evaluate pid=8224)[0m     @   00007FFA2AF4AD6C  (unknown)  memset
[36m(launch_and_evaluate pid=8224)[0m     @   00007FFA2AF33CC6  (unknown)  _C_specific_handler
[36m(launch_and_evaluate pid=8224)[0m     @   00007FFA2AF48CDF  (unknown)  _chkstk
[36m(launch_and_evaluate pid=8224)[0m     @   00007FFA2AED5BEA  (unknown)  RtlRestoreContext
[36m(launch_and_evaluate pid=8224)[0m     @   00007FFA2AED2EF1  (unknown)  RtlRaiseException
[36m(launch_and_evaluate pid=8224)[0m     @   00007FF9964E8B08  (unknown)  PyInit__raylet
[36m(launch_and_evaluate pid=8224)[0m     @   00007FF9964ED947  (unknown)  PyInit__raylet
[36m(launch_and_evaluate pid=8224)[0m     @   00007FF9964F3354  (unknown)  PyInit__raylet
[36m(launch_and_evaluate pid=8224)[0m     @   00007FF9964E8852  (unknown)  PyInit

[33m(raylet)[0m A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: 1562fa3f988b548fdc2ff62a83689557accf538001000000 Worker ID: a24a6812863c9e5ef55244ea6530cda19871f7441b8c9fd5be98077d Node ID: 5e0795a7e98bd965c7dad56b63592a1970a24ce210fd709fe4c53c77 Worker IP address: 127.0.0.1 Worker port: 62475 Worker PID: 8224 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 10054. An existing connection was forcibly closed by the remote host. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
[36m(launch_and_evaluate pid=33320)[0m [Client 5] evaluate, config: {}


ERROR flwr 2024-06-26 19:32:15,803 | ray_client_proxy.py:104 | The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
DEBUG flwr 2024-06-26 19:32:15,812 | server.py:182 | evaluate_round 1 received 8 results and 2 failures
DEBUG flwr 2024-06-26 19:32:15,814 | server.py:218 | fit_round 2: strategy sampled 10 clients (out of 10)
[36m(launch_and_evaluate pid=33320)[0m fatal   : Memory allocation failure[32m [repeated 5x across cluster][0m
DEBUG flwr 2024-06-26 19:32:15,847 | server.py:232 | fit_round 2 received 0 results and 10 failures
DEBUG flwr 2024-06-26 19:32:15,849 | server.py:168 | evaluate_round 2: strategy sampled 10 clients (out of 10)
DEBUG flwr 2024-06-26 19:32:15,879 | server.py:182 | evaluate_round 2 received 0 results and 10 failures
DEBUG flwr 2024-06-26 19:32:15,882 | server.py:218 | fit_round 3: strategy sampled 10 clients (out of 10)
DEBUG flwr 2024-06-26 19:32:15,919 | server.py:232 | fit_round 3 received 0

[33m(raylet)[0m Traceback (most recent call last):
  File "python\ray\_raylet.pyx", line 1883, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 1984, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\simulation\ray_transport\ray_client_proxy.py", line 160, in launch_and_evaluate
    return maybe_call_evaluate(
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\client\client.py", line 205, in maybe_call_evaluate
    return client.evaluate(evaluate_ins)
  File "c:\Users\Admin\anaconda3\envs\flwrpytorch\lib\site-packages\flwr\client\app.py", line 321, in _evaluate
    results = self.numpy_client.evaluate(parameters, ins.config)  # type: ignore
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_26896\1810498520.py", line 25, in evaluate
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_26896\756452527.py", line 63, in test
  

INFO flwr 2024-06-26 19:32:15,958 | server.py:147 | FL finished in 275.5888034
INFO flwr 2024-06-26 19:32:15,962 | app.py:218 | app_fit: losses_distributed [(1, 0.0568882015645504)]
INFO flwr 2024-06-26 19:32:15,965 | app.py:219 | app_fit: metrics_distributed_fit {}
INFO flwr 2024-06-26 19:32:15,969 | app.py:220 | app_fit: metrics_distributed {}
INFO flwr 2024-06-26 19:32:15,971 | app.py:221 | app_fit: losses_centralized []
INFO flwr 2024-06-26 19:32:15,975 | app.py:222 | app_fit: metrics_centralized {}


History (loss, distributed):
	round 1: 0.0568882015645504

In [None]:
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=10,
    config=fl.server.ServerConfig(num_rounds=3),
    client_resources=client_resources,
)

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=5)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        print(f"Client {self.cid} loss {loss}")
        print(f"Client {self.cid} accuracy {accuracy}")
        
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE) #Load Model from here
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)