Skip to content

Commit

Permalink
eager quant: remove fake_quant after add/mul nodes during QAT (#49213)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49213

Changes behavior of Eager mode quantization to remove observation after add_scalar/mul_scalar.
This is not used, and it removes one difference between Eager and FX modes.

Test Plan:
```
python test/test_quantization.py TestQuantizeFxOps.test_quantized_add_qat
python test/test_quantization.py TestQuantizeFxOps.test_quantized_mul_qat
python test/test_quantization.py TestQuantizationAwareTraining.test_add_scalar_uses_input_qparams
python test/test_quantization.py TestQuantizationAwareTraining.test_mul_scalar_uses_input_qparams
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25486276

fbshipit-source-id: 34a5d6ce0d08739319ec0f8b197cfc1309d71040
  • Loading branch information
vkuzo authored and facebook-github-bot committed Dec 17, 2020
1 parent 9045862 commit 36b2092
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 4 deletions.
43 changes: 43 additions & 0 deletions test/quantization/test_quantize.py
Expand Up @@ -1220,6 +1220,49 @@ def checkHooksIsPresent(model, before_convert=True):
torch.quantization.convert(model, inplace=True)
checkHooksIsPresent(model, False)

def test_add_scalar_uses_input_qparams(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.ff = torch.nn.quantized.FloatFunctional()

def forward(self, x):
x = self.quant(x)
x = self.ff.add_scalar(x, 1.0)
return x

m = M()
m.qconfig = torch.quantization.default_qconfig
mp = torch.quantization.prepare_qat(m)
mp(torch.randn(4, 4))
mq = torch.quantization.convert(mp)
res = mq(torch.randn(4, 4))
eps = 1e-5
self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps)

def test_mul_scalar_uses_input_qparams(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.ff = torch.nn.quantized.FloatFunctional()

def forward(self, x):
x = self.quant(x)
x = self.ff.mul_scalar(x, 2.0)
return x

m = M()
m.qconfig = torch.quantization.default_qconfig
mp = torch.quantization.prepare_qat(m)
mp(torch.randn(4, 4))
mq = torch.quantization.convert(mp)
res = mq(torch.randn(4, 4))
eps = 1e-5
self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps)


class TestEagerModeOps(QuantizationTestCase):
def _test_activation_op_impl(
self, float_module_class, quantized_module_class, extra_module_kwargs):
Expand Down
51 changes: 51 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1458,6 +1458,57 @@ def test_quantized_mul_relu(self):
self._test_quantized_binary_op_relu_impl(
operator.mul, operator.imul, torch.ops.quantized.mul_relu)

# TODO(future PR): make more generic
def _test_quantized_add_mul_qat(self, model, expected_node_occurrence):
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
mp = torch.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
self.checkGraphModuleNodes(
mp, expected_node_occurrence=expected_node_occurrence)

@skipIfNoFBGEMM
def test_quantized_add_qat(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
x = torch.add(x, 1.0)
x = self.conv1(x)
x = torch.add(x, 1.0)
x = torch.relu(x)
x = self.conv2(x)
return x

m = M()
expected_node_occurrence = {
ns.call_module(torch.quantization.FakeQuantize): 4,
}
self._test_quantized_add_mul_qat(m, expected_node_occurrence)

@skipIfNoFBGEMM
def test_quantized_mul_qat(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
x = torch.mul(x, 1.0)
x = self.conv1(x)
x = torch.mul(x, 1.0)
x = torch.relu(x)
x = self.conv2(x)
return x

m = M()
expected_node_occurrence = {
ns.call_module(torch.quantization.FakeQuantize): 4,
}
self._test_quantized_add_mul_qat(m, expected_node_occurrence)

@skipIfNoFBGEMM
def test_quantized_cat(self):
""" quantization of the output of cat will be depend on the
Expand Down
12 changes: 8 additions & 4 deletions torch/nn/quantized/modules/functional_modules.py
Expand Up @@ -50,7 +50,8 @@ def add(self, x, y):
def add_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
r = torch.add(x, y)
r = self.activation_post_process(r)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r

r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
Expand All @@ -64,7 +65,8 @@ def mul(self, x, y):
def mul_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
r = torch.mul(x, y)
r = self.activation_post_process(r)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r

r"""Operation equivalent to ``torch.cat``"""
Expand Down Expand Up @@ -203,7 +205,8 @@ def add(self, x, y):
def add_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
r = ops.quantized.add_scalar(x, y)
r = self.activation_post_process(r)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r

r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``"""
Expand All @@ -217,7 +220,8 @@ def mul(self, x, y):
def mul_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
r = ops.quantized.mul_scalar(x, y)
r = self.activation_post_process(r)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r

r"""Operation equivalent to ``torch.ops.quantized.cat``"""
Expand Down

0 comments on commit 36b2092

Please sign in to comment.