Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot convert model containing categorical_column_with_vocabulary_list op #37844

Closed
icoffeebeans opened this issue Mar 23, 2020 · 18 comments
Closed
Assignees
Labels
comp:lite TF Lite related issues TF 2.2 Issues related to TF 2.2 TFLiteConverter For issues related to TFLite converter type:bug Bug

Comments

@icoffeebeans
Copy link

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):CentOS
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (or github SHA if from source): 2.2.0rc0

Command used to run the converter or code if you’re using the Python API
If possible, please share a link to Colab/Jupyter/any notebook.

import tensorflow as tf
import os

model_dir = "models/feature_column_example"
category = tf.constant(["A", "B", "A", "C", "C", "A"])
label = tf.constant([1, 0, 1, 0, 0, 0])

ds = tf.data.Dataset.from_tensor_slices(({"category": category}, label))
ds = ds.batch(2)

fc_category = tf.feature_column.indicator_column(
    tf.feature_column.categorical_column_with_vocabulary_list(
        "category", vocabulary_list=["A", "B", "C"]
    )
)
feature_layer = tf.keras.layers.DenseFeatures([fc_category])

model = tf.keras.Sequential(
    [
        feature_layer,
        tf.keras.layers.Dense(10, activation="relu"),
        tf.keras.layers.Dense(1, activation="sigmoid"),
    ]
)
model.compile(
    optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
)

model.fit(ds, epochs=2)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.allow_custom_ops = True
# converter.experimental_new_converter = True
# converter.experimental_new_quantizer = True

# Convert the model.
tflite_model = converter.convert()
open(os.path.join(model_dir, "output.tflite"), "wb").write(tflite_model)

The output from the converter invocation

Cannot convert a Tensor of dtype resource to a NumPy array.

Also, please include a link to the saved model or GraphDef

saved_model_cli show --dir models/feature_column_example/ --tag_set serve --signature_def serving_default

The given SavedModel SignatureDef contains the following input(s):
  inputs['category'] tensor_info:
      dtype: DT_STRING
      shape: (-1, 1)
      name: serving_default_category:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['output_1'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 1)
      name: StatefulPartitionedCall_7:0
Method name is: tensorflow/serving/predict

Failure details
Cannot convert a Tensor of dtype resource to a NumPy array.

According to my analysis, this might be caused by some HashTable Ops, which create table handles. And my additional question is: whether tfliteconverter could convert model contains ops of initialize hashtableV2 and LookupTableImportV2? Thank you.

Any other info / logs: Full logs

