In [None]:
import os
import tensorflow_data_validation as tfdv
import tensorflow_model_analysis as tfma
import tempfile
from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.import_example_gen.component import ImportExampleGen
from tfx.components.example_validator.component import ExampleValidator
from tfx.components.model_validator.component import ModelValidator
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration.interactive.interactive_context import InteractiveContext
from tfx.proto import example_gen_pb2
from tfx.proto import evaluator_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.utils.dsl_utils import external_input
from tfx.utils.types import get_split_uri

In [None]:
_root = '.'
_data_root = _root + '/data'
_serving_model_dir = os.path.join(tempfile.mkdtemp(),'serving_model/mnist')

In [None]:
context = InteractiveContext()

In [None]:
input_split = example_gen_pb2.Input(splits=[
    example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'),
    example_gen_pb2.Input.Split(name='eval', pattern='test.tfrecord')
])
examples = external_input(_data_root)
example_gen = ImportExampleGen(input_base=examples, input_config=input_split)

context.run(example_gen)

In [None]:
statistics_gen = StatisticsGen(input_data=example_gen.outputs['examples'])
 
context.run(statistics_gen)

In [None]:
artifact_list = statistics_gen.outputs['output'].get()
train_artifact_uri = get_split_uri(artifact_list, 'train')
train_stats_path = os.path.join(train_artifact_uri, 'stats_tfrecord')
eval_artifact_uri = get_split_uri(artifact_list, 'eval')
eval_stats_path = os.path.join(eval_artifact_uri, 'stats_tfrecord')

In [None]:
train_stats = tfdv.load_statistics(train_stats_path)
eval_stats = tfdv.load_statistics(eval_stats_path)

tfdv.visualize_statistics(lhs_statistics=eval_stats, rhs_statistics=train_stats,
                          lhs_name='EVAL_DATASET', rhs_name='TRAIN_DATASET')

In [None]:
infer_schema = SchemaGen(stats=statistics_gen.outputs['output'])

context.run(infer_schema)

In [None]:
schema_dir = infer_schema.outputs['output'].get()[0].uri
schema_path = os.path.join(schema_dir, 'schema.pbtxt')

schema = tfdv.load_schema_text(schema_path)
tfdv.display_schema(schema)

In [None]:
validate_stats = ExampleValidator(stats=statistics_gen.outputs['output'],
                                  schema=infer_schema.outputs['output'])

context.run(validate_stats)

In [None]:
validation_dir = validate_stats.outputs['output'].get()[0].uri
anomalies_path = os.path.join(validation_dir, 'anomalies.pbtxt')

anomalies = tfdv.load_anomalies_text(anomalies_path)
tfdv.display_anomalies(anomalies)

In [None]:
_module_file = os.path.join(_root, 'cifar10_utils.py')

transform = Transform(
    input_data=example_gen.outputs['examples'],
    schema=infer_schema.outputs['output'],
    module_file=_module_file)

context.run(transform)

In [None]:
trainer = Trainer(
    module_file=_module_file,
    transformed_examples=transform.outputs['transformed_examples'],
    schema=infer_schema.outputs['output'],
    transform_output=transform.outputs['transform_output'],
    train_args=trainer_pb2.TrainArgs(num_steps=1000),
    eval_args=trainer_pb2.EvalArgs(num_steps=500))
context.run(trainer)

In [None]:
OVERALL_SLICE_SPEC = evaluator_pb2.SingleSlicingSpec()

LABEL_COLUMN_SLICE_SPEC = evaluator_pb2.SingleSlicingSpec(
    column_for_slicing=['label'])

ALL_SPECS = [
    OVERALL_SLICE_SPEC,
    LABEL_COLUMN_SLICE_SPEC
]

In [None]:
model_analyzer = Evaluator(
    examples=example_gen.outputs['examples'],
    model_exports=trainer.outputs['output'],
    feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(
        specs=ALL_SPECS
    ))

context.run(model_analyzer)

In [None]:
PATH_TO_RESULT = model_analyzer.outputs['output'].get()[0].uri
tfma_result = tfma.load_eval_result(PATH_TO_RESULT)
tfma.view.render_slicing_metrics(tfma_result)

In [None]:
model_validator = ModelValidator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['output'])

context.run(model_validator)

In [None]:
pusher = Pusher(
    model_export=trainer.outputs['output'],
    model_blessing=model_validator.outputs['blessing'],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))

context.run(pusher)