<a href="https://colab.research.google.com/github/zoyahav/tft_notebooks/blob/main/https_github_com_tensorflow_transform_issues_240.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorflow-transform

In [1]:
%tensorflow_version 2.x
import tensorflow_transform as tft

print(tft.__version__)

1.0.0


In [3]:
import tempfile
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils

import contextlib2

_TEMP_DIR = tempfile.mkdtemp()
_LABEL_KEY = "label"
_INT_DTYPE = tf.int64
_FLOAT_DTYPE = tf.float32


def _generate_data_vector():
    train_raw_data = [
        {_LABEL_KEY: [1]},
        {_LABEL_KEY: [0]},
        {_LABEL_KEY: [0]},
        {_LABEL_KEY: [0]},
    ]
    raw_data_metadata = dataset_metadata.DatasetMetadata(
        schema_utils.schema_from_feature_spec(
            {_LABEL_KEY: tf.io.FixedLenFeature(shape=[1], dtype=_INT_DTYPE),}
        )
    )
    
    return train_raw_data, raw_data_metadata

def _lookup_fn(y, deferred_vocab_filename_tensor):
    with contextlib2.ExitStack() as stack:
      if tf.executing_eagerly():
        stack.enter_context(tf.init_scope())
      initializer = tf.lookup.TextFileInitializer(
          filename=deferred_vocab_filename_tensor,
          key_dtype=tf.string,
          key_index=1,
          value_dtype=tf.int64,
          value_index=0,
          delimiter=' ')
      table = tf.lookup.StaticHashTable(initializer, default_value=-1)
      size = table.size()
    return table.lookup(y), size

def example_label_preprocessing(tensor: tf.Tensor) -> tf.Tensor:
    tensor = tf.squeeze(tensor, axis=-1)
    tensor = tf.as_string(tensor)
    key_vocab = tft.count_per_key(tensor, key_vocabulary_filename='abc')
    y = tft.apply_vocabulary(tensor, key_vocab, lookup_fn=_lookup_fn)
    return tf.strings.to_number(tensor)


def _train_and_retrieve_trained_data(train_data, feature_name):
    def preprocessing_fn(inputs):
        label = inputs[_LABEL_KEY]
        return {feature_name: example_label_preprocessing(tensor=label)}
    
    # apply preprocessing function to retrieve transform ops
    train_transformed = train_data | tft_beam.AnalyzeAndTransformDataset(
        preprocessing_fn
    )
    return train_transformed


with tft_beam.Context(
        temp_dir=_TEMP_DIR , force_tf_compat_v1=False
):
    _data = _generate_data_vector()
    _FEATURE_NAME = "feature"
    trained_data = _train_and_retrieve_trained_data(_data, _FEATURE_NAME)
    print(trained_data)













INFO:tensorflow:Assets written to: /tmp/tmpubhzvj5o/tftransform_tmp/8af014c4fca647daad7052eaa3e45c37/assets


INFO:tensorflow:Assets written to: /tmp/tmpubhzvj5o/tftransform_tmp/8af014c4fca647daad7052eaa3e45c37/assets


INFO:tensorflow:Assets written to: /tmp/tmpubhzvj5o/tftransform_tmp/72ea648b8696408eb8f888a801350735/assets


INFO:tensorflow:Assets written to: /tmp/tmpubhzvj5o/tftransform_tmp/72ea648b8696408eb8f888a801350735/assets


(([{'feature': 1.0}, {'feature': 0.0}, {'feature': 0.0}, {'feature': 0.0}], BeamDatasetMetadata(dataset_metadata={'_schema': feature {
  name: "feature"
  type: FLOAT
  presence {
    min_fraction: 1.0
  }
  shape {
  }
}
}, deferred_metadata=[{'_schema': feature {
  name: "feature"
  type: FLOAT
  presence {
    min_fraction: 1.0
  }
  shape {
  }
}
}], asset_map={})), (['/tmp/tmpubhzvj5o/tftransform_tmp/72ea648b8696408eb8f888a801350735'], BeamDatasetMetadata(dataset_metadata={'_schema': feature {
  name: "feature"
  type: FLOAT
  presence {
    min_fraction: 1.0
  }
  shape {
  }
}
}, deferred_metadata=[{'_schema': feature {
  name: "feature"
  type: FLOAT
  presence {
    min_fraction: 1.0
  }
  shape {
  }
}
}], asset_map={})))
