Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export fake quantization function to ONNX #39502

Closed
skyw opened this issue Jun 4, 2020 · 5 comments
Closed

Export fake quantization function to ONNX #39502

skyw opened this issue Jun 4, 2020 · 5 comments
Labels
feature A request for a proper, new feature. module: onnx Related to torch.onnx oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@skyw
Copy link
Contributor

skyw commented Jun 4, 2020

馃殌 Feature

Export fake_quantize_per_tensor_affine and fake_quantize_per_channel_affine functions to ONNX.

Motivation

To support deploying QAT network by backend like TensorRT outside pytorch through ONNX, fake quantization needs to be exported to ONNX operator.

Pitch

fake quantization will be broken into a pair of QuantizeLinear and DequantizeLinear ONNX operator.

Alternatives

Additional context

Fake quantization is effectively quantize followed dequantize. Have discussed with @raghuramank100 and reached an agreement this is the right way to export.

Similar functionality has been added to tensorflow-onnx, onnx/tensorflow-onnx#919.

cc @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof @jerryzh168 @jianyuh @dzhulgakov @raghuramank100 @jamesr66a

@Zehaos
Copy link

Zehaos commented Jun 4, 2020

Good point! FakeQuantize is a good choice as for exporting quantization params like scale&zeropoint.

@mrshenli mrshenli added feature A request for a proper, new feature. module: onnx Related to torch.onnx oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 4, 2020
@skyw
Copy link
Contributor Author

skyw commented Jun 4, 2020

Thanks @Zehaos. What was discussed with @jamesr66a and @raghuramank100 is I'll also contribute the PR. If the owner of the code is OK with it. So would be happy to contribute.

@jamesr66a
Copy link
Collaborator

Hi @skyw, this proposal looks good! Please feel free to contribute the PR

@skyw
Copy link
Contributor Author

skyw commented Jun 9, 2020

Thanks, James. Filed PR #39738

facebook-github-bot pushed a commit that referenced this issue Jul 16, 2020
Summary:
As discussed in #39502.

This PR adds support for exporting  `fake_quantize_per_tensor_affine` to a pair of `QuantizeLinear` and `DequantizeLinear`.

Exporting `fake_quantize_per_channel_affine` to ONNX depends on onnx/onnx#2772. will file another PR once ONNX merged the change.

It will generate ONNX graph like this:
![image](https://user-images.githubusercontent.com/1697840/84180123-ddd90080-aa3b-11ea-81d5-eaf6f5f26715.png)

jamesr66a

Pull Request resolved: #39738

Reviewed By: hl475

Differential Revision: D22517911

Pulled By: houseroad

fbshipit-source-id: e998b4012e11b0f181b193860ff6960069a91d70
SplitInfinity pushed a commit that referenced this issue Feb 18, 2021
Summary:
Fixes #39502

This PR adds support for exporting **fake_quantize_per_channel_affine** to a pair of QuantizeLinear and DequantizeLinear. Per tensor support was added by PR #39738.

`axis` attribute of QuantizeLinear and DequantizeLinear, which is required for per channel support, is added in opset13 added by onnx/onnx#2772.

[update 1/20/2021]: opset13 is being supported on master, the added function is now properly tested. Code also rebased to new master.

The function is also tested offline with the following code
```python
import torch
from torch import quantization

from torchvision import models
qat_resnet18 = models.resnet18(pretrained=True).eval().cuda()

qat_resnet18.qconfig = quantization.QConfig(
    activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(qat_resnet18, inplace=True)
qat_resnet18.apply(quantization.enable_observer)
qat_resnet18.apply(quantization.enable_fake_quant)

dummy_input = torch.randn(16, 3, 224, 224).cuda()
_ = qat_resnet18(dummy_input)
for module in qat_resnet18.modules():
    if isinstance(module, quantization.FakeQuantize):
        module.calculate_qparams()
qat_resnet18.apply(quantization.disable_observer)

qat_resnet18.cuda()

input_names = [ "actual_input_1" ]
output_names = [ "output1" ]

torch.onnx.export(qat_resnet18, dummy_input, "quant_model.onnx", verbose=True, opset_version=13)
```
It can generate the desired graph.

Pull Request resolved: #42835

Reviewed By: houseroad

Differential Revision: D26293823

Pulled By: SplitInfinity

fbshipit-source-id: 300498a2e24b7731b12fa2fbdea4e73dde80e7ea
malfet pushed a commit that referenced this issue Feb 18, 2021
Summary:
Fixes #39502

This PR adds support for exporting **fake_quantize_per_channel_affine** to a pair of QuantizeLinear and DequantizeLinear. Per tensor support was added by PR #39738.

`axis` attribute of QuantizeLinear and DequantizeLinear, which is required for per channel support, is added in opset13 added by onnx/onnx#2772.

[update 1/20/2021]: opset13 is being supported on master, the added function is now properly tested. Code also rebased to new master.

The function is also tested offline with the following code
```python
import torch
from torch import quantization

from torchvision import models
qat_resnet18 = models.resnet18(pretrained=True).eval().cuda()

qat_resnet18.qconfig = quantization.QConfig(
    activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(qat_resnet18, inplace=True)
qat_resnet18.apply(quantization.enable_observer)
qat_resnet18.apply(quantization.enable_fake_quant)

dummy_input = torch.randn(16, 3, 224, 224).cuda()
_ = qat_resnet18(dummy_input)
for module in qat_resnet18.modules():
    if isinstance(module, quantization.FakeQuantize):
        module.calculate_qparams()
qat_resnet18.apply(quantization.disable_observer)

qat_resnet18.cuda()

input_names = [ "actual_input_1" ]
output_names = [ "output1" ]

torch.onnx.export(qat_resnet18, dummy_input, "quant_model.onnx", verbose=True, opset_version=13)
```
It can generate the desired graph.

Pull Request resolved: #42835

Reviewed By: houseroad

Differential Revision: D26293823

Pulled By: SplitInfinity

fbshipit-source-id: 300498a2e24b7731b12fa2fbdea4e73dde80e7ea

Co-authored-by: Hao Wu <skyw@users.noreply.github.com>
@ZyrianovS
Copy link

ZyrianovS commented Nov 18, 2021

On current master.

In test/onnx/test_models.py TestModels gets its opset_version from torch.onnx.symbolic_helper which is 9 by default. https://github.com/pytorch/pytorch/blob/master/test/onnx/test_models.py#L50
https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_helper.py#L847

from torch.onnx.symbolic_helper import _export_onnx_opset_version 
opset_version = _export_onnx_opset_version
print(opset_version)
9
python test/onnx/test_models.py

Ran 32 tests in 27.295s

OK (skipped=10)

So it skipped 10 tests by default, including I believe test_fake_quant, test_qat_resnet_pertensor and test_qat_resnet_per_channel with @skipIfUnsupportedMinOpsetVersion wrapper set to 10 and 13 opset_version that I was trying to debug in my QAT pipeline.

If we set _default_onnx_opset_version in torch.onnx.symbolic_helper to 13 and test with that so it does not skip those QAT tests I get:

Ran 32 tests in 38.009s

FAILED (errors=3, skipped=5)
ERROR: test_fake_quant (__main__.TestModels)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/simon/pytorch/test/onnx/test_pytorch_common.py", line 51, in wrapper
    return func(self)
  File "test/onnx/test_models.py", line 184, in test_fake_quant
    self.exportTest(toC(FakeQuantNet()), toC(x))
  File "test/onnx/test_models.py", line 56, in exportTest
    graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
  File "/home/simon/pytorch/torch/onnx/utils.py", line 378, in _trace
    trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={})
  File "/home/simon/pytorch/torch/onnx/utils.py", line 232, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/simon/pytorch/torch/onnx/__init__.py", line 358, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/simon/pytorch/torch/onnx/utils.py", line 1058, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "/home/simon/pytorch/torch/onnx/symbolic_helper.py", line 167, in wrapper
    for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)]
  File "/home/simon/pytorch/torch/onnx/symbolic_helper.py", line 167, in <listcomp>
    for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)]
  File "/home/simon/pytorch/torch/onnx/symbolic_helper.py", line 94, in _parse_arg
    "for argument '{}' of node '{}', got '{}'.".format(arg_name, node_name, value.node().kind()))
