From 82974ade9a64c1b6e5417ed97db36b1011c44bd9 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 22 Dec 2020 09:35:04 -0800 Subject: [PATCH] fx quant: split linear test cases Summary: 1. Separates the module and functional linear test cases. 2. Combines the test case which tests for linear bias observation into the main linear test case, as requested in https://github.com/pytorch/pytorch/pull/49628. Test Plan: ``` python test/test_quantization.py TestQuantizeFxOps.test_linear_module python test/test_quantization.py TestQuantizeFxOps.test_linear_functional ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 09ef7d72e358bb9c4be9ed0543600b3560b4df9c Pull Request resolved: https://github.com/pytorch/pytorch/pull/49740 --- test/quantization/test_quantize_fx.py | 92 +++++++++++++++------------ 1 file changed, 51 insertions(+), 41 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index ce21e6c42469..974ff4dd7a99 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1281,7 +1281,7 @@ class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops """ @skipIfNoFBGEMM - def test_linear(self): + def test_linear_module(self): class ModuleLinear(torch.nn.Module): def __init__(self, has_relu=False, f_relu=False): super(ModuleLinear, self).__init__() @@ -1297,27 +1297,9 @@ def __init__(self, has_relu=False, f_relu=False): def forward(self, x): return self.relu(self.linear(x)) - class FuncLinear(torch.nn.Module): - def __init__(self, has_relu=False, f_relu=False): - super(FuncLinear, self).__init__() - self.w = torch.randn(4, 30) - self.b = torch.randn(4) - if has_relu: - if f_relu: - self.relu = F.relu - else: - self.relu = torch.nn.ReLU() - else: - self.relu = torch.nn.Identity() - - def forward(self, x): - return self.relu(F.linear(x, self.w, self.b)) - data = (torch.rand((1, 30), dtype=torch.float),) options = itertools.product( [(ModuleLinear(has_relu=False), True)], - # TODO: enable after raw `tensor` is supported in fx - # (FuncLinear(has_relu=False), False)], self.all_quant_types) quantized_nodes = { # is_module @@ -1328,12 +1310,6 @@ def forward(self, x): # note that we are checking the final result QuantType.QAT: ns.call_module(nnq.Linear), }, - False: { - # quant_type: - QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), - QuantType.STATIC: ns.call_function(torch.ops.quantized.linear), - QuantType.QAT: ns.call_function(torch.ops.quantized.linear), - } } for (model, is_module), quant_type in options: self.checkGraphModeFxOp( @@ -1342,10 +1318,58 @@ def forward(self, x): for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]): for model, quantized_node in [ (ModuleLinear(has_relu=True, f_relu=f_relu), ns.call_module(nniq.LinearReLU))]: - # TODO: support functional linear + relu fusion - # (FuncLinear(has_relu=True, f_relu=f_relu), ns.call_function(torch.ops.quantized.linear_relu))]: self.checkGraphModeFxOp(model, data, quant_type, quantized_node) + @skipIfNoFBGEMM + def test_linear_functional(self): + + class FuncLinear(torch.nn.Module): + def __init__(self, use_bias): + super(FuncLinear, self).__init__() + self.w = torch.randn(4, 30) + self.b = torch.randn(4) + self.use_bias = use_bias + + def forward(self, x): + if self.use_bias: + x = F.linear(x, self.w, self.b) + else: + x = F.linear(x, self.w) + return x + + data = (torch.rand((1, 30), dtype=torch.float),) + quant_type_to_qlinear_fun = { + QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), + QuantType.STATIC: ns.call_function(torch.ops.quantized.linear), + QuantType.QAT: ns.call_function(torch.ops.quantized.linear), + } + quant_type_to_prepare_expected_node_occurrence = { + QuantType.DYNAMIC: {}, + # There should be 3 observers: after input, weight and activation. + QuantType.STATIC: { + ns.call_module(torch.quantization.HistogramObserver): 2, + ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1, + }, + # There should be 3 observers: after input, weight and activation. + QuantType.QAT: { + ns.call_module(torch.quantization.FakeQuantize): 3, + }, + } + options = itertools.product( + (QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT), + (True, False), # use_bias + ) + for quant_type, use_bias in options: + model = FuncLinear(use_bias) + qlinear_fun = quant_type_to_qlinear_fun[quant_type] + prepare_expected_node_occurrence = \ + quant_type_to_prepare_expected_node_occurrence[quant_type] + self.checkGraphModeFxOp( + model, data, quant_type, qlinear_fun, + prepare_expected_node_occurrence=prepare_expected_node_occurrence) + + # TODO(future PR): test for Linear + ReLU fusion + @skipIfNoFBGEMM def test_conv_module(self): conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} @@ -1388,20 +1412,6 @@ def test_conv2d_functional(self): expected_node_occurrence=expected_node_occurrence, ) - def test_linear_functional_bias_not_observed(self): - data = (torch.rand((1, 4), dtype=torch.float),) - for bias in [True, False]: - linear = torch.nn.Linear(4, 4, bias=bias) - # There should be 3 observers: after input, weight and activation. - expected_node_occurrence = { - ns.call_module(torch.quantization.HistogramObserver): 2, - ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1, - } - self.checkGraphModeFxOp( - linear, data, QuantType.STATIC, - prepare_expected_node_occurrence=expected_node_occurrence, - ) - @skipIfNoFBGEMM def test_quantized_conv_relu(self): """tests for conv1d_relu/conv2d_relu/conv3d_relu"""