In [47]:
## Script to split the data into a federated Dataset

import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import os



def split_data(dataset, num_clients, deviation_factor=0.1):
    """
    Splits data across clients with slight deviation
    """
    client_indices = {i: [] for i in range(num_clients)}

    # Split indices based on classes
    for class_idx in range(10):
        class_indices = np.where(np.array(dataset.targets) == class_idx)[0]
        np.random.shuffle(class_indices)

        # Calculate the number of samples
        total_samples = len(class_indices)
        avg_samples_per_client = total_samples // num_clients


        splits = []
        remaining_samples = total_samples
        for client in range(num_clients):
            # Apply deviation to the average samples
            deviation = np.random.randint(-int(avg_samples_per_client * deviation_factor),
                                          int(avg_samples_per_client * deviation_factor) + 1)
            client_sample_count = avg_samples_per_client + deviation

            if client == num_clients - 1:
                client_sample_count = remaining_samples
            else:
                client_sample_count = min(client_sample_count, remaining_samples)

            remaining_samples -= client_sample_count
            splits.append(client_sample_count)

        split_class_indices = np.split(class_indices, np.cumsum(splits[:-1]))

        for client, indices in enumerate(split_class_indices):
            client_indices[client].extend(indices)

    return client_indices



# Function to save client data as jpg
def save_client_data_as_jpg(dataset, indices, client_dir, client_id):
    client_path = os.path.join(client_dir, f'client_{client_id}')
    os.makedirs(client_path, exist_ok=True)


    for idx in indices:
        img, label = dataset[idx]
        label_folder = os.path.join(client_path, f'label_{label}')
        os.makedirs(label_folder, exist_ok=True)

        # Save the image as a .jpg
        img_path = os.path.join(label_folder, f'{idx}.jpg')

        img_pil = transforms.ToPILImage()(img)
        img_pil.save(img_path)





def main():
    transform = transforms.Compose([
                                transforms.RandomHorizontalFlip(p=0.3),
                                transforms.RandomApply([transforms.RandomRotation(10)], p=0.3),
                                transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
                                # transforms.RandomApply([transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))], p=0.2),
                                transforms.ToTensor(),
                                # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    num_clients = 10
    client_dir = './federated_clients_data/'

    # Spliting data for training and testing
    train_indices = split_data(train_dataset, num_clients, deviation_factor=0.2)
    test_indices = split_data(test_dataset, num_clients, deviation_factor=0.2)

    # Saving the  data for each client
    for client_id in range(num_clients):
        save_client_data_as_jpg(train_dataset, train_indices[client_id], client_dir, client_id)
        save_client_data_as_jpg(test_dataset, test_indices[client_id], client_dir, client_id)

    print(f"Data has been split and saved into '{client_dir}' for {num_clients} clients.")

if __name__ == '__main__':
    main()

Files already downloaded and verified
Files already downloaded and verified
Data has been split and saved into './federated_clients_data/' for 10 clients.


In [48]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [49]:
print(torch.__version__)

2.4.1+cu121


In [50]:
!pip install -q flwr[simulation] flwr-datasets[vision]

In [51]:
from collections import OrderedDict
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

Training on cpu
Flower 1.12.0 / PyTorch 2.4.1+cu121


In [52]:
NUM_CLIENTS = 10
BATCH_SIZE = 32

In [53]:
NUM_CLIENTS,BATCH_SIZE

(10, 32)

In [54]:
import os
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image

In [55]:
class ClientDataset(Dataset):
  def __init__(self,root_dir,client_id,transform=None):
    self.client_dir = os.path.join(root_dir,client_id)
    self.transform = transform
    self.images = []
    self.labels = []
    self._load_data()
    print(f"Loading data for {client_id}")
  def _load_data(self):
    for label in os.listdir(self.client_dir):
            label_dir = os.path.join(self.client_dir, label)
            if os.path.isdir(label_dir):
                for img_file in os.listdir(label_dir):
                    img_path = os.path.join(label_dir, img_file)
                    self.images.append(img_path)
                    self.labels.append(int(label[-1]))

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img_path = self.images[idx]
    image = Image.open(img_path)
    label = self.labels[idx]
    if self.transform:
        image = self.transform(image)
    return image, label


In [56]:
def load_datasets(partition_id:int):
    root_dir = "/content/federated_clients_data"
    client_id = "client_" + str(partition_id)
    transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to match CIFAR-10 dimensions
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    client_dataset = ClientDataset(root_dir=root_dir, client_id=client_id, transform=transform)

    train_size = int(0.8 * len(client_dataset))
    test_size = len(client_dataset) - train_size

    train_dataset, test_dataset = random_split(client_dataset, [train_size, test_size])

    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset,[train_size,val_size])
    trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    valloader = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=False)
    testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return trainloader,valloader,testloader




