In [None]:
!git clone https://github.com/piEsposito/tfx-tutorial-medium.git
!pip install -r requirements.txt
!cd tfx-tutorial-medium/local

In [None]:
from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

In [None]:
context = InteractiveContext()

In [None]:
example_gen = tfx.components.CsvExampleGen(input_base="data_local/")
context.run(example_gen)

In [None]:
statistics_gen = tfx.components.StatisticsGen(
        examples=example_gen.outputs["examples"]
    )
context.run(statistics_gen)

In [None]:
schema_gen = tfx.components.SchemaGen(
        statistics=statistics_gen.outputs["statistics"],
        infer_feature_shape=False,
    )

context.run(schema_gen)

In [None]:
transform = tfx.components.Transform(
        examples=example_gen.outputs["examples"],
        schema=schema_gen.outputs["schema"],
        module_file="module.py",
    )

context.run(transform)

In [None]:
from tfx.proto import example_gen_pb2, pusher_pb2, trainer_pb2

In [None]:
training_kwargs = {
        "module_file": "module.py",
        "examples": transform.outputs["transformed_examples"],
        "transform_graph": transform.outputs["transform_graph"],
        "train_args": trainer_pb2.TrainArgs(num_steps=100),
        "eval_args": trainer_pb2.EvalArgs(num_steps=1),
    }

trainer = tfx.components.Trainer(**training_kwargs)
context.run(trainer)

In [None]:
import tensorflow_model_analysis as tfma

eval_config = tfma.EvalConfig(
    model_specs=[
        tfma.ModelSpec(
            signature_name="serving_default",
            label_key="consumer_disputed",
            # preprocessing_function_names=["transform_features"],
        )
    ],
    slicing_specs=[tfma.SlicingSpec(), tfma.SlicingSpec(feature_keys=["product"])],
    metrics_specs=[
        tfma.MetricsSpec(
            metrics=[
                tfma.MetricConfig(
                    class_name="BinaryAccuracy",
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={"value": 0.65}
                        ),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={"value": -1e-10},
                        ),
                    ),
                ),
                tfma.MetricConfig(class_name="Precision"),
                tfma.MetricConfig(class_name="Recall"),
                tfma.MetricConfig(class_name="ExampleCount"),
                tfma.MetricConfig(class_name="AUC"),
            ],
        )
    ],
)

evaluator = tfx.components.Evaluator(
    examples=example_gen.outputs["examples"],
    model=trainer.outputs["model"],
    eval_config=eval_config,
)
context.run(evaluator)

In [None]:
pusher = tfx.components.Pusher(
    model=trainer.outputs["model"],
    model_blessing=evaluator.outputs["blessing"],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory="./model-output"
        )
    ),
)
context.run(pusher)