TensorFlow Extended (TFX) is used to create machine learning pipelines. It is open-source and is used by many large companies to put their machine learning model into production.

![im](https://i.imgur.com/Npjr3NK.png)


**Pipeline :**
A TFX pipeline is used to implement an ML pipeline which remains intact for the entire lifetime of the product. They are built using tfx components. 

![](https://i.imgur.com/z3mMdts.png)

**Metadata store**-

It is a single place to manage all the ML metadata about experiments, artifacts, models, and pipelines.
A store consists of - 
Definitions of artifacts and their properties
Execution records of components
Lineage tracking across all executions

![](https://i.imgur.com/DwDEvHq.png)

**TFX Orchestrator**-
An Orchestrator is a system where you can execute pipeline runs. Apache airflow, kubeflow are used as orchestrators. Dagrunner is used to refer to an implementation that supports orchestrator.

**Components**-
A component is used to implement a part of the ML pipeline. ML pipeline consists of many such components. Components are composed of:

component specification- Defines the component's input, output and required parameters.
executor- Implements the code in that component
component interface- Contains component specification and executor
 
Explaining each component here-

**ExampleGen**- This component ingests data into the pipelines. It takes external files (in formats such as CSV, TFRecord, BigQuery etc),partitions and shuffles it, to generate the Examples file. 

![](https://i.imgur.com/hrb82Ev.png)

**StatisticsGen**-
It generates features statistics over Example data and releases data statistics for other pipeline components.

![](https://i.imgur.com/Ym9pOZd.png)

**SchemaGen**
Schema provides a description of input data. SchemaGen automatically generates a schema from the training data and provides detail about the features, their allowed values and data types for feature values.

![](https://i.imgur.com/MySRzxs.png)

**ExampleValidator**
It finds anomalies in training and serving data by comparing data statistics computed by the StatisticsGen pipeline component against a schema. For example: 

![](https://i.imgur.com/FMi2ZWn.png)

**Transform**
It performs feature engineering on Example data and releases both SavedModel as well as statistics on both pre-transform and post-transform data.

![](https://i.imgur.com/R1OYqkV.png)

**Trainer**
Trainer is used to train models using tensorflow.
It takes Examples data, logic for training the data and protobuf definition for trainargs and eval argos. It releases a model for inference and maybe another one for evaluation

![](https://i.imgur.com/u6NqxPJ.png)

**Evaluator**
It performs analysis on the training results for the models to access overall model quality and track performance over time. It validates if the model can be pushed for production.
If the new model's metrics meet the baseline model requirements, the model is said to be "blessed", and is pushed to production.

![](https://i.imgur.com/9dXiq4y.png)

**Pusher**
IT is used to push a validated model to a deployment. 
 
The Pusher component pushes a validated model to a deployment target during model training . Pusher decides to push or not based on the blessings from the other components.
Evaluator blesses the model if it is "good enough" to be pushed to production.
InfraValidator blesses the model if the model is mechanically servable in a production environment.
A Pusher component consumes a trained model in SavedModel format, and outputs  the same SavedModel, along with versioning metadata.
 
 ![](https://i.imgur.com/rqyuCgC.png)·        
 

The ExampleGen TFX Pipeline component ingests data into TFX pipelines by consuming external data sources such as CSV, TFRecord, Avro, Parquet and BigQuery to generate Examples( tf.Example records, tf.SequenceExample records, or proto format) which will be read by other TFX components. 
ExampleGen and Other Components
ExampleGen provides data to components that make use of the TensorFlow Data Validation library, such as SchemaGen, StatisticsGen, and Example Validator. It also provides data to Transform, which makes use of the TensorFlow Transform library, and ultimately to deployment targets during inference.
 

In [None]:
try:
  import colab
  !pip install --upgrade pip
except:
  pass

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pip
  Downloading pip-22.1.2-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 30.3 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.1.2


In [None]:
!pip install -U tfx

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tfx
  Downloading tfx-1.8.0-py3-none-any.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m64.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting apache-beam[gcp]<3,>=2.38
  Downloading apache_beam-2.39.0-cp37-cp37m-manylinux2010_x86_64.whl (10.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.3/10.3 MB[0m [31m79.3 MB/s[0m eta [36m0:00:00[0m
Collecting tfx-bsl<1.9.0,>=1.8.0
  Downloading tfx_bsl-1.8.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (19.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.2/19.2 MB[0m [31m65.6 MB/s[0m eta [36m0:00:00[0m
Collecting attrs<21,>=19.3.0
  Downloading attrs-20.3.0-py2.py3-none-any.whl (49 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.3/49.3 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
Collecting tensorfl

Restart runtime before running the next cell

In [None]:
import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))

TensorFlow version: 2.8.0
TFX version: 1.7.1


Setting up variables

In [None]:
import os
#pipeline for schema generation
PIPELINE_NAME = "spaceship-simple"
SCHEMA_PIPELINE_NAME = "spaceship-tfdv-schema"

# Output directory to store artifacts generated from the pipeline
PIPELINE_ROOT = os.path.join('pipelines', PIPELINE_NAME)
SCHEMA_PIPELINE_ROOT = os.path.join('pipelines', SCHEMA_PIPELINE_NAME)

# Path to a SQLite DB file to use as an MLMD storage
METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')
SCHEMA_METADATA_PATH = os.path.join('metadata', SCHEMA_PIPELINE_NAME,
                                    'metadata.db')

# Output directory where created models from the pipeline will be exported
SERVING_MODEL_DIR = os.path.join('serving_model', PIPELINE_NAME)

from absl import logging
logging.set_verbosity(logging.INFO)  # Set default logging level

Import data

In [None]:
import urllib.request
import tempfile

DATA_ROOT = tempfile.mkdtemp(prefix='tfx-data')  # Create a temporary directory.
_data_url = 'https://raw.githubusercontent.com/ushareng/TFX_SpaceShipTitanic/main/train.csv'
_data_filepath = os.path.join(DATA_ROOT, "data.csv")
urllib.request.urlretrieve(_data_url, _data_filepath)

('/tmp/tfx-data78swgyw3/data.csv', <http.client.HTTPMessage at 0x7f311559a550>)

In [None]:
!head {_data_filepath}

PassengerId,HomePlanet,CryoSleep,Cabin,Destination,Age,VIP,RoomService,FoodCourt,ShoppingMall,Spa,VRDeck,Name,Transported
0001_01,Europa,False,B/0/P,TRAPPIST-1e,39.0,False,0.0,0.0,0.0,0.0,0.0,Maham Ofracculy,False
0002_01,Earth,False,F/0/S,TRAPPIST-1e,24.0,False,109.0,9.0,25.0,549.0,44.0,Juanna Vines,True
0003_01,Europa,False,A/0/S,TRAPPIST-1e,58.0,True,43.0,3576.0,0.0,6715.0,49.0,Altark Susent,False
0003_02,Europa,False,A/0/S,TRAPPIST-1e,33.0,False,0.0,1283.0,371.0,3329.0,193.0,Solam Susent,False
0004_01,Earth,False,F/1/S,TRAPPIST-1e,16.0,False,303.0,70.0,151.0,565.0,2.0,Willy Santantines,True
0005_01,Earth,False,F/0/P,PSO J318.5-22,44.0,False,0.0,483.0,0.0,291.0,0.0,Sandie Hinetthews,True
0006_01,Earth,False,F/2/S,TRAPPIST-1e,26.0,False,42.0,1539.0,3.0,0.0,0.0,Billex Jacostaffey,True
0006_02,Earth,True,G/0/S,TRAPPIST-1e,28.0,False,0.0,0.0,0.0,0.0,,Candra Jacostaffey,True
0007_01,Earth,False,F/3/S,TRAPPIST-1e,35.0,False,0.0,785.0,17.0,216.0,0.0,Andona Beston,True


In [None]:
!sed -i '/\bNA\b/d' {_data_filepath}
!head {_data_filepath}

PassengerId,HomePlanet,CryoSleep,Cabin,Destination,Age,VIP,RoomService,FoodCourt,ShoppingMall,Spa,VRDeck,Name,Transported
0001_01,Europa,False,B/0/P,TRAPPIST-1e,39.0,False,0.0,0.0,0.0,0.0,0.0,Maham Ofracculy,False
0002_01,Earth,False,F/0/S,TRAPPIST-1e,24.0,False,109.0,9.0,25.0,549.0,44.0,Juanna Vines,True
0003_01,Europa,False,A/0/S,TRAPPIST-1e,58.0,True,43.0,3576.0,0.0,6715.0,49.0,Altark Susent,False
0003_02,Europa,False,A/0/S,TRAPPIST-1e,33.0,False,0.0,1283.0,371.0,3329.0,193.0,Solam Susent,False
0004_01,Earth,False,F/1/S,TRAPPIST-1e,16.0,False,303.0,70.0,151.0,565.0,2.0,Willy Santantines,True
0005_01,Earth,False,F/0/P,PSO J318.5-22,44.0,False,0.0,483.0,0.0,291.0,0.0,Sandie Hinetthews,True
0006_01,Earth,False,F/2/S,TRAPPIST-1e,26.0,False,42.0,1539.0,3.0,0.0,0.0,Billex Jacostaffey,True
0006_02,Earth,True,G/0/S,TRAPPIST-1e,28.0,False,0.0,0.0,0.0,0.0,,Candra Jacostaffey,True
0007_01,Earth,False,F/3/S,TRAPPIST-1e,35.0,False,0.0,785.0,17.0,216.0,0.0,Andona Beston,True


# Generating preliminary schema
(Example, Statistics and Schema Gen)

In [None]:
def _create_schema_pipeline(pipeline_name: str,
                            pipeline_root: str,
                            data_root: str,
                            metadata_path: str) -> tfx.dsl.Pipeline:
  """Creates a pipeline for schema generation."""
  # ExampleGen
  example_gen = tfx.components.CsvExampleGen(input_base=data_root)

  # StatisticsGen
  statistics_gen = tfx.components.StatisticsGen(
      examples=example_gen.outputs['examples'])

  # SchemaGen
  schema_gen = tfx.components.SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  components = [
      example_gen,
      statistics_gen,
      schema_gen,
  ]

  return tfx.dsl.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=tfx.orchestration.metadata
      .sqlite_metadata_connection_config(metadata_path),
      components=components)

Running pipeline

In [None]:
tfx.orchestration.LocalDagRunner().run(
  _create_schema_pipeline(
      pipeline_name=SCHEMA_PIPELINE_NAME,
      pipeline_root=SCHEMA_PIPELINE_ROOT,
      data_root=DATA_ROOT,
      metadata_path=SCHEMA_METADATA_PATH))

INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Using deployment config:
 executor_specs {
  key: "CsvExampleGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.example_gen.csv_example_gen.executor.Executor"
      }
    }
  }
}
executor_specs {
  key: "SchemaGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.schema_gen.executor.Executor"
    }
  }
}
executor_specs {
  key: "StatisticsGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.statistics_gen.executor.Executor"
      }
    }
  }
}
custom_driver_specs {
  key: "CsvExampleGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_gen.driver.FileBasedDriver"
    }
  }
}
metadata_connection_config {
  database_connection_config {
    sqlite {
      filename_uri: "metada

Review outputs of the pipeline

In [None]:
from ml_metadata.proto import metadata_store_pb2
# Non-public APIs, just for showcase.
from tfx.orchestration.portable.mlmd import execution_lib

# TODO(b/171447278): Move these functions into the TFX library.

def get_latest_artifacts(metadata, pipeline_name, component_id):
  """Output artifacts of the latest run of the component."""
  context = metadata.store.get_context_by_type_and_name(
      'node', f'{pipeline_name}.{component_id}')
  executions = metadata.store.get_executions_by_context(context.id)
  latest_execution = max(executions,
                         key=lambda e:e.last_update_time_since_epoch)
  return execution_lib.get_artifacts_dict(metadata, latest_execution.id,
                                          [metadata_store_pb2.Event.OUTPUT])

# Non-public APIs, just for showcase.
from tfx.orchestration.experimental.interactive import visualizations

def visualize_artifacts(artifacts):
  """Visualizes artifacts using standard visualization modules."""
  for artifact in artifacts:
    visualization = visualizations.get_registry().get_visualization(
        artifact.type_name)
    if visualization:
      visualization.display(artifact)

from tfx.orchestration.experimental.interactive import standard_visualizations
standard_visualizations.register_standard_visualizations()

Examine outputs of pipeline execution

In [None]:
# Non-public APIs, just for showcase.
from tfx.orchestration.metadata import Metadata
from tfx.types import standard_component_specs

metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(
    SCHEMA_METADATA_PATH)

with Metadata(metadata_connection_config) as metadata_handler:
  # Find output artifacts from MLMD.
  stat_gen_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME,
                                         'StatisticsGen')
  stats_artifacts = stat_gen_output[standard_component_specs.STATISTICS_KEY]

  schema_gen_output = get_latest_artifacts(metadata_handler,
                                           SCHEMA_PIPELINE_NAME, 'SchemaGen')
  schema_artifacts = schema_gen_output[standard_component_specs.SCHEMA_KEY]

INFO:absl:MetadataStore with DB connection initialized


Examine outputs of statisticsGen

In [None]:
# docs-infra: no-execute
visualize_artifacts(stats_artifacts)

Outputs from schemaGen

In [None]:
visualize_artifacts(schema_artifacts)

Unnamed: 0_level_0,Type,Presence,Valency,Domain
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
'Cabin',BYTES,required,,-
'CryoSleep',STRING,required,,'CryoSleep'
'Destination',STRING,required,,'Destination'
'HomePlanet',STRING,required,,'HomePlanet'
'Name',BYTES,required,,-
'Transported',STRING,required,,'Transported'
'VIP',STRING,required,,'VIP'
'Age',FLOAT,required,,-
'FoodCourt',FLOAT,required,,-
'PassengerId',INT,required,,-


Unnamed: 0_level_0,Values
Domain,Unnamed: 1_level_1
'CryoSleep',"'False', 'True'"
'Destination',"'55 Cancri e', 'PSO J318.5-22', 'TRAPPIST-1e'"
'HomePlanet',"'Earth', 'Europa', 'Mars'"
'Transported',"'False', 'True'"
'VIP',"'False', 'True'"


Export schema

In [None]:
import shutil

_schema_filename = 'schema.pbtxt'
SCHEMA_PATH = 'schema'

os.makedirs(SCHEMA_PATH, exist_ok=True)
_generated_path = os.path.join(schema_artifacts[0].uri, _schema_filename)

# Copy the 'schema.pbtxt' file from the artifact uri to a predefined path.
shutil.copy(_generated_path, SCHEMA_PATH)

'schema/schema.pbtxt'

In [None]:
print(f'Schema at {SCHEMA_PATH}-----')
!cat {SCHEMA_PATH}/*

Schema at schema-----
feature {
  name: "Cabin"
  type: BYTES
  presence {
    min_fraction: 1.0
    min_count: 1
  }
}
feature {
  name: "CryoSleep"
  type: BYTES
  domain: "CryoSleep"
  presence {
    min_fraction: 1.0
    min_count: 1
  }
}
feature {
  name: "Destination"
  type: BYTES
  domain: "Destination"
  presence {
    min_fraction: 1.0
    min_count: 1
  }
}
feature {
  name: "HomePlanet"
  type: BYTES
  domain: "HomePlanet"
  presence {
    min_fraction: 1.0
    min_count: 1
  }
}
feature {
  name: "Name"
  type: BYTES
  presence {
    min_fraction: 1.0
    min_count: 1
  }
}
feature {
  name: "Transported"
  type: BYTES
  domain: "Transported"
  presence {
    min_fraction: 1.0
    min_count: 1
  }
  shape {
    dim {
      size: 1
    }
  }
}
feature {
  name: "VIP"
  type: BYTES
  domain: "VIP"
  presence {
    min_fraction: 1.0
    min_count: 1
  }
}
feature {
  name: "Age"
  type: FLOAT
  presence {
    min_fraction: 1.0
    min_count: 1
  }
}
feature {
  name: "FoodCo

#Create pipeline
(Transform and trainer component)

In [None]:
_module_file = 'spaceship_trainer.py'

In [None]:
%%writefile {_module_file}


from typing import List, Text
from absl import logging
import tensorflow as tf
from tensorflow import keras
from tensorflow_metadata.proto.v0 import schema_pb2
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils

from tfx import v1 as tfx
from tfx_bsl.public import tfxio


#      Transfrom component  
_FEATURE_KEYS = [
     'Age','RoomService','FoodCourt','ShoppingMall','Spa','VRDeck'
]
_LABEL_KEY = 'Transported'

_TRAIN_BATCH_SIZE = 20
_EVAL_BATCH_SIZE = 10


# NEW: TFX Transform will call this function.
def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.

  Returns:
    Map from string feature key to transformed feature.
  """
  outputs = {}

  # Uses features defined in _FEATURE_KEYS only.
  for key in _FEATURE_KEYS:
    # tft.scale_to_z_score computes the mean and variance of the given feature
    # and scales the output based on the result.
    outputs[key] = tft.scale_to_z_score(inputs[key])

  # For the label column we provide the mapping from string to index.
  # We could instead use `tft.compute_and_apply_vocabulary()` in order to
  # compute the vocabulary dynamically and perform a lookup.
  # Since in this example there are only 2 possible values, we use a hard-coded
  # table for simplicity.
  table_keys = ['True', 'False']
  initializer = tf.lookup.KeyValueTensorInitializer(
      keys=table_keys,
      values=tf.cast(tf.range(len(table_keys)), tf.float32),
      key_dtype=tf.string,
      value_dtype=tf.float32)
  table = tf.lookup.StaticHashTable(initializer, default_value=-1)
  outputs[_LABEL_KEY] = table.lookup(inputs[_LABEL_KEY])

  return outputs


# NEW: This function will apply the same transform operation to training data
#      and serving requests.
def _apply_preprocessing(raw_features, tft_layer):
  transformed_features = tft_layer(raw_features)
  if _LABEL_KEY in raw_features:
    transformed_label = transformed_features.pop(_LABEL_KEY)
    return transformed_features, transformed_label
  else:
    return transformed_features, None


# NEW: This function will create a handler function which gets a serialized
#      tf.example, preprocess and run an inference with it.
def _get_serve_tf_examples_fn(model, tf_transform_output):
  # We must save the tft_layer to the model to ensure its assets are kept and
  # tracked.
  model.tft_layer = tf_transform_output.transform_features_layer()

  @tf.function(input_signature=[
      tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
  ])
  def serve_tf_examples_fn(serialized_tf_examples):
    # Expected input is a string which is serialized tf.Example format.
    feature_spec = tf_transform_output.raw_feature_spec()
    # Because input schema includes unnecessary fields like 'species' and
    # 'island', we filter feature_spec to include required keys only.
    required_feature_spec = {
        k: v for k, v in feature_spec.items() if k in _FEATURE_KEYS
    }
    parsed_features = tf.io.parse_example(serialized_tf_examples,
                                          required_feature_spec)

    # Preprocess parsed input with transform operation defined in
    # preprocessing_fn().
    transformed_features, _ = _apply_preprocessing(parsed_features,
                                                   model.tft_layer)
    # Run inference with ML model.
    return model(transformed_features)

  return serve_tf_examples_fn


def _input_fn(file_pattern: List[Text],
              data_accessor: tfx.components.DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
  """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  dataset = data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(batch_size=batch_size),
      schema=tf_transform_output.raw_metadata.schema)

  transform_layer = tf_transform_output.transform_features_layer()
  def apply_transform(raw_features):
    return _apply_preprocessing(raw_features, transform_layer)

  return dataset.map(apply_transform).repeat()


def _build_keras_model() -> tf.keras.Model:
  """Creates a DNN Keras model for classifying spaceship data.

  Returns:
    A Keras Model.
  """
  # The model below is built with Functional API, please refer to
  # https://www.tensorflow.org/guide/keras/overview for all API options.
  inputs = [
      keras.layers.Input(shape=(1,), name=key)
      for key in _FEATURE_KEYS
  ]
  d = keras.layers.concatenate(inputs)
  for _ in range(2):
    d = keras.layers.Dense(8, activation='relu')(d)
  outputs = keras.layers.Dense(2)(d)

  model = keras.Model(inputs=inputs, outputs=outputs)
  model.compile(
      optimizer=keras.optimizers.Adam(1e-2),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[keras.metrics.SparseCategoricalAccuracy()])

  model.summary(print_fn=logging.info)
  return model


# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  train_dataset = _input_fn(
      fn_args.train_files,
      fn_args.data_accessor,
      tf_transform_output,
      batch_size=_TRAIN_BATCH_SIZE)
  eval_dataset = _input_fn(
      fn_args.eval_files,
      fn_args.data_accessor,
      tf_transform_output,
      batch_size=_EVAL_BATCH_SIZE)

  model = _build_keras_model()
  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps)

  # NEW: Save a computation graph including transform layer.
  signatures = {
      'serving_default': _get_serve_tf_examples_fn(model, tf_transform_output),
  }
  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)

Overwriting spaceship_trainer.py


In [None]:
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     schema_path: str, module_file: str, serving_model_dir: str,
                     metadata_path: str) -> tfx.dsl.Pipeline:
  """Implements the spaceship pipeline with TFX."""
  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = tfx.components.CsvExampleGen(input_base=data_root)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = tfx.components.StatisticsGen(
      examples=example_gen.outputs['examples'])

  # Import the schema.
  schema_importer = tfx.dsl.Importer(
      source_uri=schema_path,
      artifact_type=tfx.types.standard_artifacts.Schema).with_id(
          'schema_importer')


########       ExampleValidator                               ##########
  # Performs anomaly detection based on statistics and data schema.
  example_validator = tfx.components.ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_importer.outputs['result'])


##########       Transform             #################
  # NEW: Transforms input data using preprocessing_fn in the 'module_file'.
  transform = tfx.components.Transform(
      examples=example_gen.outputs['examples'],
      schema=schema_importer.outputs['result'],
      materialize=False,
      module_file=module_file)

  # Uses user-provided Python function that trains a model.
  trainer = tfx.components.Trainer(
      module_file=module_file,
      examples=example_gen.outputs['examples'],

      # NEW: Pass transform_graph to the trainer.
      transform_graph=transform.outputs['transform_graph'],

      train_args=tfx.proto.TrainArgs(num_steps=100),
      eval_args=tfx.proto.EvalArgs(num_steps=5))

  # Pushes the model to a filesystem destination.
  pusher = tfx.components.Pusher(
      model=trainer.outputs['model'],
      push_destination=tfx.proto.PushDestination(
          filesystem=tfx.proto.PushDestination.Filesystem(
              base_directory=serving_model_dir)))

  components = [
      example_gen,
      statistics_gen,
      schema_importer,
      example_validator,

      transform,  # NEW: Transform component was added to the pipeline.

      trainer,
      pusher,
  ]

  return tfx.dsl.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=tfx.orchestration.metadata
      .sqlite_metadata_connection_config(metadata_path),
      components=components)

Running the pipeline with local dag runner

In [None]:
tfx.orchestration.LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=PIPELINE_NAME,
      pipeline_root=PIPELINE_ROOT,
      data_root=DATA_ROOT,
      schema_path=SCHEMA_PATH,
      module_file=_module_file,
      serving_model_dir=SERVING_MODEL_DIR,
      metadata_path=METADATA_PATH))

INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Generating ephemeral wheel package for '/content/spaceship_trainer.py' (including modules: ['spaceship_trainer']).
INFO:absl:User module package has hash fingerprint version 84965ad71ab30879935df295c64020d86df1ed1f8cfb40c743a49d90f4ed811c.
INFO:absl:Executing: ['/usr/bin/python3', '/tmp/tmphbl342pp/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmp/tmpsvn1d34e', '--dist-dir', '/tmp/tmp8ku8l47v']
INFO:absl:Successfully built user code wheel distribution at 'pipelines/spaceship-simple/_wheels/tfx_user_code_Transform-0.0+84965ad71ab30879935df295c64020d86df1ed1f8cfb40c743a49d90f4ed811c-py3-none-any.whl'; target user module is 'spaceship_trainer'.
INFO:absl:Full user module path is 'spaceship_trainer@pipelines/spaceship-simple/_wheels/tfx_user_code_Transform-0.0+84965ad71ab30879935df295c64020d86df1ed1f8cfb40c743a49d90f4ed811c-py3-none-any.whl

INFO:tensorflow:Assets written to: pipelines/spaceship-simple/Transform/transform_graph/11/.temp_path/tftransform_tmp/e6c5c7bdf1a34c3a8648bdb2a9e65347/assets


INFO:tensorflow:Assets written to: pipelines/spaceship-simple/Transform/transform_graph/11/.temp_path/tftransform_tmp/e6c5c7bdf1a34c3a8648bdb2a9e65347/assets


INFO:tensorflow:struct2tensor is not available.


INFO:tensorflow:struct2tensor is not available.


INFO:tensorflow:tensorflow_decision_forests is not available.


INFO:tensorflow:tensorflow_decision_forests is not available.


INFO:tensorflow:tensorflow_text is not available.


INFO:tensorflow:tensorflow_text is not available.


INFO:tensorflow:Assets written to: pipelines/spaceship-simple/Transform/transform_graph/11/.temp_path/tftransform_tmp/129733c0de08487781cd29cf481aa489/assets


INFO:tensorflow:Assets written to: pipelines/spaceship-simple/Transform/transform_graph/11/.temp_path/tftransform_tmp/129733c0de08487781cd29cf481aa489/assets


INFO:tensorflow:struct2tensor is not available.


INFO:tensorflow:struct2tensor is not available.


INFO:tensorflow:tensorflow_decision_forests is not available.


INFO:tensorflow:tensorflow_decision_forests is not available.


INFO:tensorflow:tensorflow_text is not available.


INFO:tensorflow:tensorflow_text is not available.


INFO:tensorflow:struct2tensor is not available.


INFO:tensorflow:struct2tensor is not available.


INFO:tensorflow:tensorflow_decision_forests is not available.


INFO:tensorflow:tensorflow_decision_forests is not available.


INFO:tensorflow:tensorflow_text is not available.


INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 11 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'post_transform_schema': [Artifact(artifact: uri: "pipelines/spaceship-simple/Transform/post_transform_schema/11"
custom_properties {
  key: "name"
  value {
    string_value: "spaceship-simple:2022-05-23T09:20:04.618782:Transform:post_transform_schema:0"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.7.1"
  }
}
, artifact_type: name: "Schema"
)], 'pre_transform_schema': [Artifact(artifact: uri: "pipelines/spaceship-simple/Transform/pre_transform_schema/11"
custom_properties {
  key: "name"
  value {
    string_value: "spaceship-simple:2022-05-23T09:20:04.618782:Transform:pre_transform_schema:0"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.7.1"
  }
}
, artifact_type: name: "Schema"

INFO:tensorflow:struct2tensor is not available.


INFO:tensorflow:struct2tensor is not available.


INFO:tensorflow:tensorflow_decision_forests is not available.


INFO:tensorflow:tensorflow_decision_forests is not available.


INFO:tensorflow:tensorflow_text is not available.


INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Feature Cabin has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature CryoSleep has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature Destination has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature HomePlanet has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature Name has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature Transported has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature VIP has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature Age has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature FoodCourt has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature PassengerId has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature RoomService has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature ShoppingMall has no shape. Setting to VarLenSparseTensor.
INFO:absl:Feature Spa has no shape. Setting to VarLenSparseTe

INFO:tensorflow:Assets written to: pipelines/spaceship-simple/Trainer/model/13/Format-Serving/assets


INFO:tensorflow:Assets written to: pipelines/spaceship-simple/Trainer/model/13/Format-Serving/assets
INFO:absl:Training complete. Model written to pipelines/spaceship-simple/Trainer/model/13/Format-Serving. ModelRun written to pipelines/spaceship-simple/Trainer/model_run/13
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 13 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'model': [Artifact(artifact: uri: "pipelines/spaceship-simple/Trainer/model/13"
custom_properties {
  key: "name"
  value {
    string_value: "spaceship-simple:2022-05-23T09:20:04.618782:Trainer:model:0"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.7.1"
  }
}
, artifact_type: name: "Model"
base_type: MODEL
)], 'model_run': [Artifact(artifact: uri: "pipelines/spaceship-simple/Trainer/model_run/13"
custom_properties {
  key: "name"
  value {
    string_value: "spaceship-simple:2022-05-23T09:20

Pusher component

In [None]:
# List files in created model directory.
!find {SERVING_MODEL_DIR}

serving_model/spaceship-simple
serving_model/spaceship-simple/1653297520
serving_model/spaceship-simple/1653297520/saved_model.pb
serving_model/spaceship-simple/1653297520/keras_metadata.pb
serving_model/spaceship-simple/1653297520/assets
serving_model/spaceship-simple/1653297520/variables
serving_model/spaceship-simple/1653297520/variables/variables.index
serving_model/spaceship-simple/1653297520/variables/variables.data-00000-of-00001
serving_model/spaceship-simple/1653297673
serving_model/spaceship-simple/1653297673/saved_model.pb
serving_model/spaceship-simple/1653297673/keras_metadata.pb
serving_model/spaceship-simple/1653297673/assets
serving_model/spaceship-simple/1653297673/variables
serving_model/spaceship-simple/1653297673/variables/variables.index
serving_model/spaceship-simple/1653297673/variables/variables.data-00000-of-00001


In [None]:
!saved_model_cli show --dir {SERVING_MODEL_DIR}/$(ls -1 {SERVING_MODEL_DIR} | sort -nr | head -1) --tag_set serve --signature_def serving_default

The given SavedModel SignatureDef contains the following input(s):
  inputs['examples'] tensor_info:
      dtype: DT_STRING
      shape: (-1)
      name: serving_default_examples:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['output_0'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 2)
      name: StatefulPartitionedCall_2:0
Method name is: tensorflow/serving/predict


In [None]:
# Find a model with the latest timestamp.
model_dirs = (item for item in os.scandir(SERVING_MODEL_DIR) if item.is_dir())
model_path = max(model_dirs, key=lambda i: int(i.name)).path

loaded_model = tf.keras.models.load_model(model_path)
inference_fn = loaded_model.signatures['serving_default']





In [None]:
# Prepare an example and run inference.
features = {
  'Age': tf.train.Feature(float_list=tf.train.FloatList(value=[49.9])),
  'RoomService': tf.train.Feature(float_list=tf.train.FloatList(value=[16.1])),
  'FoodCourt': tf.train.Feature(float_list=tf.train.FloatList(value=[23.0])),
  'ShoppingMall': tf.train.Feature(float_list=tf.train.FloatList(value=[12.0])),
  'Spa': tf.train.Feature(float_list=tf.train.FloatList(value=[3.0])),
  'VRDeck': tf.train.Feature(float_list=tf.train.FloatList(value=[4.0])),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=features))
examples = example_proto.SerializeToString()

result = inference_fn(examples=tf.constant([examples]))
print(result['output_0'].numpy())

[[ 0.7996846  -0.02858142]]