In [57]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # convolutional layer (sees 32x32x3 image tensor)
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        # convolutional layer (sees 16x16x16 tensor)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        # convolutional layer (sees 8x8x32 tensor)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        # max pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        # linear layer (64 * 4 * 4 -> 500)
        self.fc1 = nn.Linear(64 * 4 * 4, 500)
        # linear layer (500 -> 10)
        self.fc2 = nn.Linear(500, 10)
        # dropout layer (p=0.25)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        # add sequence of convolutional and max pooling layers
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        # flatten image input
        x = x.view(-1, 64 * 4 * 4)
        # add dropout layer
        x = self.dropout(x)
        # add 1st hidden layer, with relu activation function
        x = F.relu(self.fc1(x))
        # add dropout layer
        x = self.dropout(x)
        # add 2nd hidden layer, with relu activation function
        x = self.fc2(x)
        return x

In [58]:
import torch.optim as optim

In [59]:
def train(net, trainloader, epochs: int, verbose=False):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.01)
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item()*labels.size(0)
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        if verbose:
            print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch
            outputs = net(images)
            loss += criterion(outputs, labels).item()*labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total

    return loss, accuracy

In [60]:
def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [61]:
metrics_globally_stored = {}

In [62]:

class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, valloader,partitionId):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.id = partitionId
        self.metrics_list = []

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=3)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        self.metrics_list.append((loss,accuracy))
        print("Adding into global")
        print(f"Global dict {metrics_globally_stored}")
        print(f"Metric List for client {self.id} : {self.metrics_list}: ")
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [63]:
clients_cache = {}

def client_fn(context: Context) -> Client:
    """Create a Flower client representing a single organization."""

    # Load model
    net = Net().to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data partition
    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    if partition_id not in clients_cache:
      trainloader, valloader, _ = load_datasets(partition_id=partition_id)
      clients_cache[partition_id] = FlowerClient(net, trainloader, valloader,partition_id).to_client()

    # Create a single Flower client representing a single organization
    # FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
    # to convert it to a subclass of `flwr.client.Client`
    return clients_cache[partition_id]


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [64]:
from flwr.server.strategy import FedAvg

class CustomStrategy(FedAvg):
  def aggregate_evaluate(self,rnd,results,failure):
    if failure:
      print(f"Round {rnd} had {len(failure)} failures")
    accuracy_list = [r.metrics["accuracy"] for _, r in results]
    big_list = [(_.cid,r.metrics["accuracy"]) for _, r in results]
    average_accuracy = sum(accuracy_list) / len(accuracy_list)
    print(f"Round {rnd} accuracy: {average_accuracy}")
    print(f"Accuracy List for Round {rnd} : {accuracy_list}")
    for id,accuracy in big_list:
      if id not in metrics_globally_stored:
        metrics_globally_stored[id] = []
    for id,accuracy in big_list:
      metrics_globally_stored[id].append((rnd,accuracy))

    print(f"Big List for Round {rnd} : {big_list}")

    return super().aggregate_evaluate(rnd, results, failure)

In [65]:
# # Create FedAvg strategy
# strategy = FedAvg(
#     fraction_fit=1.0,  # Sample 100% of available clients for training
#     fraction_evaluate=1.0,  # Sample 50% of available clients for evaluation
#     min_fit_clients=10,  # Never sample less than 10 clients for training
#     min_evaluate_clients=10,  # Never sample less than 5 clients for evaluation
#     min_available_clients=10,  # Wait until all 10 clients are available
# )

In [66]:
# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

In [67]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

In [68]:
def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """

    # Create FedAvg strategy
    # strategy = FedAvg(
    #     fraction_fit=1.0,
    #     fraction_evaluate=1.0,
    #     min_fit_clients=10,
    #     min_evaluate_clients=10,
    #     min_available_clients=10,
    #     evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
    # )

    strategy = CustomStrategy(
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=10,
        min_evaluate_clients=10,
        min_available_clients=10,
        evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
    )

    # Configure the server for 5 rounds of training
    config = ServerConfig(num_rounds=5)

    return ServerAppComponents(strategy=strategy, config=config)

In [69]:
import os
os.environ["RAY_DEDUP_LOGS"] = "0"

# Create a new server instance with the updated FedAvg strategy
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)

print(metrics_globally_stored)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=5, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=9672)[0m 2024-10-14 21:42:20.094827: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=9672)[0m 2024-10-14 21:42:20.147876: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=9672)[0m 2024-10-14 21:42:20.164571: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      S

[36m(ClientAppActor pid=9672)[0m Loading data for client_8
[36m(ClientAppActor pid=9672)[0m Loading data for client_7
[36m(ClientAppActor pid=9671)[0m Loading data for client_8
[36m(ClientAppActor pid=9672)[0m Loading data for client_2
[36m(ClientAppActor pid=9672)[0m Loading data for client_5[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=9671)[0m Loading data for client_4
[36m(ClientAppActor pid=9672)[0m Loading data for client_9
[36m(ClientAppActor pid=9671)[0m Loading data for client_1
[36m(ClientAppActor pid=9672)[0m Loading data for client_0
[36m(ClientAppActor pid=9671)[0m Loading data for client_6


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=9672)[0m Loading data for client_6
[36m(ClientAppActor pid=9671)[0m Adding into global
[36m(ClientAppActor pid=9671)[0m Global dict {}
[36m(ClientAppActor pid=9671)[0m Metric List for client 0 : [(2.260071405543935, 0.22484881803188564)]: 
[36m(ClientAppActor pid=9672)[0m Loading data for client_4[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Adding into global[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Global dict {}[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Metric List for client 8 : [(2.2625038871552023, 0.21508379888268156)]: [32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=9671)[0m Loading data for client_2[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Adding into global[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Global dict {}[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid

[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


Round 1 accuracy: 0.2178186235249925
Accuracy List for Round 1 : [0.21389496717724288, 0.22484881803188564, 0.19915029208709506, 0.20728291316526612, 0.2294816414686825, 0.23111612175873733, 0.21508379888268156, 0.20512820512820512, 0.22320441988950276, 0.22899505766062603]
Big List for Round 1 : [('12245189676292809591', 0.21389496717724288), ('9301257855516925193', 0.22484881803188564), ('9392099085599778965', 0.19915029208709506), ('8061912160097829685', 0.20728291316526612), ('9620330208725875134', 0.2294816414686825), ('4799881870640953848', 0.23111612175873733), ('10822315389268395931', 0.21508379888268156), ('2204578931904311408', 0.20512820512820512), ('1294640040899982045', 0.22320441988950276), ('7789576785706923716', 0.22899505766062603)]
[36m(ClientAppActor pid=9672)[0m Loading data for client_8[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=9671)[0m Adding into global
[36m(ClientAppActor pid=9671)[0m Global dict {}
[36m(ClientAppActor pid=9671)[0m Me

[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=9672)[0m Loading data for client_3
[36m(ClientAppActor pid=9672)[0m Adding into global
[36m(ClientAppActor pid=9672)[0m Global dict {'12245189676292809591': [(1, 0.21389496717724288)], '9301257855516925193': [(1, 0.22484881803188564)], '9392099085599778965': [(1, 0.19915029208709506)], '8061912160097829685': [(1, 0.20728291316526612)], '9620330208725875134': [(1, 0.2294816414686825)], '4799881870640953848': [(1, 0.23111612175873733)], '10822315389268395931': [(1, 0.21508379888268156)], '2204578931904311408': [(1, 0.20512820512820512)], '1294640040899982045': [(1, 0.22320441988950276)], '7789576785706923716': [(1, 0.22899505766062603)]}
[36m(ClientAppActor pid=9672)[0m Metric List for client 3 : [(1.9788643793531668, 0.2896890343698854)]: 
[36m(ClientAppActor pid=9671)[0m Loading data for client_4[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Adding into global[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=967

[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


Round 2 accuracy: 0.30040827644069923
Accuracy List for Round 2 : [0.2896890343698854, 0.27987254381306426, 0.2927374301675978, 0.30883977900552484, 0.3173618940248027, 0.31017505470459517, 0.3086216364634816, 0.30196078431372547, 0.30345572354211664, 0.291368884002199]
Big List for Round 2 : [('2204578931904311408', 0.2896890343698854), ('9392099085599778965', 0.27987254381306426), ('10822315389268395931', 0.2927374301675978), ('1294640040899982045', 0.30883977900552484), ('4799881870640953848', 0.3173618940248027), ('12245189676292809591', 0.31017505470459517), ('7789576785706923716', 0.3086216364634816), ('8061912160097829685', 0.30196078431372547), ('9620330208725875134', 0.30345572354211664), ('9301257855516925193', 0.291368884002199)]
[36m(ClientAppActor pid=9672)[0m Loading data for client_2[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=9671)[0m Adding into global[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=9671)[0m Global dict {'122451896

[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=9672)[0m Loading data for client_1
[36m(ClientAppActor pid=9672)[0m Adding into global
[36m(ClientAppActor pid=9672)[0m Global dict {'12245189676292809591': [(1, 0.21389496717724288), (2, 0.31017505470459517)], '9301257855516925193': [(1, 0.22484881803188564), (2, 0.291368884002199)], '9392099085599778965': [(1, 0.19915029208709506), (2, 0.27987254381306426)], '8061912160097829685': [(1, 0.20728291316526612), (2, 0.30196078431372547)], '9620330208725875134': [(1, 0.2294816414686825), (2, 0.30345572354211664)], '4799881870640953848': [(1, 0.23111612175873733), (2, 0.3173618940248027)], '10822315389268395931': [(1, 0.21508379888268156), (2, 0.2927374301675978)], '2204578931904311408': [(1, 0.20512820512820512), (2, 0.2896890343698854)], '1294640040899982045': [(1, 0.22320441988950276), (2, 0.30883977900552484)], '7789576785706923716': [(1, 0.22899505766062603), (2, 0.3086216364634816)]}
[36m(ClientAppActor pid=9672)[0m Metric List for client 1 : [(1.799838

[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=9672)[0m Adding into global[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Global dict {'12245189676292809591': [(1, 0.21389496717724288), (2, 0.31017505470459517)], '9301257855516925193': [(1, 0.22484881803188564), (2, 0.291368884002199)], '9392099085599778965': [(1, 0.19915029208709506), (2, 0.27987254381306426)], '8061912160097829685': [(1, 0.20728291316526612), (2, 0.30196078431372547)], '9620330208725875134': [(1, 0.2294816414686825), (2, 0.30345572354211664)], '4799881870640953848': [(1, 0.23111612175873733), (2, 0.3173618940248027)], '10822315389268395931': [(1, 0.21508379888268156), (2, 0.2927374301675978)], '2204578931904311408': [(1, 0.20512820512820512), (2, 0.2896890343698854)], '1294640040899982045': [(1, 0.22320441988950276), (2, 0.30883977900552484)], '7789576785706923716': [(1, 0.22899505766062603), (2, 0.3086216364634816)]}[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Metric List for client

[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=9671)[0m Loading data for client_5[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=9671)[0m Adding into global
[36m(ClientAppActor pid=9671)[0m Global dict {'12245189676292809591': [(1, 0.21389496717724288), (2, 0.31017505470459517), (3, 0.3435448577680525)], '9301257855516925193': [(1, 0.22484881803188564), (2, 0.291368884002199), (3, 0.35294117647058826)], '9392099085599778965': [(1, 0.19915029208709506), (2, 0.27987254381306426), (3, 0.31917153478491767)], '8061912160097829685': [(1, 0.20728291316526612), (2, 0.30196078431372547), (3, 0.3366946778711485)], '9620330208725875134': [(1, 0.2294816414686825), (2, 0.30345572354211664), (3, 0.37634989200863933)], '4799881870640953848': [(1, 0.23111612175873733), (2, 0.3173618940248027), (3, 0.36189402480270577)], '10822315389268395931': [(1, 0.21508379888268156), (2, 0.2927374301675978), (3, 0.35977653631284917)], '2204578931904311408': [(1, 0.20512820512820512), (2, 0.2896890343698854), (3, 0.33

[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


Round 4 accuracy: 0.3864601008889635
Accuracy List for Round 4 : [0.3906426155580609, 0.4005586592178771, 0.3779697624190065, 0.3938375350140056, 0.3840698308783415, 0.3862144420131291, 0.3904448105436573, 0.37758895379713225, 0.38592633315008246, 0.37734806629834255]
Big List for Round 4 : [('4799881870640953848', 0.3906426155580609), ('10822315389268395931', 0.4005586592178771), ('9620330208725875134', 0.3779697624190065), ('8061912160097829685', 0.3938375350140056), ('2204578931904311408', 0.3840698308783415), ('12245189676292809591', 0.3862144420131291), ('7789576785706923716', 0.3904448105436573), ('9392099085599778965', 0.37758895379713225), ('9301257855516925193', 0.38592633315008246), ('1294640040899982045', 0.37734806629834255)]
[36m(ClientAppActor pid=9671)[0m Loading data for client_5[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Adding into global[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Global dict {'122451896762

[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=9672)[0m Loading data for client_2[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Adding into global
[36m(ClientAppActor pid=9672)[0m Global dict {'12245189676292809591': [(1, 0.21389496717724288), (2, 0.31017505470459517), (3, 0.3435448577680525), (4, 0.3862144420131291)], '9301257855516925193': [(1, 0.22484881803188564), (2, 0.291368884002199), (3, 0.35294117647058826), (4, 0.38592633315008246)], '9392099085599778965': [(1, 0.19915029208709506), (2, 0.27987254381306426), (3, 0.31917153478491767), (4, 0.37758895379713225)], '8061912160097829685': [(1, 0.20728291316526612), (2, 0.30196078431372547), (3, 0.3366946778711485), (4, 0.3938375350140056)], '9620330208725875134': [(1, 0.2294816414686825), (2, 0.30345572354211664), (3, 0.37634989200863933), (4, 0.3779697624190065)], '4799881870640953848': [(1, 0.23111612175873733), (2, 0.3173618940248027), (3, 0.36189402480270577), (4, 0.3906426155580609)], '10822315389268395931': [(1, 0.215

[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 1720.60s


[36m(ClientAppActor pid=9671)[0m Adding into global[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=9671)[0m Global dict {'12245189676292809591': [(1, 0.21389496717724288), (2, 0.31017505470459517), (3, 0.3435448577680525), (4, 0.3862144420131291)], '9301257855516925193': [(1, 0.22484881803188564), (2, 0.291368884002199), (3, 0.35294117647058826), (4, 0.38592633315008246)], '9392099085599778965': [(1, 0.19915029208709506), (2, 0.27987254381306426), (3, 0.31917153478491767), (4, 0.37758895379713225)], '8061912160097829685': [(1, 0.20728291316526612), (2, 0.30196078431372547), (3, 0.3366946778711485), (4, 0.3938375350140056)], '9620330208725875134': [(1, 0.2294816414686825), (2, 0.30345572354211664), (3, 0.37634989200863933), (4, 0.3779697624190065)], '4799881870640953848': [(1, 0.23111612175873733), (2, 0.3173618940248027), (3, 0.36189402480270577), (4, 0.3906426155580609)], '10822315389268395931': [(1, 0.21508379888268156), (2, 0.2927374301675978), (3, 0.3597765363128

[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.2637362612560388
[92mINFO [0m:      		round 2: 1.9590754944978435
[92mINFO [0m:      		round 3: 1.829113985346162
[92mINFO [0m:      		round 4: 1.691438128520647
[92mINFO [0m:      		round 5: 1.584078641776395
[92mINFO [0m:      	History (metrics, distributed, evaluate):
[92mINFO [0m:      	{'accuracy': [(1, 0.21774464680662126),
[92mINFO [0m:      	              (2, 0.3003211929421944),
[92mINFO [0m:      	              (3, 0.34767076895276433),
[92mINFO [0m:      	              (4, 0.38636477384584617),
[92mINFO [0m:      	              (5, 0.4297909058828216)]}
[92mINFO [0m:      


[36m(ClientAppActor pid=9672)[0m Loading data for client_9[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=9672)[0m Adding into global
[36m(ClientAppActor pid=9672)[0m Global dict {'12245189676292809591': [(1, 0.21389496717724288), (2, 0.31017505470459517), (3, 0.3435448577680525), (4, 0.3862144420131291)], '9301257855516925193': [(1, 0.22484881803188564), (2, 0.291368884002199), (3, 0.35294117647058826), (4, 0.38592633315008246)], '9392099085599778965': [(1, 0.19915029208709506), (2, 0.27987254381306426), (3, 0.31917153478491767), (4, 0.37758895379713225)], '8061912160097829685': [(1, 0.20728291316526612), (2, 0.30196078431372547), (3, 0.3366946778711485), (4, 0.3938375350140056)], '9620330208725875134': [(1, 0.2294816414686825), (2, 0.30345572354211664), (3, 0.37634989200863933), (4, 0.3779697624190065)], '4799881870640953848': [(1, 0.23111612175873733), (2, 0.3173618940248027), (3, 0.36189402480270577), (4, 0.3906426155580609)], '10822315389268395931': [(1, 0.215

[36m(pid=9671)[0m 2024-10-14 21:42:20.094827: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=9671)[0m 2024-10-14 21:42:20.147891: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=9671)[0m 2024-10-14 21:42:20.167154: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


{'12245189676292809591': [(1, 0.21389496717724288), (2, 0.31017505470459517), (3, 0.3435448577680525), (4, 0.3862144420131291), (5, 0.4310722100656455)], '9301257855516925193': [(1, 0.22484881803188564), (2, 0.291368884002199), (3, 0.35294117647058826), (4, 0.38592633315008246), (5, 0.4403518416712479)], '9392099085599778965': [(1, 0.19915029208709506), (2, 0.27987254381306426), (3, 0.31917153478491767), (4, 0.37758895379713225), (5, 0.4142326075411577)], '8061912160097829685': [(1, 0.20728291316526612), (2, 0.30196078431372547), (3, 0.3366946778711485), (4, 0.3938375350140056), (5, 0.4380952380952381)], '9620330208725875134': [(1, 0.2294816414686825), (2, 0.30345572354211664), (3, 0.37634989200863933), (4, 0.3779697624190065), (5, 0.41576673866090713)], '4799881870640953848': [(1, 0.23111612175873733), (2, 0.3173618940248027), (3, 0.36189402480270577), (4, 0.3906426155580609), (5, 0.43122886133032695)], '10822315389268395931': [(1, 0.21508379888268156), (2, 0.2927374301675978), (3, 0.

In [None]:
print(metrics_globally_stored.items())


In [None]:
strategy = CustomStrategy(
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=10,
        min_evaluate_clients=10,
        min_available_clients=10,
        evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
    )

In [None]:
strategy = CustomStrategy(
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=10,
        min_evaluate_clients=10,
        min_available_clients=10,
        evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
    )

    # Configure the server for 5 rounds of training
config = ServerConfig(num_rounds=30)

serverac = ServerAppComponents(strategy=strategy, config=config)

server_config = ServerConfig(num_rounds=10)



In [None]:
from flwr.simulation import start_simulation

import os
os.environ["RAY_DEDUP_LOGS"] = "0"

start_simulation(
    client_fn=client_fn,   # Function to create client instances
    strategy = strategy,
    num_clients=NUM_CLIENTS,
    # Number of clients to simulate
    # server = server_config,
    # config={"num_rounds": 10},# Number of federated learning rounds (optional, adjust as needed)
)

[92mINFO [0m:      Starting Flower simulation, config: num_rounds=1, no round_timeout
2024-10-14 05:46:47,261	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'node:172.28.0.12': 1.0, 'node:__internal_head__': 1.0, 'memory': 7969249691.0, 'object_store_memory': 3984624844.0, 'CPU': 2.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for clients.
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 2 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=13675)[0m 2024-10-14 05:46:52.584481: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT 

[36m(ClientAppActor pid=13677)[0m Loading data for client_3
[36m(ClientAppActor pid=13677)[0m Loading data for client_9
[36m(ClientAppActor pid=13675)[0m Loading data for client_5
[36m(ClientAppActor pid=13677)[0m Loading data for client_7
[36m(ClientAppActor pid=13675)[0m Loading data for client_0
[36m(ClientAppActor pid=13677)[0m Loading data for client_1
[36m(ClientAppActor pid=13675)[0m Loading data for client_4
[36m(ClientAppActor pid=13677)[0m Loading data for client_6
[36m(ClientAppActor pid=13675)[0m Loading data for client_2
[36m(ClientAppActor pid=13677)[0m Loading data for client_8
[36m(ClientAppActor pid=13675)[0m Loading data for client_3


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=13675)[0m Loading data for client_7
[36m(ClientAppActor pid=13675)[0m Metric List for client 7 : [(0.07221347790259819, 0.12337662337662338)]: 
[36m(ClientAppActor pid=13677)[0m Loading data for client_9[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=13675)[0m Metric List for client 6 : [(0.07313690422955206, 0.11228813559322035)]: [32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 1 round(s) in 82.00s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.0724840299472361
[92mINFO [0m:      	History (metrics, distributed, evaluate):
[92mINFO [0m:      	{'accuracy': [(1, 0.11558737948525408)]}
[92mINFO [0m:      


Round 1 accuracy: 0.11560796491104978
Accuracy List for Round 1 : [0.10228509249183895, 0.17395833333333333, 0.11979695431472082, 0.1099476439790576, 0.12337662337662338, 0.08481675392670157, 0.11228813559322035, 0.10122699386503067, 0.11356466876971609, 0.11481844946025516]
Big List for Round 1 : [('7306461449108323292', 0.10228509249183895), ('5039418663426602635', 0.17395833333333333), ('246721759113017047', 0.11979695431472082), ('-4310606299311021011', 0.1099476439790576), ('5256554666922738489', 0.12337662337662338), ('6165841709946672229', 0.08481675392670157), ('-2251380921068617213', 0.11228813559322035), ('4856979522770448443', 0.10122699386503067), ('-1394596644972081091', 0.11356466876971609), ('-7577731239845339981', 0.11481844946025516)]


History (loss, distributed):
	round 1: 0.0724840299472361
History (metrics, distributed, evaluate):
{'accuracy': [(1, 0.11558737948525408)]}