<a href="https://colab.research.google.com/github/zoyahav/tft_notebooks/blob/main/TFT_Cache_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorflow-transform

In [34]:
%tensorflow_version 2.x
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils
import apache_beam as beam

import tempfile
import pprint
import os

In [55]:
span_0_key = tft_beam.analyzer_cache.DatasetKey('span-0')
span_1_key = tft_beam.analyzer_cache.DatasetKey('span-1')

def preprocessing_fn(inputs):

  return {
      'x_mean':
          tft.mean(inputs['x'], name='x') + tf.zeros_like(inputs['x']),
  }
feature_spec = {
    'x': tf.io.FixedLenFeature([], tf.float32),
}

input_metadata = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec(feature_spec))

input_cache = {}


cache_dir = tempfile.mkdtemp()

#### Iteration 0 - only span-0 as analysis data, no cache #####

input_data_dict = {
    span_0_key: [{'x': x} for x in range(0, 100)],
}

filtered_analysis_dataset_keys = (
    tft_beam.analysis_graph_builder.get_analysis_dataset_keys(
        preprocessing_fn, feature_spec,
        list(input_data_dict.keys()), input_cache, True))
print('Analysis dataset keys required: {}'.format(filtered_analysis_dataset_keys))
assert len(filtered_analysis_dataset_keys) == 1

with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
  with beam.Pipeline() as p:
    transform_fn, output_cache = (
        (input_data_dict, input_cache, input_metadata) | tft_beam.AnalyzeDatasetWithCache(
            preprocessing_fn))
    
    output_cache | tft_beam.analyzer_cache.WriteAnalysisCacheToFS(p, cache_dir)
    
    transform_data = p | 'CreateTransformDataIteration0' >> beam.Create(input_data_dict[span_0_key])
    transformed_dataset = ((transform_data, input_metadata), transform_fn) | tft_beam.TransformDataset()
    transformed_data, transformed_metadata = transformed_dataset
    transformed_data | beam.combiners.Sample.FixedSizeGlobally(1) | beam.Map(pprint.pprint)



#### Iteration 1 - span-0 and span-1 as analysis data, span-0 has cache #####

input_data_dict.update({
    span_1_key: [{'x': x} for x in range(100, 200)],
})

with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
  with beam.Pipeline() as p:
    input_cache = p | 'ReadCache' >> tft_beam.analyzer_cache.ReadAnalysisCacheFromFS(cache_dir, [span_0_key, span_1_key])

    filtered_analysis_dataset_keys = (
        tft_beam.analysis_graph_builder.get_analysis_dataset_keys(
            preprocessing_fn, feature_spec,
            [span_0_key, span_1_key], input_cache, True))
    print('Analysis dataset keys required: {}'.format(filtered_analysis_dataset_keys))
    assert len(filtered_analysis_dataset_keys) == 1

    transform_fn, output_cache = (
        (input_data_dict, input_cache, input_metadata) | tft_beam.AnalyzeDatasetWithCache(
            preprocessing_fn))
    
    output_cache | tft_beam.analyzer_cache.WriteAnalysisCacheToFS(p, cache_dir)
    
    transform_data = p | 'CreateTransformDataIteration1' >> beam.Create(input_data_dict[span_0_key])
    transformed_dataset = ((transform_data, input_metadata), transform_fn) | tft_beam.TransformDataset()
    transformed_data, transformed_metadata = transformed_dataset
    transformed_data | beam.combiners.Sample.FixedSizeGlobally(1) | beam.Map(pprint.pprint)



#### Iteration 2 - No new data, no analysis needed #####

# No need to even read the dataset.
input_data_dict = {
    span_0_key: [],
    span_1_key: [],
}

with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
  with beam.Pipeline() as p:
    input_cache = p | 'ReadCache' >> tft_beam.analyzer_cache.ReadAnalysisCacheFromFS(cache_dir, [span_0_key, span_1_key])
    
    filtered_analysis_dataset_keys = (
        tft_beam.analysis_graph_builder.get_analysis_dataset_keys(
            preprocessing_fn, feature_spec,
            [span_0_key, span_1_key], input_cache, True))
    print('Analysis dataset keys required: {}'.format(filtered_analysis_dataset_keys))
    assert len(filtered_analysis_dataset_keys) == 0

    transform_fn, output_cache = (
        (input_data_dict, input_cache, input_metadata) | tft_beam.AnalyzeDatasetWithCache(
            preprocessing_fn))
    
    output_cache | tft_beam.analyzer_cache.WriteAnalysisCacheToFS(p, cache_dir)
    
    transform_data = p | 'CreateTransformDataIteration2' >> beam.Create([{'x': 1}])
    transformed_dataset = ((transform_data, input_metadata), transform_fn) | tft_beam.TransformDataset()
    transformed_data, transformed_metadata = transformed_dataset
    transformed_data | beam.combiners.Sample.FixedSizeGlobally(1) | beam.Map(pprint.pprint)

/tmp/tmppik594kh
Analysis dataset keys required: {DatasetKey(key='span-0')}








INFO:tensorflow:Assets written to: /tmp/tmpmr0w3pat/tftransform_tmp/557ca93612ef4b12a0112b85ab98fba6/assets


INFO:tensorflow:Assets written to: /tmp/tmpmr0w3pat/tftransform_tmp/557ca93612ef4b12a0112b85ab98fba6/assets


INFO:tensorflow:Assets written to: /tmp/tmpmr0w3pat/tftransform_tmp/1052ad75faa24a15b83fac079a71dad0/assets


INFO:tensorflow:Assets written to: /tmp/tmpmr0w3pat/tftransform_tmp/1052ad75faa24a15b83fac079a71dad0/assets










[{'x_mean': 49.5}]
Analysis dataset keys required: {DatasetKey(key='span-1')}
















INFO:tensorflow:Assets written to: /tmp/tmpnzio4s48/tftransform_tmp/221973b6727a4baeb42dc2bf58cf19ee/assets


INFO:tensorflow:Assets written to: /tmp/tmpnzio4s48/tftransform_tmp/221973b6727a4baeb42dc2bf58cf19ee/assets


INFO:tensorflow:Assets written to: /tmp/tmpnzio4s48/tftransform_tmp/6bd8f47e958944e18561a09cded3bd6b/assets


INFO:tensorflow:Assets written to: /tmp/tmpnzio4s48/tftransform_tmp/6bd8f47e958944e18561a09cded3bd6b/assets


[{'x_mean': 99.5}]
Analysis dataset keys required: set()
















INFO:tensorflow:Assets written to: /tmp/tmpjswz9sxq/tftransform_tmp/30ac848bd1c243a9bfeca632d377e729/assets


INFO:tensorflow:Assets written to: /tmp/tmpjswz9sxq/tftransform_tmp/30ac848bd1c243a9bfeca632d377e729/assets


[{'x_mean': 99.5}]
