Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unwanted tf.function retracing when using variable-length inputs #38561

Closed
zaccharieramzi opened this issue Apr 15, 2020 · 27 comments
Closed

Unwanted tf.function retracing when using variable-length inputs #38561

zaccharieramzi opened this issue Apr 15, 2020 · 27 comments
Assignees
Labels
comp:keras Keras related issues TF 2.1 for tracking issues in 2.1 release TF 2.2 Issues related to TF 2.2 type:bug Bug

Comments

@zaccharieramzi
Copy link
Contributor

System information

  • Have I written custom code (as opposed to using a stock
    example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 16.04
  • TensorFlow installed from (source or binary): pip
  • TensorFlow version (use command below): 2.2.0rc2
  • Python version: 3.6.8

Describe the current behavior

A lot of warnings saying that there is a tf.function retracing are happening when using a keras model in a loop with variable length inputs.

Describe the expected behavior

I would like not to have retracing if there is no need (for example a fully convolutionnal model).

Standalone code to reproduce the issue

from random import randint

import tensorflow as tf
from tensorflow.keras.layers import Conv1D
from tensorflow.keras.models import Sequential

model = Sequential()
model.add(Conv1D(8, 3))
model.build([None, 12, 1])

predict_tensors = [
    tf.random.normal([randint(1, 8), randint(4, 40), 1])
    for _ in range(10)
]
for t in predict_tensors:
    _ = model.predict(t)

Other info / logs

Logs:

WARNING: Logging before flag parsing goes to stderr.
W0406 09:22:52.525994 139643050075904 def_function.py:598] 5 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f00a7fc1268> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
W0406 09:22:52.615050 139643050075904 def_function.py:598] 6 out of the last 7 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f00a7fc1268> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
W0406 09:22:52.653312 139643050075904 def_function.py:598] 7 out of the last 8 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f00a7fc1268> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
W0406 09:22:52.706550 139643050075904 def_function.py:598] 8 out of the last 10 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f00a7fc1268> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

This issue was originally described here, and some other people have had trouble with training as well.

When switching back to 2.1, the problem is gone.

@gurushantj
Copy link
Contributor

As per my understanding, if the input tensor's shape or dtype changes(if it is not constant) then the function would get retraced again.
You may refer this https://www.tensorflow.org/api_docs/python/tf/function

@zaccharieramzi
Copy link
Contributor Author

Yes this is totally true, but I am not using tf.function myself directly. Maybe keras is under the hood, but in any case they should handle inputs with varying shapes (but same rank and "compatible" shapes) better by for example specifying a dynamic input signature (see Inputs signatures in the doc).

Moreover, the behaviour I am describing is for version 2.2.0rc2, and the doc is still for 2.1 where there is no issue.

@ngc92
Copy link
Contributor

ngc92 commented Apr 15, 2020

You can see the current doc here:
https://www.tensorflow.org/api_docs/python/tf/function?version=nightly
I think the option you need should be experimental_relax_shapes.

As a workaround, you could try to wrap the keras model in an explicit tf.function call, like this

@tf.function(experimental_relax_shapes=True)
def predict(x):
     return model.predict(x)

@gurushantj
Copy link
Contributor

gurushantj commented Apr 15, 2020

Yes this is totally true, but I am not using tf.function myself directly. Maybe keras is under the hood, but in any case they should handle inputs with varying shapes (but same rank and "compatible" shapes) better by for example specifying a dynamic input signature (see Inputs signatures in the doc).

Moreover, the behaviour I am describing is for version 2.2.0rc2, and the doc is still for 2.1 where there is no issue.

Following is the ouput of tf 2.1.0, seems output is the same

/usr/local/bin/python3.7 /Users/gurushant/PycharmProjects/MTCNN/test6.py
2020-04-15 14:12:33.527382: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-04-15 14:12:33.545554: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fa56ad8c050 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-04-15 14:12:33.545588: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
WARNING:tensorflow:5 out of the last 5 calls to <function _make_execution_function.<locals>.distributed_function at 0x134d83290> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function _make_execution_function.<locals>.distributed_function at 0x134d83290> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:7 out of the last 7 calls to <function _make_execution_function.<locals>.distributed_function at 0x134d83290> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:8 out of the last 8 calls to <function _make_execution_function.<locals>.distributed_function at 0x134d83290> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:9 out of the last 9 calls to <function _make_execution_function.<locals>.distributed_function at 0x134d83290> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

@zaccharieramzi
Copy link
Contributor Author

@gurushantj yes you are right. I don't know why I thought this was in 2.1, it's actually in 2.0 that the problem is gone.

Still, the documentation regarding the re-tracing is about the same.

@ngc92 I tried your workaround but got the following error:

ValueError: When using data tensors as input to a model, you should specify the `steps` argument.

@gurushantj
Copy link
Contributor

gurushantj commented Apr 15, 2020

@gurushantj yes you are right. I don't know why I thought this was in 2.1, it's actually in 2.0 that the problem is gone.

Still, the documentation regarding the re-tracing is about the same.

@ngc92 I tried your workaround but got the following error:

ValueError: When using data tensors as input to a model, you should specify the `steps` argument.

Could you please validate following and let me know :

Disable eager execution setting

tf.compat.v1.disable_eager_execution()

and pass steps=1 in model.predict and validate

@amahendrakar
Copy link
Contributor

Was able to reproduce the issue with TF v2.1, TF v2.2.0rc3, TF-nightly. Please find the attached gist. Thanks!

