# Federated model training simulation and upload

Simulation of the federated training and evaluation procedure for a Federated Model. Once the model completes successfully its simulation it is uploaded to a local model catalogue.

## Preparation of the environment

An environment with python3.11.* is needed for this simulation, provided with some other packages.

If the virtual environment is not ready yet with the dependencies needed for the simulation and for the used model, the following cell will install in the current environment all the necessary packages.

For any doubts on the requirement or on how to create a federated model, see the Github repository of the FedModelKit package [here](https://github.com/synthema-project/app-model_store-interface/tree/dev)

In [None]:
!pip install flwr[simulation]==1.9
# !pip install 'flwr[simulation]==1.9' # for Mac users
!pip install FedModelKit
!pip install mlflow
!pip install openpyxl

The following cell is used to initialize the directory to upload a model to the platform. In this case the additional argument "app" is used to create also the server and client scripts that will be used for the simulation 

In [None]:
!fmk init app

## Federated Model Building

### Local Learner Definition

The following cell defines a function `create_local_learner` that sets up the local learner for the federated learning process, and a function `create_aggregator` that defines the federated aggregation strategy. The local learner includes the definition of a simple logistic regression model (`SimpleLR`) and its associated methods for data preparation, training, evaluation, and parameter management. The `SimpleLR` class handles the following tasks:

- **Data Preparation**: Splits the data into training and validation sets, encodes categorical variables, and scales numerical variables.
- **Training**: Trains the logistic regression model on the prepared training data.
- **Evaluation**: Evaluates the model on the validation data and returns accuracy metrics.
- **Parameter Management**: Includes methods to set and get model parameters, which are essential for the federated learning process where model parameters are shared and aggregated across multiple clients.

These functions are crucial for the federated learning simulation as they provide the local model that each client will train and evaluate during the federated learning rounds and the aggregation strategy performed by the central node on model's parameters and metrics coming from the clients.

`%%writefile model_example.py` this line produces a model_example.py script where the content of the cell is stored

---

`THE MODEL CAN BE MODIFIED WITH WHATEVER MODEL THE USER WANTS TO TEST`

The model must be created according to the protocols explained in the Github repository of the FedModelKit package [here](https://github.com/synthema-project/app-model_store-interface/tree/dev)

---

In [None]:
%%writefile model_example.py
import FedModelKit as fmk

# Function defining the local learner
def create_local_learner():
    import pandas as pd
    import numpy as np
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score
    from sklearn.preprocessing import StandardScaler, OneHotEncoder
    import flwr
    from collections import OrderedDict
    import pickle
    from pathlib import Path

    # Define the local learner class
    class SimpleLR:
        def __init__(self,
                     test_size=0.2, 
                     random_state=42) -> None:
            # Initialize the SimpleLR class with test size and random state
            self.test_size = test_size
            self.random_state = random_state
            self.model = LogisticRegression(warm_start=True)

        def prepare_data(self, data: pd.DataFrame) -> None:
            # Divide numerical and categorical columns
            categorical_cols = ['Gender', 'perf_status', 'secondary', 'ahd', 'eln_2017']
            numerical_cols = data.columns.difference(categorical_cols).drop("OS_Status")

            # Encode categorical data using pre-trained OneHotEncoder
            with open(Path(__file__).parent/"src"/"one_hot_encoder.pkl", 'rb') as f:
                one_hot_encoder = pickle.load(f)
            categorical_data = one_hot_encoder.transform(data[categorical_cols])
            
            # Scale numerical data
            scaler = StandardScaler()
            numerical_data = scaler.fit_transform(data[numerical_cols])

            # Get data and labels dataframes
            labels = data['OS_Status']
            data = pd.DataFrame(np.hstack([categorical_data, numerical_data]))

            # Split data into training and validation sets
            self.train_data, self.val_data, self.train_labels, self.val_labels = train_test_split(data, labels, test_size=self.test_size, random_state=self.random_state)

        def _parameters_to_dict(self, params_record: flwr.common.ParametersRecord) -> OrderedDict:
            # Convert ParametersRecord to an OrderedDict
            state_dict = OrderedDict()
            for k, v in params_record.items():
                state_dict[k] = self._basic_array_deserialisation(v)
            return state_dict

        def _dict_to_parameter_record(self, 
            parameters: OrderedDict["str", flwr.common.NDArray],
        ) -> flwr.common.ParametersRecord:
            # Convert OrderedDict to ParametersRecord
            state_dict = OrderedDict()
            for k, v in parameters.items():
                state_dict[k] = self._ndarray_to_array(v)

            return flwr.common.ParametersRecord(state_dict)

        def _ndarray_to_array(self, ndarray: flwr.common.NDArray) -> flwr.common.Array:
            """Represent NumPy ndarray as Array."""
            return flwr.common.Array(
                data=ndarray.tobytes(),
                dtype=str(ndarray.dtype),
                stype="numpy.ndarray.tobytes",
                shape=list(ndarray.shape),
            )

        def _basic_array_deserialisation(self, array: flwr.common.Array) -> flwr.common.NDArray:
            # Deserialize Array to NumPy ndarray
            return np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape)

        def train_round(self) -> flwr.common.MetricsRecord:
            # Train the model on the training data
            self.model.fit(self.train_data, self.train_labels)
            predictions = self.model.predict(self.train_data)
            loss = np.mean(predictions != self.train_labels)

            # Create a MetricsRecord for training loss
            return flwr.common.MetricsRecord({"loss": loss})

        def evaluate(self) -> flwr.common.MetricsRecord:
            # Evaluate the model on the validation data
            predictions = self.model.predict(self.val_data)
            accuracy = float(accuracy_score(self.val_labels, predictions))

            # Return a MetricsRecord for accuracy
            return flwr.common.MetricsRecord({"accuracy": accuracy}) 

        def _set_initial_parameters(self) -> None:
            # Initialize model parameters with dummy data
            X_dummy = np.zeros((2, 21))  # 21 features, the length of the feature vector
            y_dummy = np.array([0, 1])   # Dummy target
            # Fit the model once to initialize coef_
            self.model.fit(X_dummy, y_dummy)

        def set_parameters(self, parameters: flwr.common.ParametersRecord):
            # Convert the ParametersRecord back into an OrderedDict.
            state_dict = self._parameters_to_dict(parameters)
            
            # Set the model's parameters.
            if not hasattr(self.model, "coef_"):
                self._set_initial_parameters()
            self.model.coef_ = state_dict["coef_"]
            self.model.intercept_ = state_dict["intercept_"]

        def get_parameters(self) -> flwr.common.ParametersRecord:
            # Get the model's parameters as a ParametersRecord
            if not hasattr(self.model, "coef_"):
                self._set_initial_parameters()
            param_dict = {}
            param_dict["coef_"] = np.array(self.model.coef_)
            param_dict["intercept_"] = np.array(self.model.intercept_)

            return self._dict_to_parameter_record(OrderedDict(param_dict))
        
    return SimpleLR()


# Function defining the aggregator
def create_aggregator():
    from collections import OrderedDict
    import numpy as np
    import flwr
    from typing import Optional

    # Define the custom aggregator class
    class CustomAggregator:

        def _parameters_to_dict(self, params_record: flwr.common.ParametersRecord) -> OrderedDict:
            # Convert ParametersRecord to an OrderedDict
            state_dict = OrderedDict()
            for k, v in params_record.items():
                state_dict[k] = self._basic_array_deserialisation(v)
            return state_dict

        def _dict_to_parameter_record(self, 
            parameters: OrderedDict["str", flwr.common.NDArray],
        ) -> flwr.common.ParametersRecord:
            # Convert OrderedDict to ParametersRecord
            state_dict = OrderedDict()
            for k, v in parameters.items():
                state_dict[k] = self._ndarray_to_array(v)

            return flwr.common.ParametersRecord(state_dict)

        def _ndarray_to_array(self, ndarray: flwr.common.NDArray) -> flwr.common.Array:
            """Represent NumPy ndarray as Array."""
            return flwr.common.Array(
                data=ndarray.tobytes(),
                dtype=str(ndarray.dtype),
                stype="numpy.ndarray.tobytes",
                shape=list(ndarray.shape),
            )

        def _basic_array_deserialisation(self, array: flwr.common.Array) -> flwr.common.NDArray:
            # Deserialize Array to NumPy ndarray
            return np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape)

        def aggregate_parameters(self, results: list[flwr.common.ParametersRecord], config: Optional[flwr.common.ConfigsRecord]=None
            ) -> flwr.common.ParametersRecord:
                parameters = [self._parameters_to_dict(param) for param in results]
                keys = parameters[0].keys()
                result = OrderedDict()
                for key in keys:
                    # Init array
                    this_array: np.ndarray = np.zeros_like(parameters[0][key])
                    for p in parameters:
                        this_array += p[key]
                    result[key] = this_array / len(results)
                return self._dict_to_parameter_record(result)

        def aggregate_metrics(self, results: list[flwr.common.MetricsRecord], config: Optional[flwr.common.ConfigsRecord]=None) -> flwr.common.MetricsRecord:
                keys = results[0].keys()
                result = OrderedDict()
                for key in keys:
                    # Init array
                    cumsum = 0.0
                    for m in results:
                        if not isinstance(m[key], (int, float)):
                            raise ValueError(
                                f"flwr.common.MetricsRecord value type not supported: {type(m[key])}"
                            )
                        cumsum += m[key]  # type: ignore
                    result[key] = cumsum / len(results)
                return flwr.common.MetricsRecord(result)
    
    return CustomAggregator()

`Type-check:` make sure to type-check the function/s you created in the following cell according to the protocols indicated in the package [repository](https://github.com/synthema-project/app-model_store-interface/tree/dev) referenced before.

In [None]:
import FedModelKit as fmk
from model_example import create_local_learner as create_ll_check # Use alias here to avoid name conflict when uploading the model
from model_example import create_aggregator as create_agg_check # Use alias here to avoid name conflict when uploading the model

fmk.FederatedModel(create_local_learner=create_ll_check, 
                    model_name='simple_lr',
                    create_aggregator=create_agg_check,
                    aggregator_name='custom_aggregator')

The following cell defines a function `load_data` that loads and preprocesses the local dataset for each client in the federated learning setup. Here are the key steps and the role of the OneHotEncoder matrix:

1. **Load Local Data**: The function loads the dataset from an Excel file and drops the 'Index' column. The dataset used was obtained from a public data repository (Tazi et al. https://github.com/papaemmelab/Tazi_NatureC_AML?tab=readme-ov-file), then reduced and preprocessed.

2. **OneHotEncoder for Categorical Variables**:
    - The function uses `OneHotEncoder` to encode categorical variables, ensuring consistency across all clients.
    - The OneHotEncoder matrix is fitted on the entire dataset, capturing the complete information about the categorical variables.
    - This matrix is crucial because each client only has a portion of the dataset, and without the complete information, they would have incomplete or inconsistent encodings.

3. **Save OneHotEncoder Matrix**:
    - The OneHotEncoder matrix is saved as a pickle file in the `src` directory.
    - This directory will be uploaded with the federated model, ensuring that the encoder is available to all clients when they download the model.
    - This ensures that the model can correctly encode categorical variables during training and evaluation, maintaining consistency across all clients.

4. **Split Data Among Clients**: The dataset is split into portions based on the number of clients, and each client receives its respective portion.

This function ensures that all clients have consistent and complete information about the categorical variables, which is essential for the federated learning process.

---

`THE LOAD_DATA FUNCTION MUST BE ADAPTED TO THE USER'S CUSTOM MODEL`

If the local model needs to be preprocessed with information from the whole dataset, such information must be stored inside the 'src' folder. For this simulation the structure-related information about the whole dataset is created and stored in the 'src' folder by each client, while for users who directly want to upload their model the process must be performed just  once, right before the upload.

`Careful!:` don't upload to the src folder sensitive information from the dataset, just structure-related information.

---

In [None]:
%%writefile load_data.py
import pandas as pd
import numpy as np
import pickle
from sklearn.preprocessing import OneHotEncoder
from pathlib import Path

def load_data(num_clients: int, client_id: int) -> pd.DataFrame:
    """
    Load and preprocess the local dataset (for this simulation will be a split of the whole dataset) for each client in the 
    federated learning process. It also stores the OneHotEncoder matrix in a pickle file in the src directory that will be
    uploaded with the federated model. It will be used by the model as information to encode the categorical data.

    Args:
        num_clients (int): The total number of clients participating in the federated learning process.
        client_id (int): The unique identifier for the current client.

    Returns:
        pd.DataFrame: The portion of the dataset assigned to the current client.
    """

    # Load local data from an Excel file and drop the 'Index' column
    data = pd.read_excel('./AML_preprocessed_dataset.xlsx').drop("Index", axis=1) # !!Change the path to the dataset!!

    # Get mapping for categorical variables using OneHotEncoder to ensure consistency across clients
    categorical_cols = ['Gender', 'perf_status', 'secondary', 'ahd', 'eln_2017']
    one_hot_encoder = OneHotEncoder(sparse_output=False)
    matrix = one_hot_encoder.fit(data[categorical_cols])
    
    # Store the OneHotEncoder matrix in a pickle file in the src directory 
    # that will be uploaded with the federated model
    if not Path('src').exists():
        Path('src').mkdir()
    if not Path('src/one_hot_encoder.pkl').exists():
        with open('src/one_hot_encoder.pkl', 'wb') as f:
            pickle.dump(matrix, f)

    # Split the data among clients based on the number of clients specified in the config
    data_split = np.array_split(data, num_clients, axis=0) 
    data = data_split[client_id]

    return pd.DataFrame(data)

## Simulation of Federated Training and Evaluation

In this section, we will simulate the federated training and evaluation process using the defined client and server applications. The simulation will involve multiple rounds of training and evaluation, where the global model parameters are updated based on the aggregated results from the clients.

### Client

In the apps folder, generated by the initialization process performed before, there is a Python script named `client_app.py` which defines the client processes for the federated learning simulation. Here are the steps involved:

1. **Import Necessary Libraries**: The cell imports essential libraries such as `numpy`, `pandas`, `flwr`, `sklearn`, and `FedModelKit`.

2. **Define the Flower ClientApp**: An instance of `ClientApp` from Flower is created to handle the client-side operations.

3. **Define the `train` Function**:
    - **Log Local Context**: Logs the metrics from the previous round if available.
    - **Instantiate Model**: Creates an instance of the federated model using the `create_local_learner` function.
    - **Set Model Parameters**: Sets the model parameters received from the server.
    - **Load Local Data**: Loads the local dataset from an Excel file and preprocesses it.
    - **Prepare Data**: Prepares the data by encoding categorical variables and splitting it among clients.
    - **Local Training**: Trains the local model and retrieves training metrics.
    - **Construct Reply Message**: Constructs a reply message containing updated model parameters and training metrics.

4. **Define the `evaluate` Function**:
    - **Instantiate Model**: Creates an instance of the federated model using the `create_local_learner` function.
    - **Set Model Parameters**: Sets the model parameters received from the server.
    - **Load Local Data**: Loads the local dataset from an Excel file and preprocesses it.
    - **Prepare Data**: Prepares the data by encoding categorical variables and splitting it among clients.
    - **Evaluate Model**: Evaluates the model on the validation data and retrieves evaluation metrics.
    - **Construct Reply Message**: Constructs a reply message containing evaluation metrics.


These steps collectively define the client-side operations for training and evaluating the model in a federated learning setup.

---

`DON'T MODIFY THE CLIENT! AS IT REFLECTS THE STANDARD CLIENT PROCESS` (unless you know what you are doing)

---

### Server

In the apps folder, generated by the initialization process performed before, there is a Python script named `server_app.py` which defines the server-side operations for the federated learning simulation. Here are the steps involved:

1. **Import Necessary Libraries**: The cell imports essential libraries such as `flwr`, `FedModelKit`, and `create_local_learner` from `model_example.py`.

2. **Define the Flower ServerApp**: An instance of `ServerApp` from Flower is created to handle the server-side operations.

3. **Define the `main` Function**:
    - **Initialize Federated Model**: Creates an instance of the federated model using the `create_local_learner` function and sets up the global model and aggregation strategy.
    - **Server Rounds**: Iterates through multiple rounds of federated learning.
        - **Get Node IDs**: Retrieves the IDs of the participating clients.
        - **Create and Send Messages**: Constructs messages containing model parameters and configuration settings, and sends them to the clients.
        - **Wait for Client Replies**: Waits for the clients to complete their training and send back their results.
        - **Aggregate Parameters**: Aggregates the parameters received from the clients to update the global model.
        - **Evaluate the Model**: Sends the updated global model to the clients for evaluation and aggregates the evaluation metrics.

These steps collectively define the server-side operations for coordinating the training and evaluation of the model in a federated learning setup.

---

`DON'T MODIFY THE SERVER! AS IT REFLECTS THE STANDARD SERVER PROCESS` (unless you know what you are doing)

---

### Simulation start

The following command initiates the federated learning simulation by specifying the client and server applications and the number of supernodes (nodes that simulate multiple clients):


- `--client-app=client_app:app`: Specifies the client application script (`client_app.py`) and the application instance (`app`) to be used for the simulation.
- `--server-app=server_app:app`: Specifies the server application script (`server_app.py`) and the application instance (`app`) to be used for the simulation.
- `--num-supernodes=2`: Defines the number of supernodes to be used in the simulation. Each supernode can simulate multiple clients, allowing for scalable federated learning simulations.

In [None]:
!flower-simulation --client-app=apps.client_app:app --server-app=apps.server_app:app --num-supernodes=2

If the simulation completes successfully with the expected results, the model is ready to be uploaded to the model registry. This allows the model to be utilized in future federated learning tasks, ensuring that it behaves in the expected way.

## Model upload

In this section, we will describe the process of uploading the tested federated model to the model registry. This step ensures that the model is stored securely and can be accessed for future federated learning tasks.

A more detailed description of this part is provided at this link: [Github/FedModelKit](https://github.com/synthema-project/app-model_store-interface/tree/dev)

### Federated model creation with default aggregation strategy

In [None]:
%run model_example.py # Run the model_example.py file to directly get the functions defining the local learner and aggregator 
import FedModelKit as fmk

# Create an instance of the FederatedModel class from the FedModelKit module.
# This class handles the federated learning process, including model creation, training, and aggregation.
# The create_local_learner and create_aggregator functions is passed as an argument to define the local learner (model)
# for each client. Here the functions used are the ones defined in the model_example.py file.
federated_model = fmk.FederatedModel(create_local_learner=create_local_learner, # type: ignore # Make sure you already type-checked the function
                                     model_name='simple_lr',
                                     create_aggregator=create_aggregator, # type: ignore # Make sure you already type-checked the function
                                     aggregator_name='custom_aggregator')

### Local model registry server creation 

The model must be uploaded to a model registry server. For this example's sake a local molflow server is enough. It can be opened running the following command on a dedicated terminal (using the same environment used in this notebook):
```bash
mlflow server --host 0.0.0.0 --port 5000
```

### Federated model upload 

In [None]:
# Submit the federated model to the model registry server using the FedModelKit module.
# This function uploads the trained federated model to the specified platform URL.
# Parameters:
# - model: The federated model instance to be uploaded.
# - platform_url: The URL of the model registry server where the model will be uploaded.
# - username: The username for authentication (if required).
# - password: The password for authentication (if required).
# - experiment_name: The name of the experiment under which the model will be registered.
# - disease: The disease category associated with the model (e.g., AML for Acute Myeloid Leukemia).
# - trained: A boolean flag indicating whether the model has been trained.


# HERE you should produce the information about the whole dataset needed by you model and store it in
# the src directory, UNLESS you performed the simulation and the clients already stored it.

fmk.submit_fl_model(model=federated_model,
                    platform_url='http://localhost:5000',  # URL of the local MLflow server
                    username='username',  # Username for authentication (If no username is required , put a mock username)
                    password='password',  # Password for authentication (if no password is required, put a mock password)
                    experiment_name='simulation_experiment',  # Name of the experiment
                    disease='AML',  # Disease category associated with the model
                    trained=False  # Indicates that the model has been trained
                    )