RuntimeError: Expected node type 'onnx::Constant' for argument 'scale' of node 'fake_quantize_per_tensor_affine', got 'prim::Param'.
ERROR: test_qat_resnet_per_channel (__main__.TestModels)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/simon/pytorch/test/onnx/test_pytorch_common.py", line 51, in wrapper
    return func(self)
  File "test/onnx/test_models.py", line 226, in test_qat_resnet_per_channel
    self.exportTest(toC(qat_resnet50), toC(x))
  File "test/onnx/test_models.py", line 56, in exportTest
    graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
  File "/home/simon/pytorch/torch/onnx/utils.py", line 378, in _trace
    trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={})
  File "/home/simon/pytorch/torch/onnx/utils.py", line 232, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/simon/pytorch/torch/onnx/__init__.py", line 358, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/simon/pytorch/torch/onnx/utils.py", line 1058, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "/home/simon/pytorch/torch/onnx/symbolic_helper.py", line 167, in wrapper
    for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)]
  File "/home/simon/pytorch/torch/onnx/symbolic_helper.py", line 167, in <listcomp>
    for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)]
  File "/home/simon/pytorch/torch/onnx/symbolic_helper.py", line 94, in _parse_arg
    "for argument '{}' of node '{}', got '{}'.".format(arg_name, node_name, value.node().kind()))
RuntimeError: Expected node type 'onnx::Constant' for argument 'scale' of node 'fake_quantize_per_tensor_affine', got 'prim::Param'.
ERROR: test_qat_resnet_pertensor (__main__.TestModels)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/simon/pytorch/test/onnx/test_pytorch_common.py", line 51, in wrapper
    return func(self)
  File "test/onnx/test_models.py", line 205, in test_qat_resnet_pertensor
    self.exportTest(toC(qat_resnet50), toC(x))
  File "test/onnx/test_models.py", line 56, in exportTest
    graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
  File "/home/simon/pytorch/torch/onnx/utils.py", line 378, in _trace
    trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={})
  File "/home/simon/pytorch/torch/onnx/utils.py", line 232, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/simon/pytorch/torch/onnx/__init__.py", line 358, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/simon/pytorch/torch/onnx/utils.py", line 1058, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "/home/simon/pytorch/torch/onnx/symbolic_helper.py", line 167, in wrapper
    for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)]
  File "/home/simon/pytorch/torch/onnx/symbolic_helper.py", line 167, in <listcomp>
    for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)]
  File "/home/simon/pytorch/torch/onnx/symbolic_helper.py", line 94, in _parse_arg
    "for argument '{}' of node '{}', got '{}'.".format(arg_name, node_name, value.node().kind()))
RuntimeError: Expected node type 'onnx::Constant' for argument 'scale' of node 'fake_quantize_per_tensor_affine', got 'prim::Param'.

I was getting the same errors when I updated to torch 1.10.0 in my QAT training pipeline. Rolling back to torch 1.9.1 solved the issue.
cc @neginraoof @spandantiwari @skyw

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: onnx Related to torch.onnx oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants