# TFX Iterative Development Example
This notebook demonstrates how to use Jupyter notebooks for TFX iterative development.  Here, we walk through the Chicago Taxi example in an interactive Jupyter notebook.

## Setup
First, download data, import modules and set up paths.

### Import packages
We import necessary packages, including standard TFX component classes.

In [0]:
import os
import tempfile
import urllib

import tfx
from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.csv_example_gen.component import CsvExampleGen
from tfx.components.example_validator.component import ExampleValidator
from tfx.components.model_validator.component import ModelValidator
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration.interactive.interactive_context import InteractiveContext
from tfx.proto import evaluator_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.utils.dsl_utils import csv_input

### Download example data
We download the sample dataset for use in our TFX pipeline.

In [0]:
# Download the example data.
_data_root = tempfile.mkdtemp(prefix='tfx-data')
DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/chicago_taxi_pipeline/data/simple/data.csv'
with open(os.path.join(_data_root, 'data.csv'), 'wb') as f:
  contents = urllib.request.urlopen(DATA_PATH).read()
  f.write(contents)

### Set up pipeline paths

In [0]:
# Set up paths.
_taxi_root = os.path.join(tfx.__path__[0], 'examples/chicago_taxi_pipeline')
# Python module file to inject customized logic into the TFX components. The
# Transform and Trainer both require user-defined functions to run successfully.
_taxi_module_file = os.path.join(_taxi_root, 'taxi_utils.py')
# Path which can be listened to by the model server.  Pusher will output the
# trained model here.
_serving_model_dir = os.path.join(tempfile.mkdtemp(), 'serving_model/taxi_simple')

## Create the InteractiveContext
We now create the interactive context.

In [0]:
# Here, we create an InteractiveContext using default parameters. This will
# use a temporary directory with an ephemeral ML Metadata database instance.
# To use your own pipeline root or database, the optional properties
# `pipeline_root` and `metadata_connection_config` may be passed to
# InteractiveContext.
context = InteractiveContext()

## Run TFX components interactively
Next, we construct TFX components and run each one interactively using within the interactive session to obtain `ExecutionResult` objects.

### ExampleGen
`ExampleGen` brings data into the TFX pipeline.

In [0]:
# Use the packaged CSV input data.
examples = csv_input(_data_root)

# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input_base=examples)
context.run(example_gen)

### StatisticsGen
`StatisticsGen` computes statistics for visualization and example validation.

In [0]:
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(
    input_data=example_gen.outputs['examples'])
context.run(statistics_gen)

### SchemaGen
`SchemaGen` generates a schema for your data based on computed statistics.

In [0]:
# Generates schema based on statistics files.
infer_schema = SchemaGen(stats=statistics_gen.outputs['output'])
context.run(infer_schema)

### ExampleValidator
`ExampleValidator` performs anomaly detection based on computed statistics and your data schema.

In [0]:
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
    stats=statistics_gen.outputs['output'],
    schema=infer_schema.outputs['output'])
context.run(validate_stats)

### Transform
`Transform` performs data transformations and feature engineering which is kept in sync for training and serving.

In [0]:
# Performs transformations and feature engineering in training and serving.
transform = Transform(
    input_data=example_gen.outputs['examples'],
    schema=infer_schema.outputs['output'],
    module_file=_taxi_module_file)
context.run(transform)

### Trainer
`Trainer` trains your custom model using TF-Learn.

In [0]:
# Uses user-provided Python function that implements a model using TF-Learn.
trainer = Trainer(
    module_file=_taxi_module_file,
    transformed_examples=transform.outputs['transformed_examples'],
    schema=infer_schema.outputs['output'],
    transform_output=transform.outputs['transform_output'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))
context.run(trainer)

### Evaluator
`Evaluator` computes evaluation statistics over features of your model.

In [0]:
# Uses TFMA to compute a evaluation statistics over features of a model.
model_analyzer = Evaluator(
    examples=example_gen.outputs['examples'],
    model_exports=trainer.outputs['output'],
    feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
        evaluator_pb2.SingleSlicingSpec(
            column_for_slicing=['trip_start_hour'])
    ]))
context.run(model_analyzer)

### ModelValidator
`ModelValidator` performs validation of your candidate model compared to a baseline.

In [0]:
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['output'])
context.run(model_validator)

### Pusher
`Pusher` checks whether a model has passed validation, and if so, pushes the model to a file destination.

In [0]:
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
    model_export=trainer.outputs['output'],
    model_blessing=model_validator.outputs['blessing'],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
context.run(pusher)