-
Notifications
You must be signed in to change notification settings - Fork 19.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Catch optree exception when changing backends #20886
base: master
Are you sure you want to change the base?
Catch optree exception when changing backends #20886
Conversation
7df740a
to
f90a7a4
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20886 +/- ##
==========================================
- Coverage 82.24% 82.24% -0.01%
==========================================
Files 561 561
Lines 52647 52655 +8
Branches 8136 8136
==========================================
+ Hits 43302 43306 +4
- Misses 7341 7345 +4
Partials 2004 2004
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Calling E.g., when attempting to make a simple mnist model:
getting past that, attempting to call
Switching back to jax, recreating the same model:
I think it might be easier to restructure the project to avoid needing |
optree.register_pytree_node( | ||
ListWrapper, | ||
lambda x: (x, None), | ||
lambda metadata, children: ListWrapper(list(children)), | ||
namespace="keras", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing unregister first would be better than ignoring arbitrary exceptions.
with contextlib.suppress(ValueError):
optree.unregister_pytree_node(
ListWrapper,
namespace="keras",
)
optree.register_pytree_node(
ListWrapper,
lambda x: (x, None),
lambda metadata, children: ListWrapper(list(children)),
namespace="keras",
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the note! If I recall correctly, unregistering was unsupported for the types when I tried while drafting this PR. There was a new optree release recently, so maybe this changed. It's worth a try. I agree that approach would be better if it works.
Also, this set of changes is already on master after #21049
This PR fixes an exception raised when
keras.config.set_backend("tensorflow")
is called.To reproduce:
Example python session with all output: