In [1]:
import flwr as fl
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix

# -------------------------
#   Load and preprocess data
# -------------------------
img_size = (128, 128)
batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
    "AlzheimerDataset/train",
    image_size=img_size,
    batch_size=batch_size,
    shuffle=True
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    "AlzheimerDataset/test",
    image_size=img_size,
    batch_size=batch_size,
    shuffle=False
)

# Normalize images
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))

# Partition data across clients
def partition_dataset(dataset, client_id, num_clients=2):
    return dataset.shard(num_shards=num_clients, index=client_id)

client_id = int(input("Enter client ID (0 or 1): "))

# Each client gets its own unique training and testing data partition
train_client_ds = partition_dataset(train_ds, client_id, num_clients=2)
test_client_ds = partition_dataset(test_ds, client_id, num_clients=2) # ✅ PARTITION THE TEST DATA

# Memory-safe class weight calculation
y_train_list = []
for _, batch_y in train_client_ds:
    y_train_list.append(batch_y.numpy())
y_train_client = np.concatenate(y_train_list)

classes = np.unique(y_train_client)
class_weights = compute_class_weight("balanced", classes=classes, y=y_train_client)
class_weights = dict(zip(classes, class_weights))
print(f"Client {client_id} class weights computed.")

# -------------------------
#   Define CNN model
# -------------------------
def create_cnn_model():
    """Creates a memory-efficient CNN model."""
    model = models.Sequential([
        layers.Input(shape=(128, 128, 3)),
        layers.Conv2D(16, (3,3), activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(32, (3,3), activation='relu'),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(4, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

model = create_cnn_model()
print("✅ CNN model created.")

# -------------------------
#   Flower client class
# -------------------------
class AlzheimerClient(fl.client.NumPyClient):
    def get_parameters(self, config=None):
        return model.get_weights()

    def fit(self, parameters, config):
        model.set_weights(parameters)
        model.fit(train_client_ds, epochs=1, class_weight=class_weights, verbose=0)
        return model.get_weights(), len(y_train_client), {}

    def evaluate(self, parameters, config=None):
        model.set_weights(parameters)
        
        # ✅ EVALUATE ON THE CLIENT'S PARTITIONED TEST DATA
        loss, acc = model.evaluate(test_client_ds, verbose=0)
        
        y_true, y_pred = [], []
        # ✅ PREDICT ON THE CLIENT'S PARTITIONED TEST DATA
        for batch_x, batch_y in test_client_ds:
            preds = model.predict(batch_x, verbose=0)
            preds = np.argmax(preds, axis=1)
            y_true.extend(batch_y.numpy())
            y_pred.extend(preds)
        
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        
        cm = confusion_matrix(y_true, y_pred, labels=classes) # Ensure all classes are represented
        report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
        
        class_acc = {f"class_{cls}_acc": report.get(str(cls), {}).get("recall", 0.0) 
                     for cls in classes}
        
        metrics = {"loss": loss, "accuracy": acc}
        metrics.update(class_acc)
        
        cm_flat_list = cm.flatten().tolist()
        cm_string = ",".join(map(str, cm_flat_list))
        
        metrics["confusion_matrix_str"] = cm_string
        metrics["num_classes"] = cm.shape[0]

        num_test_examples = len(y_true)
        
        return loss, num_test_examples, metrics

# -------------------------
#   Start client
# -------------------------
print(f"🚀 Starting client {client_id}...")
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=AlzheimerClient())



Found 10240 files belonging to 4 classes.
Found 1283 files belonging to 4 classes.


Enter client ID (0 or 1):  1


	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        


Client 1 class weights computed.
✅ CNN model created.
🚀 Starting client 1...


[92mINFO [0m:      
[92mINFO [0m:      Received: train message dade3672-ee26-4e1a-adee-9357c347c574
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message ec045fe2-16d8-4390-a3bf-c55550687f30
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 521f38e1-d7a0-4317-b41f-bbe09ec86c04
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message a2ce51b0-421d-43e2-96f7-efcff2ced5bd
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 7e3a454c-2a2b-4430-955a-1a079062f9e7
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 37541a79-ece5-44e0-bbbd-393261c721ef
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: reconnect message e6334d3d-4b3e-45dc-9542-c5d93ff1cd66
[92mINFO [0m:      Disconnect and shut down
