# TFX Iterative Development Example
This notebook demonstrates how to use Jupyter notebooks for TFX iterative development.  Here, we walk through the Chicago Taxi example in an interactive Jupyter notebook.

## Setup
First, we install the necessary packages, download data, import modules and set up paths.

### Install TFX and Tensorflow

In [0]:
!pip install http://tfx-ccy-public.storage.googleapis.com/tfx-0.14.0.dev0-py3-none-any.whl tensorflow==1.14.0

### Import packages
We import necessary packages, including standard TFX component classes.

In [0]:
import os
import tempfile
import urllib

import tfx
from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.csv_example_gen.component import CsvExampleGen
from tfx.components.example_validator.component import ExampleValidator
from tfx.components.model_validator.component import ModelValidator
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration.interactive.interactive_context import InteractiveContext
from tfx.proto import evaluator_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.proto.evaluator_pb2 import SingleSlicingSpec
from tfx.utils.dsl_utils import csv_input

### (in progress: inject IPython formatter integration)

In [0]:
# TFX IPython formatter integration
# TODO: integrate this
from tfx.orchestration.interactive.interactive_context import ExecutionResult

def style():
    return '''<style>
.object {
}
.object.expanded {
padding: 4px 8px 4px 8px;
border: 1px solid #bbbbbb;
box-shadow: 4px 4px 2px rgba(0,0,0,0.05);
background: white;
border-radius: 0;
}
.object, .object * {
font-size: 11pt;
}
.object > .title {
cursor: pointer;
}
.expansion-marker {
color: #999999;
}
.object.expanded > .title > .expansion-marker:before {
content: '▼'
}
.object.collapsed > .title > .expansion-marker:before {
content: '▶'
}
.classname {
font-weight: bold
}
.deemph {
opacity: 0.5;
}
.object.collapsed > table.attrtable {
display: none;
}
.object.expanded > table.attrtable {
display: block;
}
table.attrtable {
border: 2px solid white;
margin-top: 5px;
}
table.attrtable tr {
}
td.attrname {
vertical-align: top;
font-weight: bold;
}
td.attrvalue {
text-align: left
}
</style>
<script>
function toggle(element) {
var obj_element = element.parentElement;
if (obj_element.classList.contains('collapsed')) {
obj_element.classList.remove('collapsed');
obj_element.classList.add('expanded');
} else {
obj_element.classList.add('collapsed');
obj_element.classList.remove('expanded');
}
}
</script>
'''

def execution_result(obj):
    return style() + '''''' % (obj.__class__.__name__, id(obj), obj.execution_id, obj.component.component_name, obj.component.inputs) + repr(obj)

from typing import Text, Type

from tfx.components.base.base_component import _PropertyDictWrapper
from tfx.utils.channel import Channel
from tfx.utils.types import TfxArtifact

