Skip to content

Commit

Permalink
[ONNX] Relax not exist assertion for 'register_pytree_node'
Browse files Browse the repository at this point in the history
To not conflict with potential existing workaround or solution outside of exporter.
Latest huggingface/transformers main (>4.31) patches PyTorch PyTree with support over `ModelOutput` class.
`_PyTreeExtensionContext` is kept to support prior versions of transformers.

ghstack-source-id: 70e50b0e6cda42888a117188a9e5c18f3f40acd7
Pull Request resolved: #107245
  • Loading branch information
BowenBao committed Aug 15, 2023
1 parent bf197bf commit 5049896
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torch/onnx/_internal/fx/dynamo_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,14 @@ def register_pytree_node(
Raises:
AssertionError: If the custom python type is already registered.
"""
assert (
if (
class_type not in pytree.SUPPORTED_NODES
and class_type not in self._extensions
), "PyTree node already registered"
):
# PyTree node already registered.
# E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after
# https://github.com/huggingface/transformers/pull/25358.
return
self._extensions[class_type] = (flatten_func, unflatten_func)

def _register_huggingface_model_output_extension(self):
Expand Down

0 comments on commit 5049896

Please sign in to comment.