In [1]:
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import ndarrays_to_parameters, Context
from flwr.server import ServerApp, ServerConfig
from flwr.server import ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr.common import Metrics, NDArrays, Scalar
from typing import List, Tuple, Dict, Optional

In [2]:
import tensorflow as tf
# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # Normalize to [0, 1]
x_train = x_train.reshape(-1, 28 * 28)  # Flatten images
x_test = x_test.reshape(-1, 28 * 28)



In [3]:
import numpy as np

def exclude_digits(x_data, y_data, excluded_digits):
    mask = ~np.isin(y_data, excluded_digits)  # Create a mask for non-excluded digits
    x_filtered = x_data[mask]  # Filter input data
    y_filtered = y_data[mask]  # Filter labels
    return x_filtered, y_filtered

### Create Train dataset

In [4]:
x_train_1, y_train_1  = exclude_digits(x_train, y_train, excluded_digits=[1, 3, 7])
print("y_train_1 : ",set(y_train_1))


x_train_2, y_train_2 = exclude_digits(x_train, y_train, excluded_digits=[2, 5, 8])
print("y_train_2 : ",set(y_train_2))

x_train_3, y_train_3 = exclude_digits(x_train, y_train, excluded_digits=[4, 6, 9])
print("y_train_3 : ",set(y_train_3))


y_train_1 :  {0, 2, 4, 5, 6, 8, 9}
y_train_2 :  {0, 1, 3, 4, 6, 7, 9}
y_train_3 :  {0, 1, 2, 3, 5, 7, 8}


In [5]:
train_sets = [(x_train_1, y_train_1 ),(x_train_2, y_train_2),(x_train_3, y_train_3)]

### Create Test dataset

In [6]:
x_test_1, y_test_1  = exclude_digits(x_test, y_test, excluded_digits=[1, 3, 7])
print("y_test_1 : ",set(y_test_1))

x_test_2, y_test_2  = exclude_digits(x_test, y_test, excluded_digits=[2, 5, 8])
print("y_test_2 : ",set(y_test_2))

x_test_3, y_test_3  = exclude_digits(x_test, y_test, excluded_digits=[4,6,9])
print("y_test_3 : ",set(y_test_3))


y_test_1 :  {0, 2, 4, 5, 6, 8, 9}
y_test_2 :  {0, 1, 3, 4, 6, 7, 9}
y_test_3 :  {0, 1, 2, 3, 5, 7, 8}


In [7]:
test_sets = [(x_test_1, y_test_1),(x_test_2, y_test_2),(x_test_3, y_test_3)]

### Prepare model training
- Sample model
- training method
- Evaluation method

In [8]:
class SimpleNN(tf.keras.Model):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.dense1 = tf.keras.layers.Dense(128, activation='relu')
        self.dense2 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

In [9]:
def train_model(model,pth_x,pth_y):
    batch_size = 64
    epochs = 20
    num_batches = len(pth_x) // batch_size
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    loss_fn = tf.keras.losses.CategoricalCrossentropy()

    # Convert labels to one-hot encoding
    pth_y_onehot = tf.keras.utils.to_categorical(pth_y, num_classes=10)

    
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        for i in range(num_batches):
            # Get a batch of data
            start = i * batch_size
            end = start + batch_size
            x_batch = pth_x[start:end]
            y_batch = pth_y_onehot[start:end]
            
            with tf.GradientTape() as tape:
                predictions = model(x_batch, training=True)  # Forward pass
                loss = loss_fn(y_batch, predictions)        # Compute loss
            

            gradients = tape.gradient(loss, model.trainable_variables) 
        
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))  # Update weights

            if i % 200 == 0:  # Print progress every 200 batches
                print(f"Batch {i}/{num_batches}, Loss: {loss.numpy():.4f}")

In [26]:
def evaluate_model(model,x_test,y_test):
    loss_fn = tf.keras.losses.CategoricalCrossentropy()

    y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=10)

    # Evaluate the model
    test_loss = loss_fn(y_test_onehot, model(x_test))
    print("test_loss : ",test_loss)
    test_accuracy = tf.keras.metrics.categorical_accuracy(y_test_onehot, model(x_test))
    print("test_accuracy : ",test_accuracy)
    return test_loss, test_accuracy

### Fed learning

