Skip to content

Commit

Permalink
fx quant: hook up ConvTranspose{n}d (#49717)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49717

Quantization of `ConvTranpose{n}d` is supported in Eager mode. This PR
adds the support for FX graph mode.

Note: this currenlty only works in `qnnpack` because per-channel weights
are not supported by quantized conv transpose. In a future PR we should throw
an error when someone tries to quantize a ConvTranspose model with per-channel
weight observers until this is fixed.

Test Plan:
```
python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_1d
python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_2d
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25674636

fbshipit-source-id: b6948156123ed55db77e6337bea10db956215ae6
  • Loading branch information
vkuzo authored and facebook-github-bot committed Dec 28, 2020
1 parent fc559bd commit ea558b2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
39 changes: 39 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -64,7 +64,9 @@
)

from torch.testing._internal.common_quantized import (
supported_qengines,
override_qengines,
override_quantized_engine,
)

from torch.testing._internal.common_distributed import skip_if_not_multigpu
Expand All @@ -78,6 +80,7 @@
import operator
import unittest
import io
from typing import Callable

class TestFuseFx(QuantizationTestCase):
def test_fuse_conv_bn_relu(self):
Expand Down Expand Up @@ -2365,6 +2368,42 @@ def test_rnn(self):
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input)

def _test_conv_transpose_impl(
self, float_cls: Callable, q_cls: Callable, data: torch.Tensor):
with override_quantized_engine('qnnpack'):
# Create fp32 versions of FX and Eager models
m1 = torch.nn.Sequential(float_cls(1, 1, 1))
m2 = torch.nn.Sequential(float_cls(1, 1, 1))
m2.load_state_dict(m1.state_dict())
m2 = torch.quantization.QuantWrapper(m2)
# FX graph
q_result1 = self.checkGraphModeFxOp(
m1, (data,), QuantType.STATIC,
expected_node_occurrence={
ns.call_module(q_cls): 1,
})
# Eager
m2.qconfig = get_default_qconfig(torch.backends.quantized.engine)
m2.eval()
m2p = torch.quantization.prepare(m2)
m2p(data)
m2q = torch.quantization.convert(m2p)
q_result2 = m2q(data)
# verify results match
self.assertTrue(torch.allclose(q_result1, q_result2))

@unittest.skipUnless('qnnpack' in supported_qengines,
"This Pytorch Build has not been built with or does not support QNNPACK")
def test_conv_transpose_1d(self):
self._test_conv_transpose_impl(
torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4))

@unittest.skipUnless('qnnpack' in supported_qengines,
"This Pytorch Build has not been built with or does not support QNNPACK")
def test_conv_transpose_2d(self):
self._test_conv_transpose_impl(
torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4))


class TestQuantizeFxModels(QuantizationTestCase):
def _test_model_impl(
Expand Down
2 changes: 2 additions & 0 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -537,6 +537,8 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
torch._ops.ops.quantized.instance_norm:
['running_mean', 'running_var', 'use_input_stats', 'momentum'],
}
@register_quant_pattern(torch.nn.ConvTranspose1d)
@register_quant_pattern(torch.nn.ConvTranspose2d)
@register_quant_pattern(torch.nn.ELU)
@register_quant_pattern(torch.nn.LeakyReLU)
@register_quant_pattern(torch.nn.Hardswish)
Expand Down

0 comments on commit ea558b2

Please sign in to comment.