class NotebookFormatter(object):
    _DEFAULT_TITLE_FORMAT = ('<span class="classname">%s</span>', ['__class__.__name__'])
    
    def __init__(self, cls: Type, expandable: bool=False, attributes=None, title_format=None):
        self.cls = cls
        self.expandable = expandable
        self.attributes = attributes or []
        self.title_format = title_format or NotebookFormatter._DEFAULT_TITLE_FORMAT

    def _extended_getattr(self, obj, property_name: Text):
        # print('_extended_getattr', obj.__class__, property_name)
        if callable(property_name):
            return property_name(obj)
        parts = property_name.split('.')
        current = obj
        for part in parts:
            current = getattr(current, part)
        return current
    
    def render(self, obj: object, expanded=True, seen_elements=None):
        seen_elements = seen_elements or set()
        if id(obj) in seen_elements:
            return '(recursion in rendering object)'
        seen_elements.add(id(obj))
        if not isinstance(obj, self.cls):
            raise ValueError('Expected object of type %s but got %s.' % (self.cls, obj))
        seen_elements.remove(id(obj))
        return style() + '''<div class="object%s">
<div class = 'title' onclick='toggle(this)'><span class="expansion-marker"></span>
%s<span class="deemph"> at 0x%x</span></div>%s
</div>''' % (' expanded' if expanded else ' collapsed', self.render_title(obj), id(obj),
             self.render_attributes(obj, seen_elements))
    
    def render_title(self, obj: object):
        title_format = self.title_format
        values = []
        for property_name in title_format[1]:
            values.append(self._extended_getattr(obj, property_name))
        return title_format[0] % tuple(values)
    
    def render_value(self, value: object, seen_elements: set):
        if isinstance(value, _PropertyDictWrapper):
            value = value.get_all()
        if isinstance(value, dict):
            value = self.render_dict(value, seen_elements)
        if isinstance(value, list):
            value = self.render_list(value, seen_elements)
        for cls in value.__class__.mro():
            # print(value.__class__, cls)
            if cls in FORMATTER_REGISTRY:
                value = FORMATTER_REGISTRY[cls].render(value, expanded=False, seen_elements=seen_elements)
                break
        return value
    
    def render_attributes(self, obj: object, seen_elements: set):
        attr_trs = []
        for property_name in self.attributes:
            value = self._extended_getattr(obj, property_name)
            value = self.render_value(value, seen_elements)
            attr_trs.append('''<tr><td class="attrname">.%s</td>
<td class = "attrvalue">%s</td></tr>''' % (property_name, value))
        return '''<table class="attrtable">%s</table>''' % ''.join(attr_trs)
    
    def render_dict(self, obj: dict, seen_elements: set):
        if not obj:
            return '{}'
        attr_trs = []
        for key, value in obj.items():
            value = self.render_value(value, seen_elements)
            attr_trs.append('''<tr><td class="attrname">[%r]</td>
<td class = "attrvalue">%s</td></tr>''' % (key, value))
        return '''<table class="attrtable">%s</table>''' % ''.join(attr_trs)
    
    def render_list(self, obj: dict, seen_elements: set):
        if not obj:
            return '[]'
        attr_trs = []
        for i, value in enumerate(obj):
            value = self.render_value(value, seen_elements)
            attr_trs.append('''<tr><td class="attrname">[%d]</td>
<td class = "attrvalue">%s</td></tr>''' % (i, value))
        return '''<table class="attrtable">%s</table>''' % ''.join(attr_trs)
        
        
def create_formatters(formatters_spec):
    result = {}
    for cls, kwargs in formatters_spec.items():
        formatter = NotebookFormatter(cls, **kwargs)
        result[cls] = formatter
    return result

from tfx.components.base.base_component import BaseComponent
from tfx.utils.channel import Channel
from tfx.utils.types import TfxArtifact

FORMATTER_REGISTRY = create_formatters({
    ExecutionResult: {'attributes': ['execution_id', 'component', 'component.inputs', 'component.outputs']},
    BaseComponent: {'attributes': ['inputs', 'outputs', 'exec_properties']},
    Channel: {'attributes': ['type_name', '_artifacts'],
                 'title_format': ('<span class="classname">%s</span> of type <span class="classname">%r</span> (%d artifact%s)',
                                  ['__class__.__name__', 'type_name', lambda o: len(o._artifacts), lambda o: '' if len(o._artifacts) == 1 else 's']),
             },
    TfxArtifact: {'attributes': ['type_name', 'uri', 'span', 'split'],
                 'title_format': ('<span class="classname">%s</span> of type <span class="classname">%r</span> (uri: %s)',
                                  ['__class__.__name__', 'type_name', 'uri']),
                 }
})

html_formatter = get_ipython().display_formatter.formatters['text/html']
for cls, formatter in FORMATTER_REGISTRY.items():
    html_formatter.for_type(cls, formatter.render)
    

### Download example data
We download the sample dataset for use in our TFX pipeline.

In [0]:
# Download the example data.
_data_root = tempfile.mkdtemp(prefix='tfx-data')
DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/chicago_taxi_pipeline/data/simple/data.csv'
with open(os.path.join(_data_root, 'data.csv'), 'wb') as f:
  contents = urllib.request.urlopen(DATA_PATH).read()
  f.write(contents)

### Set up pipeline paths