Traceback (most recent call last):
  File "feature_column_example.py", line 62, in <module>
    export_keras_hashtable_model()
  File "feature_column_example.py", line 58, in export_keras_hashtable_model
    tflite_model = converter.convert()
  File "/root/tf2.2/lib/python3.6/site-packages/tensorflow/lite/python/lite.py", line 464, in convert
    self._funcs[0], lower_control_flow=False))
  File "/root/tf2.2/lib/python3.6/site-packages/tensorflow/python/framework/convert_to_constants.py", line 706, in convert_variables_to_constants_v2_as_graph
    func, lower_control_flow, aggressive_inlining)
  File "/root/tf2.2/lib/python3.6/site-packages/tensorflow/python/framework/convert_to_constants.py", line 457, in _convert_variables_to_constants_v2_impl
    tensor_data = _get_tensor_data(func)
  File "/root/tf2.2/lib/python3.6/site-packages/tensorflow/python/framework/convert_to_constants.py", line 217, in _get_tensor_data
    data = val_tensor.numpy()
  File "/root/tf2.2/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 961, in numpy
    maybe_arr = self._numpy()  # pylint: disable=protected-access
  File "/root/tf2.2/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 929, in _numpy
    six.raise_from(core._status_to_exception(e.code, e.message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.

@icoffeebeans icoffeebeans added the TFLiteConverter For issues related to TFLite converter label Mar 23, 2020
@amahendrakar amahendrakar added TF 2.2 Issues related to TF 2.2 type:support Support issues comp:lite TF Lite related issues labels Mar 24, 2020
@amahendrakar
Copy link
Contributor

Was able to reproduce the issue with TF v2.2.0-rc1 and TF-nightly. Please find the attached gist. Thanks!

@amahendrakar amahendrakar added type:bug Bug and removed type:support Support issues labels Mar 24, 2020
@amahendrakar amahendrakar assigned ymodak and unassigned amahendrakar Mar 24, 2020
@icoffeebeans
Copy link
Author

icoffeebeans commented Mar 24, 2020

Thanks for the prompt response to this issue.

I also read the official document about operator compatibility, saying that, ops such as EMBEDDING_LOOKUP, HASHTABLE_LOOKUP, are present, but not ready for custom models. Is it doable to write my own tflite kernels for these lookup ops? Or the tflite converter does not support the correlated operators? Thanks~!

@abattery
Copy link
Contributor

abattery commented Mar 27, 2020

EMBEDDING_LOOKUP is already supported via TensorFlow Lite builtin ops. And the experimental hashtable op kernels are existing under the following directory: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/kernels.
In order to enable them in Python, here is the example code.

with tf.Session() as sess:
    int64_values = tf.constant([1, 2, 3], dtype=tf.int64)
    string_values = tf.constant(['bar', 'foo', 'baz'], dtype=tf.string)
    int64_to_string_initializer = tf.lookup.KeyValueTensorInitializer(
        int64_values, string_values)
    string_to_int64_initializer = tf.lookup.KeyValueTensorInitializer(
        string_values, int64_values)
    int64_to_string_table = tf.lookup.StaticHashTable(
        int64_to_string_initializer, '7')
    string_to_int64_table = tf.lookup.StaticHashTable(
        string_to_int64_initializer, 4)

    with tf.control_dependencies([tf.initializers.tables_initializer()]):
      input_int64_tensor = tf.placeholder(tf.int64, shape=[1])
      input_string_tensor = tf.placeholder(tf.string, shape=[1])
      out_string_tensor = int64_to_string_table.lookup(input_int64_tensor)
      out_int64_tensor = string_to_int64_table.lookup(input_string_tensor)

    graph_def = tf.get_default_graph().as_graph_def()
    tf.io.write_graph(graph_def, '/tmp/', 'hashtable.pbtxt')

  converter = tf.lite.TFLiteConverter(graph_def,
                                      [input_int64_tensor, input_string_tensor],
                                      [out_string_tensor, out_int64_tensor])

  supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
  converter.target_spec.supported_ops = supported_ops
  converter.allow_custom_ops = True
  tflite_model = converter.convert()

  # Initialize interpreter with HashTableV2 custom ops.
  model_interpreter = interpreter_wrapper.InterpreterWithCustomOps(
      model_content=tflite_model, custom_op_registerers=['AddHashtableOps'])

  input_details = model_interpreter.get_input_details()
  output_details = model_interpreter.get_output_details()
  print('Input details: ', input_details)
  print('Output details: ', output_details)

  model_interpreter.allocate_tensors()
  model_interpreter.reset_all_variables()
  model_interpreter.set_tensor(input_details[0]['index'],
                               np.array([2], dtype=np.int64))
  model_interpreter.set_tensor(input_details[1]['index'],
                               np.array(['foo'], dtype=np.string_))
  model_interpreter.invoke()

  for out in output_details:
    result = model_interpreter.get_tensor(out['index'])  # Expect no errors.
    print('Result tensor: ', out['name'], result.shape, result)

The AddHashtableOps Python module exists in here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/kernels/BUILD#L177

If you want to make your own custom ops, please check out https://www.tensorflow.org/lite/guide/ops_custom.

For the below error at the conversion, currently, we are trying to fix it soon hopefully. I will leave a comment when it is fixed in the nightly build.

Cannot convert a Tensor of dtype resource to a NumPy array.

@abattery abattery self-assigned this Mar 27, 2020
@ymodak ymodak removed their assignment Mar 28, 2020
@icoffeebeans
Copy link
Author

icoffeebeans commented Mar 31, 2020

@abattery Hi Jae Sung, thanks a lot for your feedback. Actually I saw your commits of hashtable ops in tensorflow project. Thus I was looking forward for your reply when I created this issue: )

After running the above example, I still have a couple of questions.

  1. Does the code only work for tfv2? I've tried it on both tfv2.2 and tfv1.15. The tflite file converted by tfv1.15 does not contain LookupTableImportV2 op. Whether this conversion can be success in tf1.15, in anyway?

tflite_converted_by_tf1_15

  1. What is the function of with tf.control_dependencies([tf.initializers.tables_initializer()]) to a graph that is converted to tflite? If I don't use this context statement, but instead, use separated tf.initializers.tables_initializer(), it also cannot be converted to a correct graph.

tf22_no_dep

Thank you again!

@abattery
Copy link
Contributor

Hi,

I have tested with both TF 1.15.0 version and TF 2.2.0rc2 version with the above sample. As you said, the model generated from TF 1.15.0 version does not contains LookupTableImportV2 node. In your case, is it okay to use TF 2.x version? Thank you for reporting a bug.

The tf.control_dependencies method creates an explicit dependency between things in a graph. For example, it is required because the LookupTableImportV2 op node should be located before the LookupTableFindV2 op node appears.

Best regards,
Jaesung

@lizhen2017
Copy link

lizhen2017 commented Apr 13, 2020

Hi,

I have tested with both TF 1.15.0 version and TF 2.2.0rc2 version with the above sample. As you said, the model generated from TF 1.15.0 version does not contains LookupTableImportV2 node. In your case, is it okay to use TF 2.x version? Thank you for reporting a bug.

The tf.control_dependencies method creates an explicit dependency between things in a graph. For example, it is required because the LookupTableImportV2 op node should be located before the LookupTableFindV2 op node appears.

Best regards,
Jaesung
Hi, I've tried to complie the tf 2.x source code, and pip install it,but when I run the example above, I met no symbol AddHashtableOps.
Thanks for your help. @abattery

@abattery
Copy link
Contributor

Removed AddHashtableOps support in Python temporarily. However, you can still add this to an interpreter in C++.

@lizhen2017
Copy link

Removed AddHashtableOps support in Python temporarily. However, you can still add this to an interpreter in C++.
Would you please show me a demo? Thank you very much!

@abattery
Copy link
Contributor

How to Include Hashtable ops in your TFLite.

Currently, hashtable ops are under the experimental stage. You need to add hashtable ops manually by including the following dependency:

"//tensorflow/lite/experimental/kernels:hashtable_op_kernels"

And then, your op resolver should add them like the following statements:

  // Add hashtable op handlers.
  resolver->AddCustom("HashTableV2",    
                      tflite::ops::custom::Register_HASHTABLE());
  resolver->AddCustom("LookupTableFindV2",
                      tflite::ops::custom::Register_HASHTABLE_FIND());
  resolver->AddCustom("LookupTableImportV2",
                      tflite::ops::custom::Register_HASHTABLE_IMPORT());
  resolver->AddCustom("LookupTableSizeV2",
                      tflite::ops::custom::Register_HASHTABLE_SIZE());

@icoffeebeans
Copy link
Author

@abattery Hi Jaesung, thanks a lot for your reply. Since the TF1.15.0 of export+convert cannot success, I'm wondering if model exported from TF version under 2.x could be converted correctly?

I have tested with both TF 1.15.0 version and TF 2.2.0rc2 version with the above sample. As you said, the model generated from TF 1.15.0 version does not contains LookupTableImportV2 node. In your case, is it okay to use TF 2.x version? Thank you for reporting a bug.

Say the currently lookup example you posted previously: I've tried to firstly export with model using SavedModelBuilder with TF v1.15.0 and converted the saved model using TF2.2.0, but still not success. The tflite file shows a missing of LookupImportV2 Op. I'm wondering if there's any methods could make this work: TF v1.15.0 export model and TF2.2 convert?

@abattery
Copy link
Contributor

Model should need to store a explicit dependency between LookupImportV2 op and other Lookup ops. However, it seems that TF v1.15.0 could not store that information.

If the information is already lost, the TF 2.2 converter can not revive the missing information.

@icoffeebeans
Copy link
Author

@abattery May I ask for some details about the dependency? Is it the Python API part or C++ kernels? I'm wondering if I could get this done through refactor some tensorflow code.

@abattery
Copy link
Contributor

Could you try setting with "drop_control_dependency=False" in the TFLiteConverterV1?

@tleyden
Copy link
Contributor

tleyden commented Jan 24, 2021

@abattery is this still the case?

Removed AddHashtableOps support in Python temporarily. However, you can still add this to an interpreter in C++.

If it's been added back, do you have any example code on how to use it from python? As mentioned in this stack overflow post, I was able to add converter.allow_custom_ops = True to get past tf.HashTableV2 missing custom implemenation errors during conversion to tflite, however I'm not clear on how to perform inference with the tflite model.

@abattery
Copy link
Contributor

Please take a look at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/hashtable/README.md in order to use the provided hash table op kernels as a custom op library.

@sushreebarsa
Copy link
Contributor

Was able to replicate the issue with TF v2.5,please find the gist here ..Thanks!

@abattery
Copy link
Contributor

This above converter error can be gone when the saved model converter is choosen.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:lite TF Lite related issues TF 2.2 Issues related to TF 2.2 TFLiteConverter For issues related to TFLite converter type:bug Bug
Projects
None yet
Development

No branches or pull requests

7 participants