-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[onnx][quantization] Add JIT pass to insert permutes for conv ops #30679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…onv ops" Summary: Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 889aef9 Pull Request resolved: #30679
…onv ops" Summary: Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D18790518](https://our.internmc.facebook.com/intern/diff/D18790518) [ghstack-poisoned]
Summary: Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 190567f Pull Request resolved: #30679
…onv ops" Summary: Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D18790518](https://our.internmc.facebook.com/intern/diff/D18790518) [ghstack-poisoned]
Summary: Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 68d1d19 Pull Request resolved: #30679
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall the PR is good. Some inline comments to address.
| torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict) | ||
|
|
||
| # Insert permutes before and after each conv op to ensure correct order. | ||
| torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible we can use _jit_pass_custom_pattern_based_rewrite_graph to insert permutes as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did try that but got an assert of the form
terminate called after throwing an instance of 'c10::Error'
what(): n->owningGraph() == this && n->inBlockList() INTERNAL ASSERT FAILED at ../torch/csrc/jit/ir.h:1191
I suspect adding new ops to the graph using the rewriter isn't updating all the required dependencies correctly.
| op_node->insertInput(0, permute_node->output()); | ||
|
|
||
| Node* permute_node_after = graph->create( | ||
| Symbol::fromQualString("quantized::nhwc2nchw"), {input_node->output()}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why use input_node->output()? shall we use op_node->outputs()[0] here?
…onv ops" Summary: Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D18790518](https://our.internmc.facebook.com/intern/diff/D18790518) [ghstack-poisoned]
Summary: Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 22f7a0b Pull Request resolved: #30679
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Ship it.
|
This pull request has been merged in a51c5f5. |
Summary: Pull Request resolved: pytorch#30679 Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Imported from OSS Differential Revision: D18790518 fbshipit-source-id: 4dd39cf0b31b21f5586c0edfdce2260d4e245112
Stack from ghstack:
Summary:
Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW.
Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op.
Using graph rewriter to find consecutive redundant permutes and remove them from the graph
Test Plan:
python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D18790518