# TFX  pipeline in Flower framework

The notebook is used for local testing, replicating Flower's tutorial to implement client and server at the same place.

In [None]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import glob
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

Training on cpu
Flower 1.18.0 / PyTorch 2.2.2+cpu


In [None]:
import absl
import tensorflow_model_analysis as tfma
from tfx.components import CsvExampleGen
from tfx.components import Evaluator
from tfx.components import ExampleValidator
from tfx.components import Pusher
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components import Transform
from tfx.dsl.components.common import resolver
from tfx.dsl.experimental import latest_blessed_model_resolver
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.types import Channel
from tfx.types.standard_artifacts import Model
from tfx.types.standard_artifacts import ModelBlessing

import tensorflow as tf

In [None]:
from taxi_utils_native_keras import _build_keras_model

def save_model_template(save_path: str):
    model = _build_keras_model()
    model.save(save_path)

# Run once to prepare evaluation model
save_model_template("taxi_utils_native_keras_model_template")

### Load dataset

In [None]:
NUM_PARTITIONS = 5
NUM_CLIENTS = 5
BATCH_SIZE = 250
CSV_DIR = "../tfx-flower/data/simple"

FEATURE_COLUMNS = ["pickup_community_area", "fare", "trip_start_month", "trip_start_hour", "trip_start_day", "trip_start_timestamp",
                   "pickup_latitude", "pickup_longitude", "dropoff_latitude", "dropoff_longitude", "trip_miles", "pickup_census_tract",
                   "dropoff_census_tract", "payment_type", "company", "trip_seconds", "dropoff_community_area"]

TARGET_COLUMN = "tips" 

class TaxiDataset(Dataset):
    def __init__(self, df):
        self.features = torch.tensor(df[FEATURE_COLUMNS].values, dtype=torch.float32)
        self.labels = torch.tensor(df[TARGET_COLUMN].values, dtype=torch.float32)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

def load_datasets(partition_id: int, num_partitions: int):
    file_path = os.path.join(CSV_DIR, f"client_{partition_id+1}.csv")
    df = pd.read_csv(file_path)

    dataset = TaxiDataset(df)

    # 80/20 train/val split
    val_size = int(0.2 * len(dataset))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

    trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    valloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    test_df = pd.read_csv("../tfx-flower/data/simple/global_test.csv")
    test_dataset = TaxiDataset(test_df)
    testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    return trainloader, valloader, testloader

### TFX Flower Client

In [None]:
def _create_pipeline(client_id: int, module_file: str, output_dir: str):
    data_root = f"../data/simple/client_{client_id}.csv"
    pipeline_root = f"{output_dir}/pipeline_client_{client_id}"
    metadata_path = f"{output_dir}/metadata_client_{client_id}.db"
    serving_model_dir = os.path.join(output_dir, f"serving_model_client_{client_id}")

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=os.path.dirname(data_root))
    
    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=True)
    
    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])
        
    # Performs transformations and feature engineering in training and serving.
    transform = Transform(
        examples=example_gen.outputs['examples'],
        schema=schema_gen.outputs['schema'],
        module_file=module_file)

    # Uses user-provided Python function that implements a model.
    trainer = Trainer(
        module_file=module_file,
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        train_args=trainer_pb2.TrainArgs(num_steps=1000),
        eval_args=trainer_pb2.EvalArgs(num_steps=150))

    # # Get the latest blessed model for model validation.
    # model_resolver = resolver.Resolver(
    #     strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
    #     model=Channel(type=Model),
    #     model_blessing=Channel(
    #         type=ModelBlessing)).with_id('latest_blessed_model_resolver')

    # # Uses TFMA to compute a evaluation statistics over features of a model and
    # # perform quality validation of a candidate model (compared to a baseline).
    # eval_config = tfma.EvalConfig(
    #     model_specs=[
    #         tfma.ModelSpec(
    #             signature_name='serving_default', label_key='tips_xf',
    #             preprocessing_function_names=['transform_features'])
    #     ],
    #     slicing_specs=[tfma.SlicingSpec()],
    #     metrics_specs=[
    #         tfma.MetricsSpec(metrics=[
    #             tfma.MetricConfig(
    #                 class_name='BinaryAccuracy',
    #                 threshold=tfma.MetricThreshold(
    #                     value_threshold=tfma.GenericValueThreshold(
    #                         lower_bound={'value': 0.6}),
    #                     # Change threshold will be ignored if there is no
    #                     # baseline model resolved from MLMD (first run).
    #                     change_threshold=tfma.GenericChangeThreshold(
    #                         direction=tfma.MetricDirection.HIGHER_IS_BETTER,
    #                         absolute={'value': -1e-10})))
    #         ])
    #     ])
    # evaluator = Evaluator(
    #     examples=example_gen.outputs['examples'],
    #     model=trainer.outputs['model'],
    #     baseline_model=model_resolver.outputs['model'],
    #     eval_config=eval_config)

    # # Checks whether the model passed the validation steps and pushes the model
    # # to a file destination if check passed.
    # pusher = Pusher(
    #     model=trainer.outputs['model'],
    #     model_blessing=evaluator.outputs['blessing'],
    #     push_destination=pusher_pb2.PushDestination(
    #         filesystem=pusher_pb2.PushDestination.Filesystem(
    #             base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=f"client_{client_id}_pipeline",
        pipeline_root=pipeline_root,
        components=[example_gen, statistics_gen, schema_gen, example_validator, transform, trainer],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(metadata_path)
    )

def run_pipeline(client_id: int, module_file: str, output_dir: str):
    BeamDagRunner().run(_create_pipeline(client_id=client_id, module_file=module_file, output_dir=output_dir))

ModuleNotFoundError: No module named 'absl'

### Flower Client

In [None]:
def get_parameters_from_keras_model(model):
    return [w.numpy() for w in model.get_weights()]

def set_parameters_to_keras_model(model, weights):
    model.set_weights(weights)

def find_latest_model_path(client_id: int, base_dir="tfx_output"):
    model_dir = os.path.join("tfx-flower", base_dir, f"pipeline_client_{client_id}", "Trainer")
    candidates = sorted(glob.glob(os.path.join(model_dir, "model", "*")), reverse=True)
    if not candidates:
        raise FileNotFoundError(f"No model found in {model_dir}")
    return candidates[0]

In [None]:
class TFXFlowerClient(NumPyClient):
    def __init__(self, partition_id):
        self.partition_id = partition_id
        self.model = None

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return [w.numpy() for w in self.model.get_weights()]

    def fit(self, parameters, config):
        run_pipeline(self.partition_id, "taxi_utils_native_keras.py", "tfx_output")
        model_path = find_latest_model_path(self.partition_id)
        self.model = tf.keras.models.load_model(model_path)
        self.model.set_weights(parameters)
        return self.get_parameters({}), 1, {}


    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")

        if self.model is None:
            model_path = find_latest_model_path(self.partition_id)
            self.model = tf.keras.models.load_model(model_path)
            
        set_parameters_to_keras_model(self.model, parameters)

        # Evaluate on test set (or reuse validation set)
        test_path = "../tfx-flower/data/simple/global_test.csv"
        test_df = pd.read_csv(test_path)
        x_test = test_df[FEATURE_COLUMNS].values
        y_test = test_df[TARGET_COLUMN].values

        loss, accuracy = self.model.evaluate(x_test, y_test, verbose=0)
        return float(loss), len(x_test), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    partition_id = context.node_config["partition-id"]
    return TFXFlowerClient(partition_id).to_client()

# Create the ClientApp (for simulation)
client = ClientApp(client_fn=client_fn)

### Flower Server

In [None]:
from flwr.simulation import start_simulation
from flwr.server.strategy import FedAvg
from flwr.server import ServerConfig

ROUNDS = 3

def evaluate_global_model(weights):
    """Evaluate global model using global_test.csv"""
    print("[Server] Evaluating global model...")

    # Load global test set
    test_df = pd.read_csv("../data/simple/global_test.csv")
    x_test = test_df[FEATURE_COLUMNS].values
    y_test = test_df[TARGET_COLUMN].values

    # Load a fresh model (must match client's architecture)
    model = tf.keras.models.load_model("taxi_utils_native_keras_model_template")
    model.set_weights(weights)

    loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
    print(f"[Server] Global evaluation - Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
    return loss, {"accuracy": accuracy}

def get_evaluate_fn():
    def evaluate(weights):
        return evaluate_global_model(weights)
    return evaluate

### Run

In [None]:
strategy = FedAvg(
    fraction_fit=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    on_evaluate_config_fn=lambda rnd: {},
    evaluate_fn=get_evaluate_fn()
    )

start_simulation(
    client_app=client,
    num_clients=NUM_CLIENTS,
    config=ServerConfig(num_rounds=ROUNDS),
    client_resources={"num_cpus": 1},
    strategy=strategy
    )