diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index 9e6379a29cec..07891be5e420 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -1220,6 +1220,49 @@ def checkHooksIsPresent(model, before_convert=True): torch.quantization.convert(model, inplace=True) checkHooksIsPresent(model, False) + def test_add_scalar_uses_input_qparams(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.ff = torch.nn.quantized.FloatFunctional() + + def forward(self, x): + x = self.quant(x) + x = self.ff.add_scalar(x, 1.0) + return x + + m = M() + m.qconfig = torch.quantization.default_qconfig + mp = torch.quantization.prepare_qat(m) + mp(torch.randn(4, 4)) + mq = torch.quantization.convert(mp) + res = mq(torch.randn(4, 4)) + eps = 1e-5 + self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps) + + def test_mul_scalar_uses_input_qparams(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.ff = torch.nn.quantized.FloatFunctional() + + def forward(self, x): + x = self.quant(x) + x = self.ff.mul_scalar(x, 2.0) + return x + + m = M() + m.qconfig = torch.quantization.default_qconfig + mp = torch.quantization.prepare_qat(m) + mp(torch.randn(4, 4)) + mq = torch.quantization.convert(mp) + res = mq(torch.randn(4, 4)) + eps = 1e-5 + self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps) + + class TestEagerModeOps(QuantizationTestCase): def _test_activation_op_impl( self, float_module_class, quantized_module_class, extra_module_kwargs): diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 7c6c548f2594..11b1d668e844 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1458,6 +1458,57 @@ def test_quantized_mul_relu(self): self._test_quantized_binary_op_relu_impl( operator.mul, operator.imul, torch.ops.quantized.mul_relu) + # TODO(future PR): make more generic + def _test_quantized_add_mul_qat(self, model, expected_node_occurrence): + qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')} + mp = torch.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) + self.checkGraphModuleNodes( + mp, expected_node_occurrence=expected_node_occurrence) + + @skipIfNoFBGEMM + def test_quantized_add_qat(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = torch.add(x, 1.0) + x = self.conv1(x) + x = torch.add(x, 1.0) + x = torch.relu(x) + x = self.conv2(x) + return x + + m = M() + expected_node_occurrence = { + ns.call_module(torch.quantization.FakeQuantize): 4, + } + self._test_quantized_add_mul_qat(m, expected_node_occurrence) + + @skipIfNoFBGEMM + def test_quantized_mul_qat(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = torch.mul(x, 1.0) + x = self.conv1(x) + x = torch.mul(x, 1.0) + x = torch.relu(x) + x = self.conv2(x) + return x + + m = M() + expected_node_occurrence = { + ns.call_module(torch.quantization.FakeQuantize): 4, + } + self._test_quantized_add_mul_qat(m, expected_node_occurrence) + @skipIfNoFBGEMM def test_quantized_cat(self): """ quantization of the output of cat will be depend on the diff --git a/torch/nn/quantized/modules/functional_modules.py b/torch/nn/quantized/modules/functional_modules.py index 959d60f89609..b9fab962d563 100644 --- a/torch/nn/quantized/modules/functional_modules.py +++ b/torch/nn/quantized/modules/functional_modules.py @@ -50,7 +50,8 @@ def add(self, x, y): def add_scalar(self, x, y): # type: (Tensor, float) -> Tensor r = torch.add(x, y) - r = self.activation_post_process(r) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. return r r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" @@ -64,7 +65,8 @@ def mul(self, x, y): def mul_scalar(self, x, y): # type: (Tensor, float) -> Tensor r = torch.mul(x, y) - r = self.activation_post_process(r) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. return r r"""Operation equivalent to ``torch.cat``""" @@ -203,7 +205,8 @@ def add(self, x, y): def add_scalar(self, x, y): # type: (Tensor, float) -> Tensor r = ops.quantized.add_scalar(x, y) - r = self.activation_post_process(r) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. return r r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``""" @@ -217,7 +220,8 @@ def mul(self, x, y): def mul_scalar(self, x, y): # type: (Tensor, float) -> Tensor r = ops.quantized.mul_scalar(x, y) - r = self.activation_post_process(r) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. return r r"""Operation equivalent to ``torch.ops.quantized.cat``"""