Skip to content

Commit

Permalink
[ONNX] Relax not exist assertion for 'register_pytree_node'
Browse files Browse the repository at this point in the history
To not conflict with potential existing workaround or solution outside of exporter.
Latest huggingface/transformers main (>4.31) patches PyTorch PyTree with support over `ModelOutput` class.
`_PyTreeExtensionContext` is kept to support prior versions of transformers.

ghstack-source-id: 518a747bdb894e65a82c898161f512a9c8768dcc
Pull Request resolved: #107245
  • Loading branch information
BowenBao committed Aug 15, 2023
1 parent bf197bf commit 742ae58
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torch/onnx/_internal/fx/dynamo_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ def register_pytree_node(
Raises:
AssertionError: If the custom python type is already registered.
"""
assert (
class_type not in pytree.SUPPORTED_NODES
and class_type not in self._extensions
), "PyTree node already registered"
if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions:
# PyTree node already registered.
# E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after
# https://github.com/huggingface/transformers/pull/25358.
return
self._extensions[class_type] = (flatten_func, unflatten_func)

def _register_huggingface_model_output_extension(self):
Expand Down

0 comments on commit 742ae58

Please sign in to comment.