In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pprint
import tempfile
import urllib

import absl
import tensorflow as tf
import tensorflow_model_analysis as tfma
tf.get_logger().propagate = False
pp = pprint.PrettyPrinter()

from typing import List, Text

from tfx.components import CsvExampleGen
from tfx.components import Evaluator
from tfx.components import ExampleValidator
from tfx.components import Pusher
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components import Transform
from tfx.components.trainer.executor import Executor
from tfx.dsl.components.base import executor_spec
from tfx.dsl.components.common import resolver
from tfx.dsl.experimental import latest_artifacts_resolver
from tfx.dsl.experimental import latest_blessed_model_resolver
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.local.local_dag_runner import LocalDagRunner
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.types import Channel
from tfx.types.standard_artifacts import Model
from tfx.types.standard_artifacts import ModelBlessing

In [None]:
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
context = InteractiveContext()

In [None]:
_pipeline_name = 'sampling_credit_card'
_sampling_root = os.path.dirname(".")
_data_root = os.path.join(_sampling_root, 'data')
# Python module file to inject customized logic into the TFX components. The
# Transform and Trainer both require user-defined functions to run successfully.
_module_file = os.path.join(_sampling_root, 'sampler_utils.py')
_serving_model_dir = os.path.join(_sampling_root, 'serving_model', _pipeline_name)
_tfx_root = os.path.join(os.environ['HOME'], 'tfx')
_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name)
_metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name,
                              'metadata.db')

In [None]:
_beam_pipeline_args = [
    '--direct_running_mode=multi_processing',
    # 0 means auto-detect based on on the number of CPUs available
    # during execution time.
    '--direct_num_workers=0',
]

In [None]:
example_gen = CsvExampleGen(input_base=_data_root)
context.run(example_gen)

In [None]:
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
context.run(statistics_gen)

In [None]:
context.show(statistics_gen.outputs['statistics'])

In [None]:
schema_gen = SchemaGen(
  statistics=statistics_gen.outputs['statistics'],
  infer_feature_shape=False)
context.run(schema_gen)

In [None]:
example_validator = ExampleValidator(
  statistics=statistics_gen.outputs['statistics'],
  schema=schema_gen.outputs['schema'])
context.run(example_validator)

In [None]:
from tfx_addons.sampling.component import Sampler

sampler = Sampler(
  input_data=example_gen.outputs['examples'],
  splits=['train'],
  label='Class',
)
context.run(sampler)

In [None]:
sampler_stats = StatisticsGen(examples=sampler.outputs['output_data'])
context.run(sampler_stats)

In [None]:
context.show(sampler_stats.outputs['statistics'])

In [None]:
transform = Transform(
  examples=sampler.outputs['output_data'],
  schema=schema_gen.outputs['schema'],
  module_file=_module_file)
context.run(transform)

In [None]:
latest_model_resolver = resolver.Resolver(
  strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
  latest_model=Channel(type=Model)).with_id('latest_model_resolver')
context.run(latest_model_resolver)

In [None]:
trainer = Trainer(
  module_file=_module_file,
  custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
  transformed_examples=transform.outputs['transformed_examples'],
  schema=schema_gen.outputs['schema'],
  base_model=latest_model_resolver.outputs['latest_model'],
  transform_graph=transform.outputs['transform_graph'],
  train_args=trainer_pb2.TrainArgs(num_steps=10000),
  eval_args=trainer_pb2.EvalArgs(num_steps=5000))
context.run(trainer)

In [None]:
model_resolver = resolver.Resolver(
  strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
  model=Channel(type=Model),
  model_blessing=Channel(
      type=ModelBlessing)).with_id('latest_blessed_model_resolver')
context.run(model_resolver)

In [None]:
eval_config = tfma.EvalConfig(
  model_specs=[tfma.ModelSpec(signature_name='eval')],
  slicing_specs=[
      tfma.SlicingSpec(),
      tfma.SlicingSpec(feature_keys=['trip_start_hour'])
  ],
  metrics_specs=[
      tfma.MetricsSpec(
          thresholds={
              'accuracy':
                  tfma.config.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.6}),
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-10}))
          })
  ])

In [None]:
evaluator = Evaluator(
  examples=example_gen.outputs['examples'],
  model=trainer.outputs['model'],
  baseline_model=model_resolver.outputs['model'],
  eval_config=eval_config)
context.run(evaluator)

In [None]:
pusher = Pusher(
  model=trainer.outputs['model'],
  model_blessing=evaluator.outputs['blessing'],
  push_destination=pusher_pb2.PushDestination(
      filesystem=pusher_pb2.PushDestination.Filesystem(
          base_directory=_serving_model_dir)))
context.run(pusher)