Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quant][Inductor] Enable dequant promotion inside inductor #104590

Conversation

leslie-fang-intel
Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel commented Jul 4, 2023

Stack from ghstack (oldest at bottom):

Summary
Enable the dequant pattern promotion pass in inductor. Since in the qconv weight prepack pass, we will match the dequant->conv2d pattern. If the dequant pattern has multi user nodes, it will fail to be matched.
Taking the example of

        conv1
       /     \
   conv2    conv3

After quantization flow, it will generate pattern as

      dequant1 
          |
        conv1
          |
        quant2 
          |
       dequant2
       /     \
   conv2    conv3

We need to duplicate dequant2 into dequant2 and dequant3, in order to make dequant2->conv2 and dequant3->conv3 pattern matched.

Test Plan

python -m pytest test_mkldnn_pattern_matcher.py -k test_dequant_promotion

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 4, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104590

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit ff55f12 with merge base 97a291f (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

leslie-fang-intel added a commit that referenced this pull request Jul 4, 2023
ghstack-source-id: ce32fa370be90637f06f3598b103cf8cc255fea5
Pull Request resolved: #104590
@leslie-fang-intel leslie-fang-intel changed the title Enable dequant promotion inside inductor [Quant][Inductor] Enable dequant promotion inside inductor Jul 4, 2023
@leslie-fang-intel leslie-fang-intel marked this pull request as draft July 4, 2023 08:41
leslie-fang-intel added a commit to leslie-fang-intel/pytorch that referenced this pull request Jul 6, 2023
ghstack-source-id: ce32fa370be90637f06f3598b103cf8cc255fea5
Pull Request resolved: pytorch#104590
**Summary**
Enable the dequant node promotion pass in inductor.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request Jul 6, 2023
ghstack-source-id: 911e81e54cddb952a967388864e550e805abf739
Pull Request resolved: #104590
**Summary**
Enable the dequant node promotion pass in inductor.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request Jul 6, 2023
ghstack-source-id: 85b1c839a0930c1ee428ffe1e98acbf9cbaa882c
Pull Request resolved: #104590
leslie-fang-intel added a commit to leslie-fang-intel/pytorch that referenced this pull request Jul 6, 2023
ghstack-source-id: 85b1c839a0930c1ee428ffe1e98acbf9cbaa882c
Pull Request resolved: pytorch#104590
leslie-fang-intel added a commit that referenced this pull request Jul 6, 2023
ghstack-source-id: 0a45da5b3e5e8c2517a0903337b9f222406fcfc8
Pull Request resolved: #104590
**Summary**
Enable the dequant node promotion pass in inductor.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request Jul 6, 2023
ghstack-source-id: ca1ab17fe3d157cd7c9c131223932eb0b7d15acb
Pull Request resolved: #104590
**Summary**
Enable the dequant node promotion pass in inductor.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
leslie-fang-intel added a commit to leslie-fang-intel/pytorch that referenced this pull request Jul 7, 2023
ghstack-source-id: ca1ab17fe3d157cd7c9c131223932eb0b7d15acb
Pull Request resolved: pytorch#104590
leslie-fang-intel added a commit that referenced this pull request Jul 7, 2023
ghstack-source-id: ef586e03bf5af85ad7f191b712fd6dbc73eb6a6c
Pull Request resolved: #104590
**Summary**
Enable the dequant node promotion pass in inductor.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
leslie-fang-intel added a commit to leslie-fang-intel/pytorch that referenced this pull request Jul 7, 2023
ghstack-source-id: ef586e03bf5af85ad7f191b712fd6dbc73eb6a6c
Pull Request resolved: pytorch#104590
leslie-fang-intel added a commit that referenced this pull request Jul 7, 2023
ghstack-source-id: 876008389d339c717c25096c402f717c5839faee
Pull Request resolved: #104590
**Summary**
Enable the dequant node promotion pass in inductor.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
leslie-fang-intel added a commit to leslie-fang-intel/pytorch that referenced this pull request Aug 22, 2023
ghstack-source-id: bf12b7057e73abef38f9b7663099908a2456ce7f
Pull Request resolved: pytorch#104590
**Summary**
Enable the `dequant pattern` promotion pass in inductor. Since in the qconv weight prepack pass, we will match the `dequant->conv2d` pattern. If the `dequant pattern` has multi user nodes, it will fail to be matched.
Taking the example of
```
        conv1
       /     \
   conv2    conv3
```
After quantization flow, it will generate pattern as
```
      dequant1 
          |
        conv1
          |
        quant2 
          |
       dequant2
       /     \
   conv2    conv3
```
We need to duplicate `dequant2` into `dequant2` and `dequant3`, in order to make `dequant2->conv2` and  `dequant3->conv3`  pattern matched.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_dequant_promotion
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
leslie-fang-intel added a commit to leslie-fang-intel/pytorch that referenced this pull request Aug 25, 2023
ghstack-source-id: abf5033d4d4a7820bd4c17380f340320b0af014d
Pull Request resolved: pytorch#104590
**Summary**
Enable the `dequant pattern` promotion pass in inductor. Since in the qconv weight prepack pass, we will match the `dequant->conv2d` pattern. If the `dequant pattern` has multi user nodes, it will fail to be matched.
Taking the example of
```
        conv1
       /     \
   conv2    conv3
```
After quantization flow, it will generate pattern as
```
      dequant1 
          |
        conv1
          |
        quant2 
          |
       dequant2
       /     \
   conv2    conv3
```
We need to duplicate `dequant2` into `dequant2` and `dequant3`, in order to make `dequant2->conv2` and  `dequant3->conv3`  pattern matched.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_dequant_promotion
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
leslie-fang-intel added a commit to leslie-fang-intel/pytorch that referenced this pull request Aug 25, 2023
ghstack-source-id: 021dcdd27c054e541d34a2478e45bf90779c8fe0
Pull Request resolved: pytorch#104590
**Summary**
Enable the `dequant pattern` promotion pass in inductor. Since in the qconv weight prepack pass, we will match the `dequant->conv2d` pattern. If the `dequant pattern` has multi user nodes, it will fail to be matched.
Taking the example of
```
        conv1
       /     \
   conv2    conv3
```
After quantization flow, it will generate pattern as
```
      dequant1 
          |
        conv1
          |
        quant2 
          |
       dequant2
       /     \
   conv2    conv3
```
We need to duplicate `dequant2` into `dequant2` and `dequant3`, in order to make `dequant2->conv2` and  `dequant3->conv3`  pattern matched.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_dequant_promotion
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
**Summary**
Enable the `dequant pattern` promotion pass in inductor. Since in the qconv weight prepack pass, we will match the `dequant->conv2d` pattern. If the `dequant pattern` has multi user nodes, it will fail to be matched.
Taking the example of
```
        conv1
       /     \
   conv2    conv3
```
After quantization flow, it will generate pattern as
```
      dequant1 
          |
        conv1
          |
        quant2 
          |
       dequant2
       /     \
   conv2    conv3
```
We need to duplicate `dequant2` into `dequant2` and `dequant3`, in order to make `dequant2->conv2` and  `dequant3->conv3`  pattern matched.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_dequant_promotion
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@leslie-fang-intel
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Aug 25, 2023
… inside inductor (#105455)

**Summary**
Enable the `dequant-conv2d-unary_postop(relu)-quant` pattern fusion and lowering inside inductor.

**Test Plan**
```
clear && python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_unary
```

Pull Request resolved: #105455
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588, #104590
pytorchmergebot pushed a commit that referenced this pull request Aug 25, 2023
…rn fusion inside inductor (#105456)

**Summary**
Enable the `dequant-conv2d-binary_postop(add)-unary_postop(relu)-quant` pattern fusion and lowering inside inductor.

**Test Plan**
```
clear && python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_binary
```

Pull Request resolved: #105456
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588, #104590, #105455
pytorchmergebot pushed a commit that referenced this pull request Aug 26, 2023
…ol2d) (#105639)

**Summary**
In this PR, we mainly enable 2 things.

- Enable the skeleton of quantization recipe for single quantizable operators in `X86InductorQuantizer`.
- Add quantization recipe of `maxpool2d` and annotate it as input./output share observer.

**Test Plan**
```
python -m pytest test_x86inductor_quantizer.py -k test_maxpool2d_recipe
```

Pull Request resolved: #105639
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456
pytorchmergebot pushed a commit that referenced this pull request Aug 26, 2023
**Summary**
Enable the `dq-maxpool2d-q` pattern match and lower into `torch.ops.quantized.max_pool2d`.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_qmaxpool2d
python -m pytest test_quantized_op.py -k test_max_pool2d_pt2e
```

Pull Request resolved: #105906
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456, #105639
pytorchmergebot pushed a commit that referenced this pull request Aug 26, 2023
**Summary**
After oneDNN 3.1 upgrade, we don't need to do the weight scale reciprocal calculation. So, remove the redundant reciprocal calculation to optimize QConv performance and using IDeep version API to implement it in this PR:

- This QConv implementation expects to work functionally both with current IDeep version and the following IDeep upgrade in PR: #107565.
- With the following IDeep upgrade in PR: #107565, the QConv has better performance since the redundant reciprocal calculation are removed.

Pull Request resolved: #105996
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456, #105639, #105906
pytorchmergebot pushed a commit that referenced this pull request Aug 26, 2023
…ht scale reciprocal calculation (#107565)

**Summary**
Upgrade IDeep which includes 1 IDeep change as IDeep PR: intel/ideep#226

- For IDeep PR: intel/ideep#226 which has done 2 things:

  - Remove the redundant QConv weight scale reciprocal calculation.
  - Pump IDEEP_VERSION_REVISION version from 0 to 1.

  So only QConv related calculation will be impacted and we already use IDeep version API in #105996 to make the corresponding change in PyTorch.

Pull Request resolved: #107565
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456, #105639, #105906, #105996
voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
**Summary**
Enable the `dequant pattern` promotion pass in inductor. Since in the qconv weight prepack pass, we will match the `dequant->conv2d` pattern. If the `dequant pattern` has multi user nodes, it will fail to be matched.
Taking the example of
```
        conv1
       /     \
   conv2    conv3
```
After quantization flow, it will generate pattern as
```
      dequant1
          |
        conv1
          |
        quant2
          |
       dequant2
       /     \
   conv2    conv3
```
We need to duplicate `dequant2` into `dequant2` and `dequant3`, in order to make `dequant2->conv2` and  `dequant3->conv3`  pattern matched.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_dequant_promotion
```

Pull Request resolved: #104590
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588
voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
… inside inductor (#105455)

**Summary**
Enable the `dequant-conv2d-unary_postop(relu)-quant` pattern fusion and lowering inside inductor.

**Test Plan**
```
clear && python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_unary
```

Pull Request resolved: #105455
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588, #104590
voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
…rn fusion inside inductor (#105456)

**Summary**
Enable the `dequant-conv2d-binary_postop(add)-unary_postop(relu)-quant` pattern fusion and lowering inside inductor.

**Test Plan**
```
clear && python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_binary
```

Pull Request resolved: #105456
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588, #104590, #105455
voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
…ol2d) (#105639)

**Summary**
In this PR, we mainly enable 2 things.

- Enable the skeleton of quantization recipe for single quantizable operators in `X86InductorQuantizer`.
- Add quantization recipe of `maxpool2d` and annotate it as input./output share observer.

**Test Plan**
```
python -m pytest test_x86inductor_quantizer.py -k test_maxpool2d_recipe
```

Pull Request resolved: #105639
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456
voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
**Summary**
Enable the `dq-maxpool2d-q` pattern match and lower into `torch.ops.quantized.max_pool2d`.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_qmaxpool2d
python -m pytest test_quantized_op.py -k test_max_pool2d_pt2e
```

Pull Request resolved: #105906
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456, #105639
voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
**Summary**
After oneDNN 3.1 upgrade, we don't need to do the weight scale reciprocal calculation. So, remove the redundant reciprocal calculation to optimize QConv performance and using IDeep version API to implement it in this PR:

- This QConv implementation expects to work functionally both with current IDeep version and the following IDeep upgrade in PR: #107565.
- With the following IDeep upgrade in PR: #107565, the QConv has better performance since the redundant reciprocal calculation are removed.

Pull Request resolved: #105996
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456, #105639, #105906
voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
…ht scale reciprocal calculation (#107565)

**Summary**
Upgrade IDeep which includes 1 IDeep change as IDeep PR: intel/ideep#226

- For IDeep PR: intel/ideep#226 which has done 2 things:

  - Remove the redundant QConv weight scale reciprocal calculation.
  - Pump IDEEP_VERSION_REVISION version from 0 to 1.

  So only QConv related calculation will be impacted and we already use IDeep version API in #105996 to make the corresponding change in PyTorch.

Pull Request resolved: #107565
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456, #105639, #105906, #105996
@facebook-github-bot facebook-github-bot deleted the gh/leslie-fang-intel/54/head branch August 29, 2023 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants