In [2]:
import flwr as fl
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import os
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report

# -------------------------
#   Load and preprocess data
# -------------------------
img_size = (128, 128)
batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
    "AlzheimerDataset/train",
    image_size=img_size,
    batch_size=batch_size,
    shuffle=True
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    "AlzheimerDataset/test",
    image_size=img_size,
    batch_size=batch_size
)

# Normalize images
train_ds = train_ds.map(lambda x, y: (x/255.0, y))
test_ds = test_ds.map(lambda x, y: (x/255.0, y))

# Partition data across clients
def partition_dataset(dataset, client_id, num_clients=2):
    dataset = dataset.shuffle(10000)
    total_size = len(dataset)
    part_size = total_size // num_clients
    start = client_id * part_size
    end = start + part_size
    return dataset.skip(start).take(part_size)

client_id = int(input("Enter client ID (0 or 1): "))
train_client_ds = partition_dataset(train_ds, client_id, num_clients=2)

# Convert partitioned dataset to numpy arrays for computing class weights
x_train_list, y_train_list = [], []
for batch_x, batch_y in train_client_ds:
    x_train_list.append(batch_x.numpy())
    y_train_list.append(batch_y.numpy())

x_train_client = np.concatenate(x_train_list)
y_train_client = np.concatenate(y_train_list)

# Compute class weights
classes = np.unique(y_train_client)
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=classes,
    y=y_train_client
)
class_weights = dict(zip(classes, class_weights))

# -------------------------
#   Define CNN model
# -------------------------
def create_cnn_model():
    model = models.Sequential([
        layers.Input(shape=img_size + (3,)),
        layers.Conv2D(32, (3,3), activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(64, (3,3), activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(128, (3,3), activation='relu'),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(4, activation='softmax')
    ])
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

model = create_cnn_model()

# -------------------------
#   Flower client class
# -------------------------
class AlzheimerClient(fl.client.NumPyClient):
    def get_parameters(self, config=None):
        return model.get_weights()

    def fit(self, parameters, config):
        model.set_weights(parameters)
        model.fit(
            x_train_client, y_train_client,
            epochs=1,
            batch_size=32,
            class_weight=class_weights,   # ✅ Apply weights
            verbose=0
        )
        return model.get_weights(), len(x_train_client), {}

    def evaluate(self, parameters, config=None):
        model.set_weights(parameters)

        # Evaluate overall loss/accuracy
        loss, acc = model.evaluate(test_ds, verbose=0)

        # Get predictions for per-class accuracy
        y_true, y_pred = [], []
        for batch_x, batch_y in test_ds:
            preds = model.predict(batch_x, verbose=0)
            preds = np.argmax(preds, axis=1)
            y_true.extend(batch_y.numpy())
            y_pred.extend(preds)

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)

        # Generate per-class metrics
        report = classification_report(
            y_true, y_pred, output_dict=True, zero_division=0
        )

        # Extract per-class accuracy
        class_acc = {f"class_{cls}_acc": report[str(cls)]["recall"] 
                     for cls in np.unique(y_true)}

        # Merge metrics
        metrics = {"loss": loss, "accuracy": acc}
        metrics.update(class_acc)

        return loss, len(test_ds), metrics

# -------------------------
#   Start client
# -------------------------
print(f"🚀 Starting client {client_id}")
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=AlzheimerClient())


Found 10240 files belonging to 4 classes.
Found 1283 files belonging to 4 classes.


Enter client ID (0 or 1):  1


	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        


🚀 Starting client 1


[92mINFO [0m:      
[92mINFO [0m:      Received: train message 84064cb8-f0dd-4719-8d4f-412d83f1c541
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message e17ff0fa-8427-4eb6-ab20-e1ec672b478f
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 2f068893-c28d-40ac-8c08-bcb45a2509f3
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message f665a17e-99a9-4daa-9833-0c01838e0af9
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 60c6adce-12dc-4b01-b45c-72cd9e5e1c58
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 02d0df4f-ef9c-48d7-a309-eef96a18f3b6
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: reconnect message 557c01b7-4340-4272-befb-0c8c755b1586
[92mINFO [0m:      Disconnect and shut down
