# TFX  pipeline in Flower framework

Second attempt

The notebook is used for local testing, replicating Flower's tutorial to implement client and server at the same place.

In [None]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import glob
import os

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation, start_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

  from .autonotebook import tqdm as notebook_tqdm
2025-06-20 21:57:02,020	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Training on cpu
Flower 1.18.0 / PyTorch 2.2.2+cpu


: 

In [None]:
import absl
import tensorflow_model_analysis as tfma
from tfx.components import CsvExampleGen
from tfx.components import Evaluator
from tfx.components import ExampleValidator
from tfx.components import Pusher
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components import Transform
from tfx.dsl.components.common import resolver
from tfx.dsl.experimental import latest_blessed_model_resolver
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.types import Channel
from tfx.types.standard_artifacts import Model
from tfx.types.standard_artifacts import ModelBlessing

import tensorflow as tf

In [None]:
def create_pipeline(client_id: int, module_file: str, output_dir: str) -> pipeline.Pipeline:
    data_root = f"../tfx-flower/data/simple/client_{client_id}.csv"
    pipeline_root = f"{output_dir}/pipeline_client_{client_id}"
    # metadata_path = f"{output_dir}/metadata_client_{client_id}.db"
    serving_model_dir = os.path.join(output_dir, f"serving_model_client_{client_id}")

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=os.path.dirname(data_root))
    
    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=True)
    
    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])
        
    # Performs transformations and feature engineering in training and serving.
    transform = Transform(
        examples=example_gen.outputs['examples'],
        schema=schema_gen.outputs['schema'],
        module_file=module_file)
    
    return pipeline.Pipeline(
        pipeline_name=f"client_{client_id}_pipeline",
        pipeline_root=pipeline_root,
        components=[example_gen, statistics_gen, schema_gen, example_validator, transform],
        enable_cache=True,
        metadata_connection_config=None,
    )

def run_pipeline(client_id: int, module_file: str, output_dir: str):
    print(f"Running TFX pipeline for client {client_id}...")
    BeamDagRunner().run(create_pipeline(client_id=client_id, module_file=module_file, output_dir=output_dir))
    print(f"Pipeline completed for client {client_id}")

In [None]:
# Run pipelines for all 5 clients
MODULE_FILE = "taxi_utils.py"  # Your preprocessing module file
OUTPUT_DIR = "./tfx_output"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Run pipelines for all clients
for client_id in range(5):
    try:
        run_pipeline(client_id, MODULE_FILE, OUTPUT_DIR)
    except Exception as e:
        print(f"Error running pipeline for client {client_id}: {e}")

print("All TFX pipelines completed!")

## Define the model
DNN

We define constants from the original model

In [None]:
# Categorical features are assumed to each have a maximum value in the dataset.
_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 13, 2000, 2000, 80, 80]

_CATEGORICAL_FEATURE_KEYS = [
    'trip_start_hour', 'trip_start_day', 'trip_start_month',
    'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area',
    'dropoff_community_area'
]

_DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds']

# Number of buckets used by tf.transform for encoding each feature.
_FEATURE_BUCKET_COUNT = 10

_BUCKET_FEATURE_KEYS = [
    'pickup_latitude', 'pickup_longitude', 'dropoff_latitude',
    'dropoff_longitude'
]

# Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform
_VOCAB_SIZE = 1000

# Count of out-of-vocab buckets in which unrecognized VOCAB_FEATURES are hashed.
_OOV_SIZE = 10

_VOCAB_FEATURE_KEYS = [
    'payment_type',
    'company',
]

# Keys
_LABEL_KEY = 'tips'
_FARE_KEY = 'fare'

In [None]:
from tfx.components.trainer.fn_args_utils import DataAccessor
from tfx_bsl.tfxio import dataset_options

def _transformed_name(key):
  return key + '_xf'

def _transformed_names(keys):
  return [_transformed_name(key) for key in keys]

