New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Support quantized::linear_relu
#109755
[ONNX] Support quantized::linear_relu
#109755
Conversation
gustavla
commented
Sep 21, 2023
•
edited
edited
- Adds support for quantized::linear_relu
- Adds weight unpacking pattern matcher
- Adds to export for opset 10 and 13.
- Adds QAT test modeled after conv2d+relu fusion test
- Adds support for quantized::linear_relu - Adds weight packing pattern matcher - Adds to export for opset 10 and 13. - Adds QAT test modeled after conv2d+relu fusion
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109755
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 641fe7c with merge base 4e3b032 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This has me confused, because the bias and shapes didn't match the in_feature/out_feature of the Conv2d that I was basing the Linear code on. This is because it was defining the ops incorrectly and then re-writing the weights swapping the inputs/outputs channels. It never complained that the original definition had a mismatch. I have fixed this for Linear now, but also fixed it for three cases of Conv2d.
Note, I also spotted some typos in the Conv2d definitions and have also pushed a fix for them. This didn't cause any issues, because the weights were replaced anyway, thus changing the definition of the op. However, it is confusing and should be corrected. Note that the in/out channels are swapped between how a Conv2d is constructed and the actual weights, but currently in the code they use the same order:
|
@@ -13091,7 +13091,7 @@ class M(torch.nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.quant = torch.ao.quantization.QuantStub() | |||
self.conv = torch.nn.Conv2d(2, 4, 3, stride=2) | |||
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment about these typos. If you expand the context below, the weights are replaced as follows:
model.conv.weight = torch.nn.Parameter(
_construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
)
But this is not consistent with Conv2d(2, 4, 3)
, it is consistent with Conv2d(4, 2, 3)
, which is what I'm changing it to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64 / test (default, 1, 3, macos-m1-12) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 0 checks: Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
quantized::linear_relu