# Iris Demo - Federated Learning
In this notebook the Iris dataset and the model developed in the previous notebook will be used to train a model using federated learning. There will be 3 clients, each with their own partition of the Iris dataset. In each round a central server will request the individual client weights from the trained models and average them to create a general model. The weights from this general model are then shared back to each individual client.

This notebook strongly borrows from the examples at https://flower.dev

In [1]:
# if using Google Colab
!pip install -q flwr[simulation]

In [2]:
# load libraries
import tensorflow as tf
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
import flwr as fl
from flwr.common import Metrics
from flwr.common.typing import NDArrays, Scalar

from typing import List
from typing import Tuple
from typing import Dict
from typing import Optional
import os

# Make TensorFlow logs less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [3]:
python_version = !python --version
print(
    f"Training on {'GPU' if tf.config.get_visible_devices('GPU') else 'CPU'}\
    using TensorFlow {tf.__version__}, Flower {fl.__version__} and {python_version[0]}"
)

Training on CPU    using TensorFlow 2.12.0, Flower 1.4.0 and Python 3.10.12


In [4]:
# global variables
NUM_CLIENTS = 3
EPOCHS = 50
ROUNDS = 5

In [5]:
def datasets():
    """Loads in the iris dataset from scikit-learn. The dataset is shuffled,
    one-hot-encoded, divided into datasets for three clients, and arrays of 
    train and test sets are returned.

    Returns
    -------
    Tuple[List[DataLoader], List[DataLoader]
        Local train datasets, and local test datasets
    """

    #make some arrays to hold each clients data
    X_train = []
    y_train = []
    X_test = []
    y_test = []
    # load the Iris dataset
    iris = load_iris()
    X = iris.data
    y = iris.target.reshape(-1, 1)

    # Shuffle the dataset
    indices = np.arange(len(X))
    np.random.shuffle(indices)
    X = X[indices]
    y = y[indices]

    # One-hot encode the target variable
    encoder = OneHotEncoder(sparse=False)
    y = encoder.fit_transform(y)

    # Split the shuffled data into three equal sets
    X_set, y_set = np.split(X, 3), np.split(y, 3)

    # Separate features and labels for each set
    X_client1, y_client1 = X_set[0], y_set[0]
    X_client2, y_client2 = X_set[1], y_set[1]
    X_client3, y_client3 = X_set[2], y_set[2]

    # Split the dataset into training and testing sets
    X_client1_train, X_client1_test, y_client1_train, y_client1_test = train_test_split(
        X_client1, y_client1, test_size=0.2, random_state=42
        )
    
    X_train.append(X_client1_train)
    y_train.append(y_client1_train)
    X_test.append(X_client1_test)
    y_test.append(y_client1_test)

    # Split the dataset into training and testing sets
    X_client2_train, X_client2_test, y_client2_train, y_client2_test = train_test_split(
        X_client2, y_client2, test_size=0.2, random_state=42
        )
    
    X_train.append(X_client2_train)
    y_train.append(y_client2_train)
    X_test.append(X_client2_test)
    y_test.append(y_client2_test)

    # Split the dataset into training and testing sets
    X_client3_train, X_client3_test, y_client3_train, y_client3_test = train_test_split(
        X_client3, y_client3, test_size=0.2, random_state=42
        ) 

    X_train.append(X_client3_train)
    y_train.append(y_client3_train)
    X_test.append(X_client3_test)
    y_test.append(y_client3_test)

    return X_train, y_train, X_test, y_test  

In [6]:
# instantiate the datasets
trainloaders_x, trainloaders_y, testloaders_x, testloaders_y = datasets()



In [7]:
# Define the Flower client
class FlowerClient(fl.client.NumPyClient):
    """A generic client object which can be instantiated.
    """
    def __init__(self, cid, model, x_train, y_train, x_test, y_test):
        self.cid = cid
        self.model = model
        self.x_train = x_train
        self.y_train = y_train
        self.x_test = x_test
        self.y_test = y_test

    def get_parameters(self, config):
        """Return the current local model parameters"""
        return self.model.get_weights()

    def fit(self, parameters, config):
        """Train the model on the local (train) data.

        Parameters
        ----------
        parameters: NDarrays 
            Model parameters (weights) received from the server
        
        config: Dict[str, Scalar]
            Server based configuration (needed only if you require dynamically changing values).

        Returns
        -------
        NDArrays
            Updated model parameters
        
        """
        self.model.set_weights(parameters)
        self.model.fit(self.x_train, self.y_train, epochs=EPOCHS, verbose=2)
        return self.model.get_weights(), len(self.x_train), {}

    def evaluate(self, parameters, config):
        """Evaluate model using the validation data.

        Parameters
        ----------
         parameters: NDarrays 
            Model parameters (weights) received from the server
        
        config: Dict[str, Scalar]
            Server based configuration (needed only if you require dynamically changing values).
        
        Returns
        -------
        loss : float
            The evaluation loss of the model on the local dataset.
        num_examples : int
            The number of examples used for evaluation.
        metrics : Dict[str, Scalar]
            A dictionary mapping arbitrary string keys to values of
            type bool, bytes, float, int, or str. It can be used to
            communicate arbitrary values back to the server.
        """
        self.model.set_weights(parameters)
        loss, acc = self.model.evaluate(self.x_test, self.y_test, verbose=2)
        return loss, len(self.x_test), {"accuracy": acc}
    

In [8]:
#create a unique Flower client
def client_fn(cid: str) -> fl.client.Client:
    """Create a Flower client representing a single entity/organization."""

    print("\nThis is client: ", cid)

    x_train_cid = trainloaders_x[int(cid)]
    y_train_cid = trainloaders_y[int(cid)]
    x_test_cid = testloaders_x[int(cid)]
    y_test_cid = testloaders_y[int(cid)]

    print("Loaded data for client: ", cid, "\n")

    # Load model
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(16, input_shape=(4,), activation='relu'),
        tf.keras.layers.Dense(3, activation='softmax')
        ])
    
    # Compile the model
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) 

    # Create and return client
    print("\nClient CID: " + str(cid) + " is done.\n")
    return FlowerClient(cid, model, x_train_cid, y_train_cid, x_test_cid, y_test_cid)

In [9]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    """A simple averaging for the metrics found in history."""
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

In [10]:

# instatiating a strategy, in this case FedAvg
strategy=fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=3,
    min_evaluate_clients=3,
    min_available_clients=NUM_CLIENTS,
    evaluate_metrics_aggregation_fn=weighted_average,
)

# launches the simulation, and saves the loss and accuracy to a history object
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=ROUNDS),
    strategy=strategy,
)

INFO flwr 2023-06-12 16:54:44,811 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
2023-06-12 16:54:51,233	INFO worker.py:1636 -- Started a local Ray instance.
INFO flwr 2023-06-12 16:54:54,439 | app.py:180 | Flower VCE: Ray initialized with resources: {'node:172.28.0.12': 1.0, 'CPU': 2.0, 'memory': 7759567259.0, 'object_store_memory': 3879783628.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'node:172.28.0.12': 1.0, 'CPU': 2.0, 'memory': 7759567259.0, 'object_store_memory': 3879783628.0}
INFO flwr 2023-06-12 16:54:54,451 | server.py:86 | Initializing global parameters
INFO:flwr:Initializing global parameters
INFO flwr 2023-06-12 16:54:54,455 | server.py:273 | Requesting initial parameters from one random client
INFO:flwr:Requesting initial parameters from one random client
INFO flwr 2023-06-12 16:55:08,861 | server.py:277 | Received i

