In [None]:
from __future__ import print_function

import os
import tempfile
import pandas as pd

import tensorflow as tf
import tensorflow_transform as tft
from tensorflow_transform import beam as tft_beam
import tfx_utils
from tfx.utils import io_utils
from tensorflow_metadata.proto.v0 import schema_pb2

# For DatasetMetadata boilerplate
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils

def _make_default_sqlite_uri(pipeline_name):
    return os.path.join(os.environ['HOME'], 'airflow/tfx/metadata', pipeline_name, 'metadata.db')

def get_metadata_store(pipeline_name):
    return tfx_utils.TFXReadonlyMetadataStore.from_sqlite_db(_make_default_sqlite_uri(pipeline_name))

pipeline_name = 'taxi'

pipeline_db_path = _make_default_sqlite_uri(pipeline_name)
print('Pipeline DB:\n{}'.format(pipeline_db_path))

store = get_metadata_store(pipeline_name)

In [None]:
# Get the schema URI from the metadata store
schemas = store.get_artifacts_of_type_df(tfx_utils.TFXArtifactTypes.SCHEMA)
print(schemas.URI)
schema_uri = schemas.URI.iloc[0] + '/schema.pbtxt'
print ('Schema URI:\n{}'.format(schema_uri))

In [None]:
schema_proto = io_utils.parse_pbtxt_file(file_name=schema_uri, message=schema_pb2.Schema())
feature_spec, domains = schema_utils.schema_as_feature_spec(schema_proto)
legacy_metadata = dataset_metadata.DatasetMetadata(schema_utils.schema_from_feature_spec(feature_spec, domains))

In [None]:
# Categorical features are assumed to each have a maximum value in the dataset.
_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12]

_NUMERICAL_FEATURES = ['trip_miles', 'fare', 'trip_seconds']

_BUCKET_FEATURES = [
    'pickup_latitude', 'pickup_longitude', 'dropoff_latitude',
    'dropoff_longitude'
]
# Number of buckets used by tf.transform for encoding each feature.
_FEATURE_BUCKET_COUNT = 10

_CATEGORICAL_NUMERICAL_FEATURES = [
    'trip_start_hour', 'trip_start_day', 'trip_start_month',
    'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area',
    'dropoff_community_area'
]

_CATEGORICAL_STRING_FEATURES = [
    'payment_type',
    'company',
]

# Number of vocabulary terms used for encoding categorical features.
_VOCAB_SIZE = 1000

# Count of out-of-vocab buckets in which unrecognized categorical are hashed.
_OOV_SIZE = 10

# Keys
_LABEL_KEY = 'tips'
_FARE_KEY = 'fare'


def t_name(key):
  """
  Rename the feature keys so that they don't clash with the raw keys when
  running the Evaluator component.
  Args:
    key: The original feature key
  Returns:
    key with '_xf' appended
  """
  return key + '_xf'


def _make_one_hot(x, key):
  """Make a one-hot tensor to encode categorical features.
  Args:
    X: A dense tensor
    key: A string key for the feature in the input
  Returns:
    A dense one-hot tensor as a float list
  """
  integerized = tft.compute_and_apply_vocabulary(x,
          top_k=_VOCAB_SIZE,
          num_oov_buckets=_OOV_SIZE,
          vocab_filename=key, name=key)
  depth = (
      tft.experimental.get_vocabulary_size_by_name(key) + _OOV_SIZE)
  one_hot_encoded = tf.one_hot(
      integerized,
      depth=tf.cast(depth, tf.int32),
      on_value=1.0,
      off_value=0.0)
  return tf.reshape(one_hot_encoded, [-1, depth])


def _fill_in_missing(x):
  """Replace missing values in a SparseTensor.
  Fills in missing values of `x` with '' or 0, and converts to a dense tensor.
  Args:
    x: A `SparseTensor` of rank 2.  Its dense shape should have size at most 1
      in the second dimension.
  Returns:
    A rank 1 tensor where missing values of `x` have been filled in.
  """
  if not isinstance(x, tf.sparse.SparseTensor):
    return x

  default_value = '' if x.dtype == tf.string else 0
  return tf.squeeze(
      tf.sparse.to_dense(
          tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
          default_value),
      axis=1)

def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.

  Returns:
    Map from string feature key to transformed feature operations.
  """
  outputs = {}
  for key in _NUMERICAL_FEATURES:
    # If sparse make it dense, setting nan's to 0 or '', and apply zscore.
    outputs[t_name(key)] = tft.scale_to_z_score(
        _fill_in_missing(inputs[key]), name=key)

  for key in _BUCKET_FEATURES:
    outputs[t_name(key)] = tf.cast(tft.bucketize(
            _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT, name=key),
            dtype=tf.float32)

  for key in _CATEGORICAL_STRING_FEATURES:
    outputs[t_name(key)] = _make_one_hot(_fill_in_missing(inputs[key]), key)

  for key in _CATEGORICAL_NUMERICAL_FEATURES:
    outputs[t_name(key)] = _make_one_hot(_fill_in_missing(inputs[key]), key)

  # Was this passenger a big tipper?
  taxi_fare = _fill_in_missing(inputs[_FARE_KEY])
  tips = _fill_in_missing(inputs[_LABEL_KEY])
  outputs[_LABEL_KEY] = tf.where(
      tf.math.is_nan(taxi_fare),
      tf.cast(tf.zeros_like(taxi_fare), tf.int64),
      # Test if the tip was > 20% of the fare.
      tf.cast(
          tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64))

  return outputs

In [None]:
from IPython.display import display
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
    raw_examples = [
        {
            "fare": [100.0],
            "trip_start_hour": [12],
            "pickup_census_tract": ['abcd'],
            "dropoff_census_tract": [12345], 
            "company": ['taxi inc.'],
            "trip_start_timestamp": [123456],
            "pickup_longitude": [12.0],
            "trip_start_month": [5],
            "trip_miles": [8.0],
            "dropoff_longitude": [12.05],
            "dropoff_community_area": [123],
            "pickup_community_area": [123],
            "payment_type": ['visa'],
            "trip_seconds": [600],
            "trip_start_day": [12],
            "tips": [10.0],
            "pickup_latitude": [80.0],
            "dropoff_latitude": [80.01],
        }
    ]
    (transformed_examples, transformed_metadata), transform_fn = (
        (raw_examples, legacy_metadata)
        | 'AnalyzeAndTransform' >> tft_beam.AnalyzeAndTransformDataset(
            preprocessing_fn))
    display(pd.DataFrame(transformed_examples))