-
Notifications
You must be signed in to change notification settings - Fork 74.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unwanted tf.function retracing when using variable-length inputs #38561
Comments
As per my understanding, if the input tensor's shape or dtype changes(if it is not constant) then the function would get retraced again. |
Yes this is totally true, but I am not using Moreover, the behaviour I am describing is for version 2.2.0rc2, and the doc is still for 2.1 where there is no issue. |
You can see the current doc here: As a workaround, you could try to wrap the keras model in an explicit tf.function call, like this
|
Following is the ouput of tf 2.1.0, seems output is the same
|
@gurushantj yes you are right. I don't know why I thought this was in 2.1, it's actually in 2.0 that the problem is gone. Still, the documentation regarding the re-tracing is about the same. @ngc92 I tried your workaround but got the following error:
|
Could you please validate following and let me know : Disable eager execution setting
and pass |
Was able to reproduce the issue with TF v2.1, TF v2.2.0rc3, TF-nightly. Please find the attached gist. Thanks! |
@zaccharieramzi, |
@amahendrakar I am not sure what I am supposed to see in that comment. The issue you linked suggests that this should be dealt with. |
@ngc92 still got an error: |
Is this what you want to do?
Note that you are no longer using any features of the Also, in tf 2.2 there is support for custom
i.e. just wrapping the default provided function in something that relaxes shapes. |
@ngc92 yes this is a fair workaround. However there are cases where you would want to use The second option you provided didn't work straight out of the box, but you can try things in tf 2.2 in colab: https://colab.research.google.com/drive/1MfRPQyRhjrF7he7fymoIEG7k64YCd0Da You will notice that in the case of |
We're investigating - it seems that a newly-added warning about function retracing seems to fire more than expected. |
The error message is not new one so this seems from the existing retracing detection logic. I think the warning is WAI as it's tracing many times here. Perhaps Keras using |
@mdanatg do you have any news on this? |
@zaccharieramzi No fix yet. According to the code the function that the warning talks about should be cached and only traced once. @omalleyt12 any thoughts why the tracing happens so many times? |
@mdanatg ok too bad, I just have one question though maybe you have the answer. @tf.function(experimental_relax_shapes=True)
def predict(t):
return model(t)
for t in predict_tensors:
_ = predict(t) still allows |
@zaccharieramzi Thanks for the issue! This should be fixed in the latest nightly |
@omalleyt12 thanks ! One question I didn't ask though is: was this bug slowing down anything or I am just getting annoyed with an unwanted warning? |
TF 2.3 still have this issue. |
The problem still persists with version 2.4.1 and the workaround doesn't work on predict_on_batch:
tf reports error: |
OK, I found my problem is because of multi-thread calling the predict_on_batch function. I added an empty predict before launching the threads and the warning was gone. |
Have this issue with using model.predict inside a loop of 5 different models . The warning also leads to : nvidia-smi |
My observation - In multiprocessing setting, invoking predict() causes this warning and when processing large amounts of data it errors out eventually(may be memory leakage). Setting experimental_relax_shape=True for the function being invoked by multiple processors resolves the issue. Also using model(input) instead of model.predict(input) resolves the issue. So key issue seems to be due to retracing even when input shape changes. Issue persists even on using tensors as input instead of python object |
System information
example script provided in TensorFlow): Yes
Describe the current behavior
A lot of warnings saying that there is a tf.function retracing are happening when using a keras model in a loop with variable length inputs.
Describe the expected behavior
I would like not to have retracing if there is no need (for example a fully convolutionnal model).
Standalone code to reproduce the issue
Other info / logs
Logs:
This issue was originally described here, and some other people have had trouble with training as well.
When switching back to 2.1, the problem is gone.
The text was updated successfully, but these errors were encountered: