New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Even in eager mode, Keras is passing custom models non-eager tensors in 'fit' #26268
Comments
Note that even adding |
I believe you need to pass m = MyModel(dynamic=True) Otherwise Keras assumes that your model is able to take symbolic tensors (and uses this for shape inference) even if you've asked it to compile in eager mode. This was added in d0373dc, which has a little more context. |
There is a bug here, though, in that Keras should probably tell you why this is failing. |
Thanks, that does the trick! There also might be a bit a documentation issue here - I don't recall this being prominently highlighted in the new tutorials and overviews of eager mode/keras. |
Thank you @jekbradbury for the quick response on this! We will definitely need to throw an error explaining the issue here. I will look into that. @malmaud I agree that this is not documented well. This concept is very new and we are working to see if we can remove the need for this hence the lack of documentation. |
Can we get some more explanation for how to add dynamic=True and what that does? |
I found a solution that works for me and posted it in this StackOverflow answer. tl;dr set the model.compile()
model.run_eagerly = True |
@brett-daley It works! Thank you. |
I am having the exact same problem with keras version 2.3.0-tf and tensorflow version 2.2.0. Adding dynamic=True gives me this error for a functional model:
|
I was able to fix my issue by subclassing tf.Keras.Model. Also, there was no need to pass |
@surGeonGG Can you be more specific please? I assume you did more than just simply subclassing because propagating the init via a super-call shouldn't make any difference. |
I am having the same issue when running this on 2.3.0. Doing this also doesnt resolve the issue model.compile()
model.run_eagerly = True |
Same with PrattJena, running on 2.3.0 and have same issue even with run_eagerly = True |
I also have this same issue. Any pointers towards a solution? |
Hi There, We are checking to see if you still need help on this, as you are using an older version of tensorflow which is officially considered end of life . We recommend that you upgrade to the latest 2.x version and let us know if the issue still persists in newer versions. Please open a new issue for any help you need against 2.x, and we will get you the right help. This issue will be closed automatically 7 days from now. If you still need help with this issue, please provide us with more information. |
Hi Alfred,
This issue has been resolved. Thanks so much for checking in though.
…On Fri, Feb 19, 2021 at 7:10 PM Alfred Sorten Wolf ***@***.***> wrote:
Hi There,
We are checking to see if you still need help on this, as you are using an
older version of tensorflow which is officially considered end of life . We
recommend that you upgrade to the latest 2.x version and let us know if the
issue still persists in newer versions. Please open a new issue for any
help you need against 2.x, and we will get you the right help.
This issue will be closed automatically 7 days from now. If you still need
help with this issue, please provide us with more information.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#26268 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AKEAU4WCOELDST4GRUNK6FDS72ZTZANCNFSM4G3H55ZQ>
.
|
Closing this issue since it is resolved. Feel free to reopen if necessary. Thanks! |
This isn't solved. I was able to reproduce the error even after passing the suggested line after the compile() function is called. |
System information
Describe the current behavior
When calling
keras.model.fit
on a custom model, it seems the model is passed a graph-mode tensor instead of an eager tenser, even when in eager mode.Describe the expected behavior
If in eager mode, the tensors passed to the
call
method of a custom model should be eager tensors. Otherwise, the advantages of eager mode, like the ability to use native control flow, are lost.Code to reproduce the issue
The text was updated successfully, but these errors were encountered: