##### Copyright &copy; 2020 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Python Function Custom Component Example

***An example of creating a custom component using the `@component` decorator***

Note: We recommend running this tutorial in a Colab notebook, with no setup required!  Just click "Run in Google Colab".

<div class="devsite-table-wrapper"><table class="tfo-notebook-buttons">
<td><a target="_blank" href="https://www.tensorflow.org/tfx/tutorials/tfx/python_component_simple">
<img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a></td>
<td><a target="_blank" href="https://colab.research.google.com/github/tensorflow/tfx/blob/master/docs/tutorials/tfx/python_component_simple.ipynb">
<img src="https://www.tensorflow.org/images/colab_logo_32px.png">Run in Google Colab</a></td>
<td><a target="_blank" href="https://github.com/tensorflow/tfx/tree/master/docs/tutorials/tfx/python_component_simple.ipynb">
<img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">View source on GitHub</a></td>
</table></div>

This tutorial shows how to create a new [custom component using a Python function and the `@component` decorator](https://www.tensorflow.org/tfx/guide/custom_function_component).

This is intended as a "Hello World" example, showing:

* Using the `@component` decorator and annotations to create a custom component
* Loading and saving of artifacts
* Parsing and update of a [Schema protocol buffer](https://github.com/tensorflow/metadata/blob/master/tensorflow_metadata/proto/v0/schema.proto)
* Adding a feature to the dataset.

## Setup and Imports

### Upgrade Pip

To avoid upgrading Pip in a system when running locally, check to make sure that we're running in Colab.  Local systems can of course be upgraded separately.

In [None]:
try:
  import colab
  !pip install --upgrade pip
except:
  pass

### Install TFX

**Note: In Google Colab, because of package updates, the first time you run this cell you must restart the runtime (Runtime > Restart runtime ...).**

In [None]:
!pip install -q tfx==0.22.0

### Did you restart the runtime?

If you are using Google Colab, the first time that you run the cell above, you must restart the runtime (Runtime > Restart runtime ...). This is because of the way that Colab loads packages.

### Import packages
We import necessary packages, including standard TFX component classes.

In [None]:
import os
import tempfile
import urllib

import absl
import tensorflow as tf
tf.get_logger().propagate = False

import tfx
from tfx.components import CsvExampleGen
from tfx.components import StatisticsGen
from tfx.components import SchemaGen
from tfx.components import ExampleValidator
from tfx.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.types import Channel
from tfx.utils.dsl_utils import external_input

%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip

Let's check the library versions.

In [None]:
print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))

### Set up pipeline paths

In [None]:
# This is the root directory for your TFX pip package installation.
_tfx_root = tfx.__path__[0]

# This is the directory containing the TFX Chicago Taxi Pipeline example.
_taxi_root = os.path.join(_tfx_root, 'examples/chicago_taxi_pipeline')

# This is the path where your model will be pushed for serving.
_serving_model_dir = os.path.join(
    tempfile.mkdtemp(), 'serving_model/taxi_simple')

# Set up logging.
absl.logging.set_verbosity(absl.logging.INFO)

### Download example data

In [None]:
_data_root = tempfile.mkdtemp(prefix='tfx-data')
DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/chicago_taxi_pipeline/data/simple/data.csv'
_data_filepath = os.path.join(_data_root, "data.csv")
urllib.request.urlretrieve(DATA_PATH, _data_filepath)

### Create the InteractiveContext
Last, we create an InteractiveContext, which will allow us to run TFX components interactively in this notebook.

In [None]:
# Here, we create an InteractiveContext using default parameters. This will
# use a temporary directory with an ephemeral ML Metadata database instance.
# To use your own pipeline root or database, the optional properties
# `pipeline_root` and `metadata_connection_config` may be passed to
# InteractiveContext. Calls to InteractiveContext are no-ops outside of the
# notebook.
context = InteractiveContext()

## Run TFX components interactively
In the cells that follow, we create and run TFX components one-by-one.  The first three components - `ExampleGen`, `StatisticsGen`, and `SchemaGen` - are initializing the metadata environment for `HelloComponent`.

`HelloComponent` is the focus of this tutorial.

### Ingest the Data With ExampleGen

In [None]:
example_gen = CsvExampleGen(input=external_input(_data_root))
context.run(example_gen, enable_cache=False)

### Calculate the Dataset Statistics With StatisticsGen

In [None]:
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
context.run(statistics_gen)

### Infer the Feature Types With SchemaGen

In [None]:
schema_gen = SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=False)
context.run(schema_gen)

## Add the HelloComponent

`HelloComponent` is our custom Python function component.  In it we read the dataset which was ingested by `ExampleGen` and the schema that was inferred by `SchemaGen`.  We then add a new feature to the dataset and schema, and output the results.

In [None]:
import six
import tensorflow_data_validation as tfdv
from tfx.dsl.component.experimental.annotations import OutputDict
from tfx.dsl.component.experimental.annotations import InputArtifact
from tfx.dsl.component.experimental.annotations import OutputArtifact
from tfx.dsl.component.experimental.annotations import Parameter
from tfx.dsl.component.experimental.decorators import component
from tfx.types import artifact_utils
from tfx.types.standard_artifacts import Examples
from tfx.types.standard_artifacts import Schema
from tensorflow_transform.tf_metadata import schema_utils
from tensorflow_metadata.proto.v0 import schema_pb2