In [27]:
class FlowerClient(NumPyClient):
    def __init__(self, net, trainset, testset):
        self.net = net
        self.trainset = trainset
        self.testset = testset

    # Train the model
    def fit(self, parameters, config):
        self.net.set_weights(parameters)
        pth_x,pth_y = self.trainset # (x,y)
        train_model(self.net,pth_x,pth_y )
        return self.net.get_weights(), len(pth_y), {}

    # Test the model
    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        self.net.set_weights(parameters)
        loss, accuracy = evaluate_model(self.net, self.testset)
        return loss, len(self.testset), {"accuracy": accuracy}

In [28]:
# Client function
def client_fn(context: Context) -> Client:
    net = SimpleNN()
    partition_id = int(context.node_config["partition-id"])
    client_train = train_sets[int(partition_id)]
    client_test = test_sets[int(partition_id)]
    return FlowerClient(net, client_train, client_test).to_client()

In [29]:
client = ClientApp(client_fn)

In [30]:
# test code 
net = SimpleNN()
_, accuracy137 = evaluate_model(net, x_test_1,y_test_1)
print("test accuracy on [1,3,7]: ", accuracy137)

test_loss :  tf.Tensor(2.573329, shape=(), dtype=float32)
test_accuracy :  tf.Tensor([0. 0. 1. ... 0. 0. 0.], shape=(6827,), dtype=float32)
test accuracy on [1,3,7]:  tf.Tensor([0. 0. 1. ... 0. 0. 0.], shape=(6827,), dtype=float32)


### Define evaluate for model testing
- The evaluate method evaluates the performance of the neural network model using the provided parameters and the test dataset (testset).

In [14]:
def evaluate(server_round, parameters, config):
    net = SimpleNN()
    net.set_weights(parameters)

    _, accuracy137 = evaluate_model(net, x_test_1,y_test_1)
    _, accuracy258 = evaluate_model(net, x_test_2,y_test_2)
    _, accuracy469 = evaluate_model(net, x_test_3,y_test_3)

    print("test accuracy on [1,3,7]: ", accuracy137)
    print("test accuracy on [2,5,8]: ", accuracy258)
    print("test accuracy on [4,6,9]: ", accuracy469)


The federated averaging strategy (`strategy.FedAvg`) is created for federated learning.

In [15]:
net = SimpleNN()
params = ndarrays_to_parameters(net.get_weights())

def server_fn(context: Context):
    strategy = FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=0.0,
        initial_parameters=params,
        evaluate_fn=evaluate,
    )
    config=ServerConfig(num_rounds=3)
    return ServerAppComponents(
        strategy=strategy,
        config=config,
    )

In [16]:
server = ServerApp(server_fn=server_fn)

In [17]:
from logging import ERROR
backend_setup = {"init_args": {"logging_level": ERROR, "log_to_driver": False}}


In [18]:
# Initiate the simulation passing the server and client apps
# Specify the number of super nodes that will be selected on every round
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=3,
    #backend_config=backend_setup,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]


test accuracy on [1,3,7]:  tf.Tensor([0. 0. 0. ... 0. 0. 0.], shape=(6827,), dtype=float32)
test accuracy on [2,5,8]:  tf.Tensor([0. 0. 0. ... 0. 0. 0.], shape=(7102,), dtype=float32)
test accuracy on [4,6,9]:  tf.Tensor([0. 0. 0. ... 0. 0. 0.], shape=(7051,), dtype=float32)


[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=25304)[0m Epoch 1/20
[36m(ClientAppActor pid=25304)[0m Batch 0/638, Loss: 2.2316
[36m(ClientAppActor pid=25302)[0m Epoch 2/20[32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m
[36m(ClientAppActor pid=25304)[0m Batch 200/638, Loss: 0.0225[32m [repeated 10x across cluster][0m
[36m(ClientAppActor pid=25302)[0m Epoch 3/20[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=25304)[0m Batch 400/638, Loss: 0.0916[32m [repeated 10x across cluster][0m
[36m(ClientAppActor pid=25302)[0m Epoch 4/20[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=25304)[0m Batch 600/638, Loss: 0.0750[32m [repeated 10x across cluster][0m
[36m(ClientAppActor pid=25302)[0m Epoch 6/20[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=25302)[0m Batch 0/668

: 

: 