[2m[36m(launch_and_get_parameters pid=28529)[0m 
[2m[36m(launch_and_get_parameters pid=28529)[0m This is client:  2
[2m[36m(launch_and_get_parameters pid=28529)[0m Loaded data for client:  2 
[2m[36m(launch_and_get_parameters pid=28529)[0m 
[2m[36m(launch_and_get_parameters pid=28529)[0m 
[2m[36m(launch_and_get_parameters pid=28529)[0m Client CID: 2 is done.
[2m[36m(launch_and_get_parameters pid=28529)[0m 
[2m[36m(launch_and_fit pid=28529)[0m 
[2m[36m(launch_and_fit pid=28529)[0m This is client:  2
[2m[36m(launch_and_fit pid=28529)[0m Loaded data for client:  2 
[2m[36m(launch_and_fit pid=28529)[0m 
[2m[36m(launch_and_fit pid=28529)[0m 
[2m[36m(launch_and_fit pid=28529)[0m Client CID: 2 is done.
[2m[36m(launch_and_fit pid=28529)[0m 
[2m[36m(launch_and_fit pid=28529)[0m Epoch 1/50
[2m[36m(launch_and_fit pid=28529)[0m 2/2 - 2s - loss: 2.1892 - accuracy: 0.3000 - 2s/epoch - 784ms/step
[2m[36m(launch_and_fit pid=28529)[0m Epoch 2/50
[2m[

DEBUG flwr 2023-06-12 16:55:16,543 | server.py:232 | fit_round 1 received 3 results and 0 failures
DEBUG:flwr:fit_round 1 received 3 results and 0 failures
DEBUG flwr 2023-06-12 16:55:16,560 | server.py:168 | evaluate_round 1: strategy sampled 3 clients (out of 3)
DEBUG:flwr:evaluate_round 1: strategy sampled 3 clients (out of 3)


[2m[36m(launch_and_fit pid=28528)[0m Epoch 22/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 23/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 24/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 25/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 26/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 27/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 28/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 29/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 30/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 31/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 32/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 33/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 34/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 35/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 36/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 37/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 38/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 39/50
[2m[36m(launch_and_fit pid=28528)[0m Epoch 40/50
[2m[36m(la

DEBUG flwr 2023-06-12 16:55:17,559 | server.py:182 | evaluate_round 1 received 3 results and 0 failures
DEBUG:flwr:evaluate_round 1 received 3 results and 0 failures
DEBUG flwr 2023-06-12 16:55:17,569 | server.py:218 | fit_round 2: strategy sampled 3 clients (out of 3)
DEBUG:flwr:fit_round 2: strategy sampled 3 clients (out of 3)


[2m[36m(launch_and_fit pid=28529)[0m [32m [repeated 26x across cluster][0m
[2m[36m(launch_and_fit pid=28529)[0m Client CID: 2 is done.[32m [repeated 6x across cluster][0m


DEBUG flwr 2023-06-12 16:55:20,578 | server.py:232 | fit_round 2 received 3 results and 0 failures
DEBUG:flwr:fit_round 2 received 3 results and 0 failures
DEBUG flwr 2023-06-12 16:55:20,590 | server.py:168 | evaluate_round 2: strategy sampled 3 clients (out of 3)
DEBUG:flwr:evaluate_round 2: strategy sampled 3 clients (out of 3)


[2m[36m(launch_and_fit pid=28529)[0m Epoch 45/50[32m [repeated 145x across cluster][0m
[2m[36m(launch_and_evaluate pid=28528)[0m 1/1 - 1s - loss: 0.6388 - accuracy: 0.7000 - 855ms/epoch - 855ms/step[32m [repeated 196x across cluster][0m
[2m[36m(launch_and_evaluate pid=28528)[0m This is client:  0[32m [repeated 7x across cluster][0m
[2m[36m(launch_and_evaluate pid=28528)[0m Loaded data for client:  0 [32m [repeated 7x across cluster][0m


DEBUG flwr 2023-06-12 16:55:22,741 | server.py:182 | evaluate_round 2 received 3 results and 0 failures
DEBUG:flwr:evaluate_round 2 received 3 results and 0 failures
DEBUG flwr 2023-06-12 16:55:22,745 | server.py:218 | fit_round 3: strategy sampled 3 clients (out of 3)
DEBUG:flwr:fit_round 3: strategy sampled 3 clients (out of 3)


[2m[36m(launch_and_fit pid=28529)[0m [32m [repeated 20x across cluster][0m
[2m[36m(launch_and_fit pid=28529)[0m Client CID: 2 is done.[32m [repeated 5x across cluster][0m
[2m[36m(launch_and_fit pid=28529)[0m Epoch 12/50[32m [repeated 18x across cluster][0m
[2m[36m(launch_and_fit pid=28528)[0m 2/2 - 0s - loss: 0.6098 - accuracy: 0.8000 - 16ms/epoch - 8ms/step[32m [repeated 91x across cluster][0m
[2m[36m(launch_and_fit pid=28529)[0m This is client:  0[32m [repeated 4x across cluster][0m
[2m[36m(launch_and_fit pid=28529)[0m Loaded data for client:  0 [32m [repeated 4x across cluster][0m


DEBUG flwr 2023-06-12 16:55:28,171 | server.py:232 | fit_round 3 received 3 results and 0 failures
DEBUG:flwr:fit_round 3 received 3 results and 0 failures
DEBUG flwr 2023-06-12 16:55:28,181 | server.py:168 | evaluate_round 3: strategy sampled 3 clients (out of 3)
DEBUG:flwr:evaluate_round 3: strategy sampled 3 clients (out of 3)
DEBUG flwr 2023-06-12 16:55:29,230 | server.py:182 | evaluate_round 3 received 3 results and 0 failures
DEBUG:flwr:evaluate_round 3 received 3 results and 0 failures
DEBUG flwr 2023-06-12 16:55:29,238 | server.py:218 | fit_round 4: strategy sampled 3 clients (out of 3)
DEBUG:flwr:fit_round 4: strategy sampled 3 clients (out of 3)


[2m[36m(launch_and_fit pid=28528)[0m [32m [repeated 24x across cluster][0m
[2m[36m(launch_and_fit pid=28528)[0m Client CID: 0 is done.[32m [repeated 6x across cluster][0m
[2m[36m(launch_and_fit pid=28529)[0m Epoch 13/50[32m [repeated 155x across cluster][0m
[2m[36m(launch_and_fit pid=28529)[0m 2/2 - 0s - loss: 0.5569 - accuracy: 0.8250 - 8ms/epoch - 4ms/step[32m [repeated 181x across cluster][0m
[2m[36m(launch_and_fit pid=28529)[0m This is client:  1[32m [repeated 6x across cluster][0m
[2m[36m(launch_and_fit pid=28529)[0m Loaded data for client:  1 [32m [repeated 6x across cluster][0m


DEBUG flwr 2023-06-12 16:55:32,360 | server.py:232 | fit_round 4 received 3 results and 0 failures
DEBUG:flwr:fit_round 4 received 3 results and 0 failures
DEBUG flwr 2023-06-12 16:55:32,370 | server.py:168 | evaluate_round 4: strategy sampled 3 clients (out of 3)
DEBUG:flwr:evaluate_round 4: strategy sampled 3 clients (out of 3)
DEBUG flwr 2023-06-12 16:55:33,389 | server.py:182 | evaluate_round 4 received 3 results and 0 failures
DEBUG:flwr:evaluate_round 4 received 3 results and 0 failures
DEBUG flwr 2023-06-12 16:55:33,393 | server.py:218 | fit_round 5: strategy sampled 3 clients (out of 3)
DEBUG:flwr:fit_round 5: strategy sampled 3 clients (out of 3)


[2m[36m(launch_and_fit pid=28528)[0m [32m [repeated 28x across cluster][0m
[2m[36m(launch_and_fit pid=28528)[0m Client CID: 2 is done.[32m [repeated 7x across cluster][0m
[2m[36m(launch_and_fit pid=28528)[0m Epoch 11/50[32m [repeated 243x across cluster][0m


DEBUG flwr 2023-06-12 16:55:36,288 | server.py:232 | fit_round 5 received 3 results and 0 failures
DEBUG:flwr:fit_round 5 received 3 results and 0 failures
DEBUG flwr 2023-06-12 16:55:36,301 | server.py:168 | evaluate_round 5: strategy sampled 3 clients (out of 3)
DEBUG:flwr:evaluate_round 5: strategy sampled 3 clients (out of 3)
DEBUG flwr 2023-06-12 16:55:37,266 | server.py:182 | evaluate_round 5 received 3 results and 0 failures
DEBUG:flwr:evaluate_round 5 received 3 results and 0 failures
INFO flwr 2023-06-12 16:55:37,270 | server.py:147 | FL finished in 28.396297485000105
INFO:flwr:FL finished in 28.396297485000105
INFO flwr 2023-06-12 16:55:37,279 | app.py:218 | app_fit: losses_distributed [(1, 0.8680046598116556), (2, 0.6690971851348877), (3, 0.5392269790172577), (4, 0.45218942562739056), (5, 0.3788420458634694)]
INFO:flwr:app_fit: losses_distributed [(1, 0.8680046598116556), (2, 0.6690971851348877), (3, 0.5392269790172577), (4, 0.45218942562739056), (5, 0.3788420458634694)]
INF

In [11]:
#print the history of the simulation to screen
history

History (loss, distributed):
	round 1: 0.8680046598116556
	round 2: 0.6690971851348877
	round 3: 0.5392269790172577
	round 4: 0.45218942562739056
	round 5: 0.3788420458634694
History (metrics, distributed, evaluate):
{'accuracy': [(1, 0.8333333333333334), (2, 0.6666666567325592), (3, 0.8333333333333334), (4, 0.8666666547457377), (5, 0.9666666587193807)]}

[2m[36m(launch_and_evaluate pid=28529)[0m 1/1 - 0s - loss: 0.4859 - accuracy: 1.0000 - 285ms/epoch - 285ms/step[32m [repeated 189x across cluster][0m
[2m[36m(launch_and_evaluate pid=28529)[0m This is client:  1[32m [repeated 9x across cluster][0m
[2m[36m(launch_and_evaluate pid=28529)[0m Loaded data for client:  1 [32m [repeated 9x across cluster][0m