class _feature_utils(object):
  @staticmethod
  def _bytes_feature(value) -> tf.train.Feature:
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
      value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    value = value[0] if value.size > 0 else bytes([]) 
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

  @staticmethod
  def _float_feature(value) -> tf.train.Feature:
    """Returns a float_list from a float / double."""
    value = [] if value.numpy().size == 0 else [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

  @staticmethod
  def _int64_feature(value) -> tf.train.Feature:
    """Returns an int64_list from a bool / enum / int / uint."""
    value = [] if value.numpy().size == 0 else [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

  @classmethod
  def example_from_tensor_dict(cls, tensor_dict) -> tf.train.Example:
    tfexample_dict = {}
    for name, tensor in six.iteritems(tensor_dict):
      val = tensor.values if hasattr(tensor, 'values') else tensor
      if tensor.dtype == 'int64':
        tfexample_dict[name] = cls._int64_feature(val)
      elif tensor.dtype == 'float32':
        tfexample_dict[name] = cls._float_feature(val)
      elif tensor.dtype == 'string':
        tfexample_dict[name] = cls._bytes_feature(val)
      else:
        raise ValueError('{} is an unknown type: {}'.format(name, tensor.dtype))
    
    return tf.train.Example(features=tf.train.Features(feature=tfexample_dict))

@component
def HelloComponent(
    input_data: InputArtifact[Examples],
    schema: InputArtifact[Schema],
    output_data: OutputArtifact[Examples],
    new_schema: OutputArtifact[Schema],
    new_feature_name: Parameter[str],
    component_name: Parameter[str]
    ) -> None:
  
  schema_proto = tfdv.load_schema_text(os.path.join(schema.uri, 'schema.pbtxt'))
  feature_spec, domains = schema_utils.schema_as_feature_spec(schema_proto)

  # Get a list of the splits in input_data
  splits_list = artifact_utils.decode_split_names(
    split_names=input_data.split_names
  )
  
  for split in splits_list:
    input_dir = os.path.join(input_data.uri, split)
    output_dir = os.path.join(output_data.uri, split)
    os.mkdir(output_dir)

    for tfrecord_filename in os.listdir(input_dir):
      input_path = os.path.join(input_dir, tfrecord_filename)
      output_path = os.path.join(output_dir, tfrecord_filename)
      with tf.io.TFRecordWriter(output_path, options="GZIP") as writer:
        # Read each tfrecord file in the input split
        for tfrecord in tf.data.TFRecordDataset(input_path, compression_type="GZIP"):
          tensor_dict = tf.io.parse_single_example(tfrecord, feature_spec)

          # Imagine that we want to add a new feature
          tensor_dict[new_feature_name] = tf.constant(42, dtype=tf.int64)
          
          result = _feature_utils.example_from_tensor_dict(tensor_dict)
          writer.write(result.SerializeToString())

  # Add the new feature to the schema
  new_feature = schema_pb2.Feature(
    name=new_feature_name,
    type=schema_pb2.INT,
    value_count={'min':1, 'max':1},
    presence={'min_fraction': 1.0, 'min_count': 1}
    )
  schema_proto.feature.append(new_feature)
  schema_path = os.path.join(new_schema.uri, 'schema.pbtxt')
  tfdv.write_schema_text(schema_proto, schema_path)

  # For completeness, encode the splits names.
  # We could also just use input_data.split_names.
  output_data.split_names = artifact_utils.encode_split_names(
      splits=splits_list
      )
  
  return

### Run HelloComponent

In [None]:
hello = HelloComponent(input_data=example_gen.outputs['examples'],
                       schema=schema_gen.outputs['schema'],
                       component_name=u'HelloWorld',
                       new_feature_name=u'new_feature')
context.run(hello, enable_cache=False)

### Examine the Output Artifacts

In [None]:
output_data = hello.outputs['output_data'].get()[0]
print('output_data: splits={}, URI={}'.format(output_data.split_names, output_data.uri))

new_schema = hello.outputs['new_schema'].get()[0]
print('new_schema: URI={}'.format(new_schema.uri))

In [None]:
# Get the URI of the output artifact representing the training examples, which is a directory
train_uri = os.path.join(output_data.uri, 'train')

# Get the list of files in this directory (all compressed TFRecord files)
tfrecord_filenames = [os.path.join(train_uri, name)
                      for name in os.listdir(train_uri)]

# Create a `TFRecordDataset` to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

# Iterate over the first 2 records and decode them.
for tfrecord in dataset.take(2):
  serialized_example = tfrecord.numpy()
  example = tf.train.Example()
  example.ParseFromString(serialized_example)
  print(example)

### Examine the Updated Schema

In [None]:
context.show(hello.outputs['new_schema'])

### Test With ExampleValidator

Using ExampleValidator as a test, try using the output from HelloComponent.

Note: Since we didn't update the statistics after adding a new feature, `ExampleValidator` should see this as an anomaly.  That's ok, because we've already run `ExampleValidator` on our original dataset to look for problems that we might not be aware of.  If you wanted to update the statistics, add another output artifact to `HelloComponent`, or follow `HelloComponent` with a second instance of `StatisticsGen`.

In [None]:
example_validator = ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=hello.outputs['new_schema'])
context.run(example_validator, enable_cache=False)

In [None]:
context.show(example_validator.outputs['anomalies'])