def _input_fn(file_pattern: List[str],
              data_accessor: DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
  """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  return data_accessor.tf_dataset_factory(
      file_pattern,
      dataset_options.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)),
      tf_transform_output.transformed_metadata.schema).repeat()

In [None]:
def create_model(hidden_units: List[int]=None) -> tf.keras.Model:
    """Creates a DNN Keras model for classifying taxi data.
    Mostly identical to the _build_keras_model in the taxi_utils_native_keras.py

    Args:
        hidden_units: [int], the layer sizes of the DNN (input layer first).

    Returns:
        A Wide and Deep keras Model.
    """
    # Keras needs the feature definitions at compile time.
    deep_input = {
        colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype=tf.float32)
        for colname in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS)
    }

    wide_vocab_input = {
        colname: tf.keras.layer.Input(name=colname, shape=(1,), dtype='int32')
        for colname in _transformed_names(_VOCAB_FEATURE_KEYS)
    }

    wide_bucket_input = {
        colname: tf.keras.layers.Input(name=colname, shape=(1,), dttpe='int32')
        for colname in _transformed_names(_BUCKET_FEATURE_KEYS)
    }

    wide_categorical_inpt = {
        colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
        for colname in _transformed_names(_CATEGORICAL_FEATURE_KEYS)
    }

    input_layers = {
        **deep_input,
        **wide_vocab_input,
        **wide_bucket_input,
        **wide_categorical_inpt,
    }

    # Build deep branch
    deep = tf.keras.layers.concatenate(
        [tf.keras.layers.Normalization()(layer) for layer in deep_input.values()]
    )

    for numnodes in (hidden_units or [100, 70, 50, 25]):
        deep = tf.keras.layers.Dense(numnodes)(deep)

    # Build wide branch
    wide_layers = []

    for key in _transformed_names(_VOCAB_FEATURE_KEYS):
        wide_layers.append(
            tf.keras.layers.CategoryEncoding(num_tokens=_VOCAB_SIZE + _OOV_SIZE)(
                input_layers[key]
            )
        )

    for key in _transformed_names(_BUCKET_FEATURE_KEYS):
        wide_layers.append(
            tf.keras.layers.CategoryEncoding(num_tokens=_FEATURE_BUCKET_COUNT)(
                input_layers[key]
            )
        )
    
    for key, num_tokens in zip(
        _transformed_names(_CATEGORICAL_FEATURE_KEYS),
        _MAX_CATEGORICAL_FEATURE_VALUES,
    ):
        wide_layers.append(
            tf.keras.layers.CategoryEncoding(num_tokens=num_tokens + 1)(
                input_layers[key]
            )
        )

    wide = tf.keras.layers.concatenate(wide_layers)

    # Combine and create output
    combined = tf.keras.layers.concatenate([deep, wide])
    output = tf.keras.layers.Dense(1, activation='sigmoid')(combined)
    output = tf.keras.layers.Reshape((1,))(output)

    model = tf.keras.Model(inputs=input_layers, outputs=output)

    # Compile
    model.compile(
        loss='binary_crossentropy',
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[tf.keras.metrics.BinaryAccuracy()],
    )

    return model

In [None]:
# Test model creating
print("Testing model creation")
test_model = create_model()
print(f"Model created. Total parameters: {test_model.count_params()}")
test_model.summary()

## Data loading functions

In [None]:
import tft 

def load_transformed_data(client_id: int, split: str='train', batch_size:int=32):
    """
    Load preprocessed data from TFX Transform component output

    Args:
        client_id:
        split: data split ('train' or 'eval')
        batch_size: batch size for dataset

    Returns:
        tf.data.Dataset with preprocessed features and labels
    """
    # Paths to transformed data
    transform_output_path = f"{OUTPUT_DIR}/pipeline_client_{client_id}/Transform/transform_graph"
    data_path = f"{OUTPUT_DIR}/pipeline_client_{client_id}/Transform/transformed_examples/{split}/*.gz"

    try:
        # Load the transformed output
        tf_transform_output = tft.TransformOutput(transform_output_path)

        # Find tfrecord files
        tfrecord_files = glob.glob(data_path)
        if not tfrecord_files:
            raise FileNotFoundError(f"No tfrecord files found at {data_path}")
        
        print(f"Loading data for client {client_id}, split: {split}")
        print(f"Found {len(tfrecord_files)} tfrecord files")

        # Create dataset from tfrecord files
        dataset = tf.data.TFRecordDataset(tfrecord_files, compression_type='GZIP')

        # Parse the examples using the transformed schema
        feature_spec = tf_transform_output.transformed_feature_spec()

        def parse_fn(example_proto):
            parsed = tf.io.parse_single_example(example_proto, feature_spec)

            label_key = _transformed_name(_LABEL_KEY)
            label = parsed.pop(label_key)

            return parsed, label
        
        dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        # check dataset
        for batch in dataset.take(1):
            features, labels = batch
            print(f"Feature keys: {list(features.keys())}")
            print(f"Batch shape - Features: {len(features)}, Labels: {labels.shape}")
            break

        return dataset 
    
    except Exception as e:
        print(f"Error loading data for client {client_id}: {e}")

In [None]:
# Test data loading for client 0
print("Test data loading")
test_dataset = load_transformed_data(0, 'train')
if test_dataset:
    print("Data loading successful")
else:
    print("Data loading failed")


## Flower Client Implementation

In [None]:
class TaxiFlowerClient(fl.client.NumPyClient):

    def __init__(self, client_id: int):
        self.client_id = client_id
        self.model = create_model()

        # Load client's data
        self.train_dataset = load_transformed_data(client_id, 'train', batch_size=32)
        self.eval_dataset = load_transformed_data(client_id, 'eval', batch_size=32)

        self.train_size = sum(1 for _ in self.train_dataset.unbatch())
        self.eval_size = sum(1 for _ in self.eval_dataset.unbatch())

        print(f"Client {client_id} initialized with {self.train_size} train examples and {self.eval_size} eval examples")

    def get_parameters(self, config):
        return self.model.get_weights()
    
    def set_parameters(self, parameters) -> None:
        return self.model.set_weights(parameters)
    
    def fit(self, parameters, config):
        """Train themodel on the local dataset"""
        self.set_parameters(parameters)  # set parameters received from the server

        epochs = int(config.get("epochs", 1))

        history = self.model.fit(
            self.train_dataset,
            epochs=epochs,
            validation_data=self.eval_dataset,
            verbose=1
        )

        return (
            self.get_parameters({}),
            self.train_size,
            {
                "train_loss": float(history.history["loss"][-1]),
                "train_accuracy": float(history.history["binary_accuracy"][-1])
            }
        )
    
    def evaluate(self, parameters, config):
        self.set_parameters(parameters)

        loss, accuracy = self.model.evaluate(self.eval_dataset, verbose=0)

        return float(loss), self.eval_size, {"accuracy": float(accuracy)}

In [None]:
# Test client creation
print("Test client creation")
try:
    test_client = TaxiFlowerClient(0)
    print("Client creation successful")
except Exception as e:
    print(f"Client creation failed: {e}")

In [None]:
def client_fn(cid: str) ->  TaxiFlowerClient:
    client_id = int(cid)
    return TaxiFlowerClient(client_id).to_client()

# Create the ClientApp (for simulation)
client = ClientApp(client_fn=client_fn)


## Flower Server

In [None]:
def weighted_average(metrics: List[Tuple[int, Dict[str, float]]]) -> Dict[str, float]:
    if not metrics:
        return {}
    
    accuracies = [num_examples * m.get("accuracy", 0.0) for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    total_examples = sum(examples)
    if total_examples == 0:
        return {"accuracy": 0.0}
    
    weighted_accuracy = sum(accuracies) / total_examples
    return {"accuracy": weighted_accuracy}

In [None]:
NUM_CLIENTS = 5
NUM_ROUNDS = 5

strategy = FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    evaluate_metrics_aggregation_fn=weighted_average,
    fit_metrics_aggregation_fn=weighted_average,
)

def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """

    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=NUM_ROUNDS)

    return ServerAppComponents(strategy=strategy, config=config)