In [0]:
# Set up paths.
_taxi_root = os.path.join(tfx.__path__[0], 'examples/chicago_taxi_pipeline')
# Python module file to inject customized logic into the TFX components. The
# Transform and Trainer both require user-defined functions to run successfully.
_taxi_module_file = os.path.join(_taxi_root, 'taxi_utils.py')
# Path which can be listened to by the model server.  Pusher will output the
# trained model here.
_serving_model_dir = os.path.join(tempfile.mkdtemp(), 'serving_model/taxi_simple')

## Create the InteractiveContext
We now create the interactive context.

In [0]:
# Here, we create an InteractiveContext using default parameters. This will
# use a temporary directory with an ephemeral ML Metadata database instance.
# To use your own pipeline root or database, the optional properties
# `pipeline_root` and `metadata_connection_config` may be passed to
# InteractiveContext.
context = InteractiveContext()

## Run TFX components interactively
Next, we construct TFX components and run each one interactively using within the interactive session to obtain `ExecutionResult` objects.

### ExampleGen
`ExampleGen` brings data into the TFX pipeline.

In [0]:
# Use the packaged CSV input data.
examples = csv_input(_data_root)

# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input_base=examples)
context.run(example_gen)

### StatisticsGen (using Tensorflow Data Validation)
`StatisticsGen` computes statistics for visualization and example validation. This uses the [Tensorflow Data Validation](https://www.tensorflow.org/tfx/data_validation/get_started) library.

#### Run TFDV statistics computation using the StatisticsGen component

In [0]:
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(
    input_data=example_gen.outputs['examples'])
context.run(statistics_gen)

#### Import TFDV and visualize the statistics result

In [0]:
# Import TFDV and get the train statistics path.
import tensorflow_data_validation as tfdv
from tfx.utils.types import get_split_uri
artifact_list = statistics_gen.outputs['output'].get()
train_artifact_uri = get_split_uri(artifact_list, 'train')
train_stats_path = os.path.join(train_artifact_uri, 'stats_tfrecord')

In [0]:
# Load statistics and visualize training data stats.
train_stats = tfdv.load_statistics(train_stats_path)
tfdv.visualize_statistics(train_stats)

### SchemaGen (using Tensorflow Data Validation)
`SchemaGen` generates a schema for your data based on computed statistics. This component also uses the [Tensorflow Data Validation](https://www.tensorflow.org/tfx/data_validation/get_started) library.

#### Run TFDV schema inference using the SchemaGen component

In [0]:
# Generates schema based on statistics files.
infer_schema = SchemaGen(stats=statistics_gen.outputs['output'])
context.run(infer_schema)

#### Visualize the inferred schema



In [0]:
# Get the schema path.
schema_dir = infer_schema.outputs['output'].get()[0].uri
schema_path = os.path.join(schema_dir, 'schema.pbtxt')

In [0]:
# Load and visualize the generated schema.
schema = tfdv.load_schema_text(schema_path)
tfdv.display_schema(schema)

### ExampleValidator
`ExampleValidator` performs anomaly detection based on computed statistics and your data schema.

#### Run TFDV data validation using the ExampleValidation component

In [0]:
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
    stats=statistics_gen.outputs['output'],
    schema=infer_schema.outputs['output'])
context.run(validate_stats)

#### Visualize the detected anomalies

In [0]:
# Get the validation path.
validation_dir = validate_stats.outputs['output'].get()[0].uri
anomalies_path = os.path.join(validation_dir, 'anomalies.pbtxt')

In [0]:
# Load and visualize the anomalies.

# Utility function backported from TFDV 0.14.
def load_anomalies_text(input_path):
  from google.protobuf import text_format
  from tensorflow.python.lib.io import file_io
  from tensorflow_metadata.proto.v0 import anomalies_pb2
  anomalies = anomalies_pb2.Anomalies()
  anomalies_text = file_io.read_file_to_string(input_path)
  text_format.Parse(anomalies_text, anomalies)
  return anomalies

anomalies = load_anomalies_text(anomalies_path)
# anomalies = tfdv.load_anomalies_text(anomalies_path)  # Will be available once TFDV 0.14 is released.
tfdv.display_anomalies(anomalies)

### Transform
`Transform` performs data transformations and feature engineering which are kept in sync for training and serving.

#### Run the Transform component

In [0]:
# Performs transformations and feature engineering in training and serving.
transform = Transform(
    input_data=example_gen.outputs['examples'],
    schema=infer_schema.outputs['output'],
    module_file=_taxi_module_file)
context.run(transform)

### Trainer
`Trainer` trains your custom model using TF-Learn.

In [0]:
# Uses user-provided Python function that implements a model using TF-Learn.
trainer = Trainer(
    module_file=_taxi_module_file,
    transformed_examples=transform.outputs['transformed_examples'],
    schema=infer_schema.outputs['output'],
    transform_output=transform.outputs['transform_output'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))
context.run(trainer)

### Evaluator (using Tensorflow Model Analysis)
The `Evaluator` computes evaluation statistics over features of your model using [Tensorflow Model Analysis](https://www.tensorflow.org/tfx/model_analysis/get_started). In this section, we run TFMA in our TFX pipeline and then visualize the results to analyze the performance of our model.

#### Run TFMA using the Evaluator component

Here, we first define slicing specs for analyzing our data. Next, we run TFMA using these specs to generate results.

In [0]:
# An empty slice spec means the overall slice, that is, the whole dataset.
OVERALL_SLICE_SPEC = evaluator_pb2.SingleSlicingSpec()

# Data can be sliced along a feature column
# In this case, data is sliced along feature column trip_start_hour.
FEATURE_COLUMN_SLICE_SPEC = evaluator_pb2.SingleSlicingSpec(column_for_slicing=['trip_start_hour'])

# Data can be sliced by crossing feature columns
# In this case, slices are computed for trip_start_day x trip_start_month.
FEATURE_COLUMN_CROSS_SPEC = evaluator_pb2.SingleSlicingSpec(column_for_slicing=['trip_start_day', 'trip_start_month'])

ALL_SPECS = [
    OVERALL_SLICE_SPEC,
    FEATURE_COLUMN_SLICE_SPEC, 
    FEATURE_COLUMN_CROSS_SPEC,
]

In [0]:
# Use TFMA to compute a evaluation statistics over features of a model.
model_analyzer = Evaluator(
    examples=example_gen.outputs['examples'],
    model_exports=trainer.outputs['output'],
    feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(
        specs=ALL_SPECS
    ))
context.run(model_analyzer)

#### Get the TFMA output result path

In [0]:
PATH_TO_RESULT = model_analyzer.outputs['output'].get()[0].uri

#### Import TFMA and load the result

In [0]:
import tensorflow_model_analysis as tfma
tfma_result = tfma.load_eval_result(PATH_TO_RESULT)

#### Visualization: Slicing Metrics

To see the slices, either use the name of the column (by setting slicing_column) or provide a tfma.slicer.SingleSliceSpec (by setting slicing_spec). If neither is provided, an overall visualization will be displayed.

The default visualization is the **slice overview** when the number of slices is small. It shows the value of a metric for each slice, sorted by the another metric. It is also possible to set a threshold to filter out slices with smaller weights.

This view also supports the **metrics histogram** as an alternative visualization. It is also the default view when the number of slices is large. The results will be divided into buckets and the number of slices / total weights / both can be visualized. Slices with small weights can be filtered out by setting the threshold. Further filtering can be applied by dragging the grey band. To reset the range, double click the band. Filtering can be used to remove outliers in the visualization and the metrics table below.

In [0]:
# Show data sliced along feature column trip_start_hour.
tfma.view.render_slicing_metrics(tfma_result, slicing_column='trip_start_hour')

In [0]:
# Show metrics sliced by 'trip_start_day' x 'trip_start_month'.
tfma.view.render_slicing_metrics(tfma_result,
                                 slicing_spec=tfma.slicer.SingleSliceSpec(
                                     columns=['trip_start_day', 'trip_start_month']))

In [0]:
# Show overall metrics.
tfma.view.render_slicing_metrics(tfma_result)

### ModelValidator
`ModelValidator` performs validation of your candidate model compared to a baseline.

In [0]:
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['output'])
context.run(model_validator)

### Pusher
`Pusher` checks whether a model has passed validation, and if so, pushes the model to a file destination.

In [0]:
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
    model_export=trainer.outputs['output'],
    model_blessing=model_validator.outputs['blessing'],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
context.run(pusher)