# TFX Pipeline

# TODO: Organise
Relevant Guides:
- https://www.tensorflow.org/recommenders/examples/basic_retrieval
- https://www.tensorflow.org/tfx/tutorials/tfx/recommenders
- https://www.tensorflow.org/recommenders/examples/basic_ranking
- https://www.tensorflow.org/recommenders/examples/ranking_tfx

In [1]:
from importlib import reload
from pathlib import Path

import tensorflow_model_analysis as tfma
from absl import logging
from tfx import v1 as tfx
from tfx.components import (
    CsvExampleGen,
    Evaluator,
    Pusher,
    SchemaGen,
    StatisticsGen,
    Transform,
)
from tfx.orchestration.experimental.interactive.interactive_context import (
    InteractiveContext,
)
from tfx.types.standard_component_specs import (
    BLESSING_KEY,
    EVALUATION_KEY,
    EXAMPLES_KEY,
    MODEL_KEY,
    POST_TRANSFORM_SCHEMA_KEY,
    SCHEMA_KEY,
    STATISTICS_KEY,
    TRANSFORM_GRAPH_KEY,
    TRANSFORMED_EXAMPLES_KEY,
)

from recommender_systems import evaluator_module, trainer_module, transform_module
from recommender_systems.features import ProductFeatures
from recommender_systems.splits import Splits
from tfx_tfrs.trainer import Trainer

logging.set_verbosity(logging.INFO)

DATA = Path.cwd().parent / "data"

PIPELINE_NAME = "recommender_systems"

context = InteractiveContext(
    pipeline_name=PIPELINE_NAME,
    pipeline_root=str(Path("pipeline-root") / PIPELINE_NAME),
)

2025-06-25 14:03:39.768453: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-06-25 14:03:39.822260: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-25 14:03:39.822298: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-25 14:03:39.823762: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-25 14:03:39.832616: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-06-25 14:03:39.834198: I tensorflow/core/platform/cpu_feature_guard.cc:1

In [None]:
PARTICIPANT = "stefan-dominicus"

## Ingest Reviews

### Examples
Docs:
- https://www.tensorflow.org/tfx/guide/examplegen
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/CsvExampleGen
- https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto

In [None]:
reviews_example_gen_component = CsvExampleGen(
    input_base=str(DATA / "reviews"),
    input_config=tfx.proto.Input(
        splits=[
            tfx.proto.Input.Split(name=split, pattern=f"{split}.csv")
            for split in [Splits.TRAIN, Splits.VALIDATION]
        ]
    ),
)
context.run(reviews_example_gen_component, enable_cache=True)

### Statistics
Docs:
- https://www.tensorflow.org/tfx/guide/statsgen
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/StatisticsGen

In [None]:
reviews_statistics_gen_component = StatisticsGen(
    examples=reviews_example_gen_component.outputs[EXAMPLES_KEY]
)
context.run(reviews_statistics_gen_component, enable_cache=True)

In [None]:
context.show(reviews_statistics_gen_component.outputs[STATISTICS_KEY])

### Schema
Docs:
- https://www.tensorflow.org/tfx/guide/schemagen
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/SchemaGen

In [None]:
reviews_schema_gen_component = SchemaGen(
    statistics=reviews_statistics_gen_component.outputs[STATISTICS_KEY]
)
context.run(reviews_schema_gen_component, enable_cache=True)

In [None]:
context.show(reviews_schema_gen_component.outputs[SCHEMA_KEY])

## Transform Reviews

### Transform
Docs:
- https://www.tensorflow.org/tfx/guide/transform
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/Transform

In [None]:
reload(transform_module)

transform_component = Transform(
    examples=reviews_example_gen_component.outputs[EXAMPLES_KEY],
    schema=reviews_schema_gen_component.outputs[SCHEMA_KEY],
    module_file=transform_module.__file__,
    splits_config=tfx.proto.SplitsConfig(
        # Analyse all splits for full vocabulary coverage (default: train only)
        analyze=[Splits.TRAIN, Splits.VALIDATION],
        # Transform (and materialise) examples from all splits (default)
        transform=[Splits.TRAIN, Splits.VALIDATION],
    ),
)
context.run(transform_component, enable_cache=True)

### Statistics
Docs:
- https://www.tensorflow.org/tfx/guide/statsgen
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/StatisticsGen

In [None]:
post_transform_statistics_gen_component = StatisticsGen(
    examples=transform_component.outputs[TRANSFORMED_EXAMPLES_KEY]
)
context.run(post_transform_statistics_gen_component, enable_cache=True)

In [None]:
context.show(post_transform_statistics_gen_component.outputs[STATISTICS_KEY])

### Schema

In [None]:
context.show(transform_component.outputs[POST_TRANSFORM_SCHEMA_KEY])

## Ingest Products

### Examples
Docs:
- https://www.tensorflow.org/tfx/guide/examplegen
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/CsvExampleGen
- https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto

In [None]:
product_example_gen_component = CsvExampleGen(
    input_base=str(DATA),
    input_config=tfx.proto.Input(
        splits=[tfx.proto.Input.Split(name=Splits.SINGLE, pattern="products.csv")]
    ),
    output_config=tfx.proto.Output(
        split_config=tfx.proto.SplitConfig(
            splits=[tfx.proto.SplitConfig.Split(name=Splits.SINGLE, hash_buckets=1)]
        )
    ),
)
context.run(product_example_gen_component, enable_cache=True)

### Statistics
Docs:
- https://www.tensorflow.org/tfx/guide/statsgen
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/StatisticsGen

In [None]:
product_statistics_gen_component = StatisticsGen(
    examples=product_example_gen_component.outputs[EXAMPLES_KEY]
)
context.run(product_statistics_gen_component, enable_cache=True)

In [None]:
context.show(product_statistics_gen_component.outputs[STATISTICS_KEY])

### Schema
Docs:
- https://www.tensorflow.org/tfx/guide/schemagen
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/SchemaGen

In [None]:
product_schema_gen_component = SchemaGen(
    statistics=product_statistics_gen_component.outputs[STATISTICS_KEY]
)
context.run(product_schema_gen_component, enable_cache=True)

In [None]:
context.show(product_schema_gen_component.outputs[SCHEMA_KEY])

## Train Model

### Trainer
Docs:
- https://www.tensorflow.org/tfx/guide/trainer
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/Trainer
- https://github.com/tensorflow/tfx/blob/master/tfx/proto/trainer.proto

In [None]:
reload(trainer_module)

trainer_component = Trainer(
    examples=transform_component.outputs[TRANSFORMED_EXAMPLES_KEY],
    transform_graph=transform_component.outputs[TRANSFORM_GRAPH_KEY],
    schema=transform_component.outputs[POST_TRANSFORM_SCHEMA_KEY],
    item_examples=product_example_gen_component.outputs[EXAMPLES_KEY],
    item_schema=product_schema_gen_component.outputs[SCHEMA_KEY],
    module_file=trainer_module.__file__,
    train_args=tfx.proto.TrainArgs(splits=[Splits.TRAIN]),
    eval_args=tfx.proto.EvalArgs(splits=[Splits.VALIDATION]),
    custom_config=dict(
        # tensorboard_log_dir="",
    ),
)
context.run(trainer_component, enable_cache=False)

### Evaluator
Docs:
- https://www.tensorflow.org/tfx/guide/evaluator
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/Evaluator
- https://github.com/tensorflow/tfx/blob/master/tfx/proto/evaluator.proto

In [None]:
reload(evaluator_module)

evaluator_component = Evaluator(
    examples=reviews_example_gen_component.outputs[EXAMPLES_KEY],
    model=trainer_component.outputs[MODEL_KEY],
    example_splits=[Splits.VALIDATION],
    eval_config=tfma.EvalConfig(
        metrics_specs=[
            tfma.MetricsSpec(
                metrics=[
                    tfma.MetricConfig(
                        class_name="ExampleCount",
                        threshold=tfma.MetricThreshold(
                            value_threshold=tfma.GenericValueThreshold(
                                lower_bound=dict(value=1)
                            ),
                        ),
                    ),
                    tfma.MetricConfig(
                        class_name="TopKAccuracy",
                        module=evaluator_module.__name__,
                    ),
                ],
            ),
        ],
        model_specs=[
            tfma.ModelSpec(
                label_key=ProductFeatures.ID,
                signature_name="evaluate_products_for_customer",
            ),
        ],
    ),
    schema=reviews_schema_gen_component.outputs[SCHEMA_KEY],
)
context.run(evaluator_component, enable_cache=False)

In [None]:
# TODO: Figure out what I actually want to show in this cell

output_path = evaluator_component.outputs[EVALUATION_KEY].get()[0].uri

# Load the evaluation result
eval_result = tfma.load_eval_result(output_path)
print("EvalResult:", eval_result)

# Load the evaluation metrics
metrics = tfma.load_metrics(output_path)
print("Metrics:", list(metrics))

# Load the validation results
validation_result = tfma.load_validation_result(output_path)
print("ValidationResult:", validation_result)
if not validation_result.validation_ok:
    print("Validation failed (model not blessed).")

### Pusher
Docs:
- https://www.tensorflow.org/tfx/guide/pusher
- https://www.tensorflow.org/tfx/api_docs/python/tfx/v1/components/Pusher
- https://github.com/tensorflow/tfx/blob/master/tfx/proto/pusher.proto

In [None]:
# TODO: Consider pushing to a GCS bucket so I can easily access their models

pusher_component = Pusher(
    model=trainer_component.outputs[MODEL_KEY],
    model_blessing=evaluator_component.outputs[BLESSING_KEY],
    push_destination=tfx.proto.PushDestination(
                filesystem=tfx.proto.PushDestination.Filesystem(
                    base_directory=f"gs://tal-deep-learning-indabax-models/{PARTICIPANT}",
                    versioning=tfx.proto.Versioning.UNIX_TIMESTAMP,
                )
            ),
)
context.run(pusher_component, enable_cache=True)