@amahendrakar amahendrakar added TF 2.1 for tracking issues in 2.1 release TF 2.2 Issues related to TF 2.2 comp:autograph Autograph related issues labels Apr 15, 2020
@amahendrakar
Copy link
Contributor

@zaccharieramzi,
Could you please check this comment from a similar issue and let us know if it works? Thanks!

@zaccharieramzi
Copy link
Contributor Author

@amahendrakar I am not sure what I am supposed to see in that comment. The issue you linked suggests that this should be dealt with.

@zaccharieramzi
Copy link
Contributor Author

@ngc92 still got an error: AttributeError: 'Tensor' object has no attribute '_numpy'.

@ngc92
Copy link
Contributor

ngc92 commented Apr 15, 2020

Is this what you want to do?

@tf.function(experimental_relax_shapes=True)
def predict(t):
    return model(t)

for t in predict_tensors:
    _ = predict(t)

Note that you are no longer using any features of the model.predict function, but since you seem to be looping over examples by hand that might be OK.

Also, in tf 2.2 there is support for custom model.predict_function, i.e. you might be able to do something like

model.predict_function = tf.function(experimental_relax_shapes=True)(model.predict_function)

i.e. just wrapping the default provided function in something that relaxes shapes.
I haven't tried 2.2 yet, so I'm not very sure about the second suggestion.

@zaccharieramzi
Copy link
Contributor Author

@ngc92 yes this is a fair workaround. However there are cases where you would want to use predict for the callbacks or the batch size.

The second option you provided didn't work straight out of the box, but you can try things in tf 2.2 in colab: https://colab.research.google.com/drive/1MfRPQyRhjrF7he7fymoIEG7k64YCd0Da

You will notice that in the case of evaluate and I guess train if you feed the variable-length input through a tf dataset, it doesn't retrace the function, suggesting a bug somewhere.

@gowthamkpr gowthamkpr assigned mdanatg and unassigned gowthamkpr Apr 19, 2020
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 19, 2020
@mdanatg mdanatg added the comp:core issues related to core part of tensorflow label Apr 19, 2020
@mdanatg
Copy link

mdanatg commented Apr 19, 2020

We're investigating - it seems that a newly-added warning about function retracing seems to fire more than expected.

@kkimdev
Copy link
Contributor

kkimdev commented Apr 20, 2020

The error message is not new one so this seems from the existing retracing detection logic. I think the warning is WAI as it's tracing many times here. Perhaps Keras using experimental_relax_shapes is an option?

@mdanatg
Copy link

mdanatg commented Apr 20, 2020

@omalleyt12 @fchollet

@zaccharieramzi
Copy link
Contributor Author

@mdanatg do you have any news on this?

@mdanatg mdanatg added comp:keras Keras related issues and removed comp:autograph Autograph related issues comp:core issues related to core part of tensorflow labels May 8, 2020
@mdanatg
Copy link

mdanatg commented May 8, 2020

@zaccharieramzi No fix yet. According to the code the function that the warning talks about should be cached and only traced once.

@omalleyt12 any thoughts why the tracing happens so many times?

@zaccharieramzi
Copy link
Contributor Author

@mdanatg ok too bad, I just have one question though maybe you have the answer.
Do you know if the fix provided by @ngc92 , i.e.:

@tf.function(experimental_relax_shapes=True)
def predict(t):
    return model(t)

for t in predict_tensors:
    _ = predict(t)

still allows predict to benefit from a distribution strategy (typically MirroredStrategy)? My guess is that not but I am not sure, and not sure how to test this on a single GPU (2 logical GPUs).

@mdanatg
Copy link

mdanatg commented May 8, 2020

@guptapriya

@omalleyt12
Copy link
Contributor

@zaccharieramzi Thanks for the issue! This should be fixed in the latest nightly

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@zaccharieramzi
Copy link
Contributor Author

@omalleyt12 thanks ! One question I didn't ask though is: was this bug slowing down anything or I am just getting annoyed with an unwanted warning?

@chopwoodwater
Copy link

TF 2.3 still have this issue.

@RuralHunter
Copy link

The problem still persists with version 2.4.1 and the workaround doesn't work on predict_on_batch:

    @tf.function(experimental_relax_shapes=True)
    def predict_on_batch(self,states):
        return self.model.predict_on_batch(states)

tf reports error:
RuntimeError: Detected a call to Model.predict_on_batchinside atf.function. Model.predict_on_batch is a high-level endpoint that manages its own tf.function. Please move the call to Model.predict_on_batch outside of all enclosing tf.functions. Note that you can call a Model directly on Tensors inside a tf.function like: model(x).`

@RuralHunter
Copy link

OK, I found my problem is because of multi-thread calling the predict_on_batch function. I added an empty predict before launching the threads and the warning was gone.

@aakashba
Copy link

aakashba commented Apr 25, 2021

Have this issue with using model.predict inside a loop of 5 different models . The warning also leads to :

nvidia-smi
Failed to initialize NVML: Driver/library version mismatch

@A150852
Copy link

A150852 commented Sep 12, 2021

My observation - In multiprocessing setting, invoking predict() causes this warning and when processing large amounts of data it errors out eventually(may be memory leakage). Setting experimental_relax_shape=True for the function being invoked by multiple processors resolves the issue. Also using model(input) instead of model.predict(input) resolves the issue. So key issue seems to be due to retracing even when input shape changes. Issue persists even on using tensors as input instead of python object

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.1 for tracking issues in 2.1 release TF 2.2 Issues related to TF 2.2 type:bug Bug
Projects
None yet
Development

Successfully merging a pull request may close this issue.