New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Register list/tuple/dict to format_argumment and refactor fx.Node format_argument in diagnostics #105263
Conversation
…ode format_argument in diagnostics [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105263
Note: Links to docs will display an error until the docs builds have been completed. ✅ 4 Unrelated FailuresAs of commit 2d6fe3b: UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ode format_argument in diagnostics ghstack-source-id: e4a89ad6f1fe3d663ea3327a42ebc06071def218 Pull Request resolved: #105263
…factor fx.Node format_argument in diagnostics" [ghstack-poisoned]
…ode format_argument in diagnostics ghstack-source-id: ed61ca152af9cfe493ddd9ed0ef3427348a92cf8 Pull Request resolved: #105263
Due to the need of supporting None args, we need to support None (torch.OptionalType) to get shape/dtype. To support pytorch/pytorch#105263 --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
…factor fx.Node format_argument in diagnostics" Previous to this PR, the SARIF reports don't have detail on torch.fx.Node (shape/dtype), and don't unpack the tuple/list/dict. This PR provides thorough information of args/kwargs from torch in fx.graph expression: f32[64, 64, 2] (dtype[shape]). Need microsoft/onnxscript#890 ![dispatcher_sarif](https://github.com/pytorch/pytorch/assets/18010845/2567fac6-4154-4ce8-bc34-83950ef1c1d7) [ghstack-poisoned]
…ode format_argument in diagnostics ghstack-source-id: 50421618df5cd977aeadd5aa68cfc0b0bcadd2e9 Pull Request resolved: #105263
…factor fx.Node format_argument in diagnostics" Previous to this PR, the SARIF reports don't have detail on torch.fx.Node (shape/dtype), and don't unpack the tuple/list/dict. This PR provides thorough information of args/kwargs from torch in fx.graph expression: f32[64, 64, 2] (dtype[shape]). Need microsoft/onnxscript#890 ![dispatcher_sarif](https://github.com/pytorch/pytorch/assets/18010845/2567fac6-4154-4ce8-bc34-83950ef1c1d7) [ghstack-poisoned]
…ode format_argument in diagnostics ghstack-source-id: b6ce5444604f9bb70def741efde620bfc07c6cf1 Pull Request resolved: #105263
…factor fx.Node format_argument in diagnostics" Previous to this PR, the SARIF reports don't have detail on torch.fx.Node (shape/dtype), and don't unpack the tuple/list/dict. This PR provides thorough information of args/kwargs from torch in fx.graph expression: f32[64, 64, 2] (dtype[shape]). Need microsoft/onnxscript#890 ![dispatcher_sarif](https://github.com/pytorch/pytorch/assets/18010845/2567fac6-4154-4ce8-bc34-83950ef1c1d7) [ghstack-poisoned]
…ode format_argument in diagnostics ghstack-source-id: 879f7976faf3264d61093be1c65a819cbb83fe03 Pull Request resolved: #105263
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice improvement. I would consider covering more complex data types which combine data structs, such as dict with tuples as values, etc
…factor fx.Node format_argument in diagnostics" Previous to this PR, the SARIF reports don't have detail on torch.fx.Node (shape/dtype), and don't unpack the tuple/list/dict. This PR provides thorough information of args/kwargs from torch in fx.graph expression: f32[64, 64, 2] (dtype[shape]). Need microsoft/onnxscript#890 ![dispatcher_sarif](https://github.com/pytorch/pytorch/assets/18010845/2567fac6-4154-4ce8-bc34-83950ef1c1d7) [ghstack-poisoned]
…ode format_argument in diagnostics ghstack-source-id: 0c6507f7624380ce7983069b0cae855621a526c4 Pull Request resolved: #105263
@BowenBao PTAL |
test/onnx/test_fx_op_consistency.py
Outdated
xfail( | ||
"nn.functional.adaptive_avg_pool2d", | ||
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.index.Tensor"), | ||
reason=onnx_test_common.reason_dynamo_does_not_support("RecursionError: maximum recursion depth exceeded \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decompose pass.. probably our bug not dynamo's?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm getting this error onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (AveragePool_0) Op (AveragePool) [ShapeInferenceError] Attribute strides has incorrect size
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will try
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just repro. Maybe try the latest commit of onnx-script. I had an update on avg_pool in torchlib.
|
||
@_format_argument.register | ||
def _torch_fx_symbolic_value( | ||
obj, # NOTE: functools.singledispatch does not support Union until 3.11, so we use Any here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if a silly approach of creating 3 different functions for it is viable :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should then allow format_argument
to be used instead of _format_nested_argument_by_dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LG w/ comment.
A few follow ups
- Add
format_argument
forSymbolicFunction
. Not sure what is the best format, but right now it is only showing class name. We might be interested in the overload, whether it is custom, and maybe param schema but that might be too verbose. - Avoid printing long tuple/list/dict. I think a few cases like decomp table, list of nodes of fx graph might fall into this category.
- Add
format_argument
for primitive types like int. This probably should go into baseformat_argument
inside the base diagnostic infra package.
…factor fx.Node format_argument in diagnostics" Previous to this PR, the SARIF reports don't have detail on torch.fx.Node (shape/dtype), and don't unpack the tuple/list/dict. This PR provides thorough information of args/kwargs from torch in fx.graph expression: f32[64, 64, 2] (dtype[shape]). Need microsoft/onnxscript#890 ![dispatcher_sarif](https://github.com/pytorch/pytorch/assets/18010845/2567fac6-4154-4ce8-bc34-83950ef1c1d7) [ghstack-poisoned]
…ode format_argument in diagnostics ghstack-source-id: 758c2290f185fa83d31acc7d9b82761d9c99a521 Pull Request resolved: #105263
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Previous to this PR, the SARIF reports don't have detail on torch.fx.Node (shape/dtype), and don't unpack the tuple/list/dict. This PR provides thorough information of args/kwargs from torch in fx.graph expression: f32[64, 64, 2] (dtype[shape]).
Need microsoft/onnxscript#890