# Create the ServerApp
server = ServerApp(server_fn=server_fn)


## Federated Learning Simulation

In [None]:
pip install -U "flwr[simulation]"

In [None]:
# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

# Global variable to store history for analysis
fl_history = None

# Run simulation
try: 
    fl_history = run_simulation(
        server_app=server,
        client_app=client,
        num_supernodes=NUM_CLIENTS,
        backend_config=backend_config,
    )
except Exception as e:
    print(f"Error during federated learning: {e}")

## Result analyis

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
def analyze_fl_results(history):
    """Analyze and visualize FL results"""
    if history is None:
        print("No history avaiable")
        return
    
    print(f"History object type: {type(history)}")
    print(f"Available attributes: {dir(history)}")

    # Containers for metrics
    rounds_data = []

    if hasattr(history, 'metrics_distributed'):
        print("Found metrics_distributed")
        for round_num, metrics in history.metrics_distributed.items():
            round_data = {
                'round': round_num,
                'distributed_accuracy': metrics.get('accuracy', None),
            }
            rounds_data.append(round_data)

    if hasattr(history, 'metrics_centralized') and history.metrics_centralized:
        print("Found metrics_centralized")
        for round_num, metrics in history.metrics_centralized.items():
            round_data = next((r for r in rounds_data if r['round'] == round_num), None)
            if round_data is None:
                round_data = {'round': round_num}
                rounds_data.append(round_data)
            round_data['centralized_accuracy'] = metrics.get('accuracy', None)

    if hasattr(history, 'losses_distributed') and history.losses_distributed:
        print("Found losses_distributed")
        for round_num, loss in history.losses_distributed.items():
            round_data = next((r for r in rounds_data if r['round'] == round_num), None)
            if round_data is None:
                round_data = {'round': round_num}
                rounds_data.append(round_data)
            round_data['distributed_loss'] = loss

    if hasattr(history, 'losses_centralized') and history.losses_centralized:
        print("Found losses_centralized")
        for round_num, loss in history.losses_centralized.items():
            round_data = next((r for r in rounds_data if r['round'] == round_num), None)
            if round_data is None:
                round_data = {'round': round_num}
                rounds_data.append(round_data)
            round_data['centralized_loss'] = loss

    # Convert to DataFrame
    if rounds_data:
        df = pd.DataFrame(rounds_data).sort_values('round')
        print(f"\nResults DataFrame:")
        print(df)

        #plots
        fig, axes = plt.subplits(2, 2, figsize=(15, 10))

        # 1. Distributed accuracy
        if 'distributed_accuracy' in df.columns and df['distributed_accuracy'].notna().any():
            axes[0, 0].plot(df['round'], df['distributed_accuracy'], 'b-o', linewidth=2, markersize=6)
            axes[0, 0].set_title('Distributed accuracy over rounds')
            axes[0, 0].set_xlabel('Round')
            axes[0, 0].set_ylabel('Accuracy')
            axes[0, 0].grid(True, alpha=0.3)
            axes[0, 0].set_ylim([0, 1])

        # 2. Distributed loss
        if 'distributed_loss' in df.columns and df['distributed_loss'].notna().any():
            axes[0, 1].plot(df['round'], df['distributed_loss'], 'r-o', linewidth=2, markersize=6)
            axes[0, 1].set_title('Distributed loss over rounds')
            axes[0, 1].set_xlabel('Round')
            axes[0, 1].set_ylabel('Loss')
            axes[0, 1].gid(True, alpha=0.3)

        # 3. Centralized vs Distributed accuracy
        if ('centralized_accuracy' in df.columns and df['centralized_accuracy'].notna().any() and 
            'distributed_accuracy' in df.columns and df['distributed_accuracy'].notna().any()):
            axes[1, 0].plot(df['round'], df['distributed_accuracy'], 'b-o', label='Distributed', linewidth=2)
            axes[1, 0].plot(df['round'], df['centralized_accuracy'], 'g-s', label='Centralized', linewidth=2)
            axes[1, 0].set_title('Accuracy comparison')
            axes[1, 0].set_xlabel('Round')
            axes[1, 0].set_ylabel('Accuracy')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)
            axes[1, 0].set_ylim([0, 1])
        elif 'distributed_accuracy' in df.columns and df['distributed_accuracy'].notna().any():
            axes[1, 0].plot(df['round'], df['distributed_accuracy'], 'b-o', linewidth=2, markersize=6)
            axes[1, 0].set_title('Distributed accuracy')
            axes[1, 0].set_xlabel('Round')
            axes[1, 0].set_ylabel('Accuracy')
            axes[1, 0].grid(True, alpha=0.3)
            axes[1, 0].set_ylim([0, 1])
        
        # 4. Training progress summary
        if 'distributed_accuracy' in df.columns and df['distributed_accuracy'].notna().any():
            improvement = df['distributed_accuracy'].iloc[-1] - df['distributed_accuracy'].iloc[0]
            axes[1, 1].bar(['Initial', 'Final'],
                           [df['distributed_accuracy'].iloc[0], df['distributed_accuracy'].iloc[-1]],
                           color=['lightblue', 'darkblue'], alpha=0.7)
            axes[1, 1].set_title(f"Accuracy improvement: +{improvement: .3f}")
            axes[1, 1].set_ylabel('Accuracy')
            axes[1, 1].set_ylim([0, 1])
            axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

        # Summary statistics
        print(f"\n{'='*50}")
        print("FEDERATED LEARNING SUMMARY")
        print(f"{'='*50}")
        print(f"Number of Clients: {NUM_CLIENTS}")
        print(f"Number of Rounds: {NUM_ROUNDS}")
        print(f"Device Used: {DEVICE}")
        
        if 'distributed_accuracy' in df.columns and df['distributed_accuracy'].notna().any():
            initial_acc = df['distributed_accuracy'].iloc[0]
            final_acc = df['distributed_accuracy'].iloc[-1]
            max_acc = df['distributed_accuracy'].max()
            print(f"\nAccuracy Metrics:")
            print(f"  Initial Accuracy: {initial_acc:.4f}")
            print(f"  Final Accuracy: {final_acc:.4f}")
            print(f"  Best Accuracy: {max_acc:.4f}")
            print(f"  Total Improvement: {final_acc - initial_acc:.4f}")
        
        if 'distributed_loss' in df.columns and df['distributed_loss'].notna().any():
            initial_loss = df['distributed_loss'].iloc[0]
            final_loss = df['distributed_loss'].iloc[-1]
            min_loss = df['distributed_loss'].min()
            print(f"\nLoss Metrics:")
            print(f"  Initial Loss: {initial_loss:.4f}")
            print(f"  Final Loss: {final_loss:.4f}")
            print(f"  Best Loss: {min_loss:.4f}")
            print(f"  Loss Reduction: {initial_loss - final_loss:.4f}")
        
        return df
    else:
        print("No metrics data found in history object")
        print("Available history attributes:")
        for attr in dir(history):
            if not attr.startswith('_'):
                print(f"  - {attr}: {getattr(history, attr)}")
        return None


results_df = analyze_fl_results(fl_history)