Skip to content

Commit

Permalink
fx quant: split linear test cases
Browse files Browse the repository at this point in the history
Summary:

1. Separates the module and functional linear test cases.
2. Combines the test case which tests for linear bias observation into
the main linear test case, as requested in
#49628.

Test Plan:

```
python test/test_quantization.py TestQuantizeFxOps.test_linear_module
python test/test_quantization.py TestQuantizeFxOps.test_linear_functional
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 09ef7d72e358bb9c4be9ed0543600b3560b4df9c
Pull Request resolved: #49740
  • Loading branch information
vkuzo committed Dec 22, 2020
1 parent a00c177 commit 82974ad
Showing 1 changed file with 51 additions and 41 deletions.
92 changes: 51 additions & 41 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1281,7 +1281,7 @@ class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
"""
@skipIfNoFBGEMM
def test_linear(self):
def test_linear_module(self):
class ModuleLinear(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super(ModuleLinear, self).__init__()
Expand All @@ -1297,27 +1297,9 @@ def __init__(self, has_relu=False, f_relu=False):
def forward(self, x):
return self.relu(self.linear(x))

class FuncLinear(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super(FuncLinear, self).__init__()
self.w = torch.randn(4, 30)
self.b = torch.randn(4)
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()

def forward(self, x):
return self.relu(F.linear(x, self.w, self.b))

data = (torch.rand((1, 30), dtype=torch.float),)
options = itertools.product(
[(ModuleLinear(has_relu=False), True)],
# TODO: enable after raw `tensor` is supported in fx
# (FuncLinear(has_relu=False), False)],
self.all_quant_types)
quantized_nodes = {
# is_module
Expand All @@ -1328,12 +1310,6 @@ def forward(self, x):
# note that we are checking the final result
QuantType.QAT: ns.call_module(nnq.Linear),
},
False: {
# quant_type:
QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
QuantType.STATIC: ns.call_function(torch.ops.quantized.linear),
QuantType.QAT: ns.call_function(torch.ops.quantized.linear),
}
}
for (model, is_module), quant_type in options:
self.checkGraphModeFxOp(
Expand All @@ -1342,10 +1318,58 @@ def forward(self, x):
for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]):
for model, quantized_node in [
(ModuleLinear(has_relu=True, f_relu=f_relu), ns.call_module(nniq.LinearReLU))]:
# TODO: support functional linear + relu fusion
# (FuncLinear(has_relu=True, f_relu=f_relu), ns.call_function(torch.ops.quantized.linear_relu))]:
self.checkGraphModeFxOp(model, data, quant_type, quantized_node)

@skipIfNoFBGEMM
def test_linear_functional(self):

class FuncLinear(torch.nn.Module):
def __init__(self, use_bias):
super(FuncLinear, self).__init__()
self.w = torch.randn(4, 30)
self.b = torch.randn(4)
self.use_bias = use_bias

def forward(self, x):
if self.use_bias:
x = F.linear(x, self.w, self.b)
else:
x = F.linear(x, self.w)
return x

data = (torch.rand((1, 30), dtype=torch.float),)
quant_type_to_qlinear_fun = {
QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
QuantType.STATIC: ns.call_function(torch.ops.quantized.linear),
QuantType.QAT: ns.call_function(torch.ops.quantized.linear),
}
quant_type_to_prepare_expected_node_occurrence = {
QuantType.DYNAMIC: {},
# There should be 3 observers: after input, weight and activation.
QuantType.STATIC: {
ns.call_module(torch.quantization.HistogramObserver): 2,
ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1,
},
# There should be 3 observers: after input, weight and activation.
QuantType.QAT: {
ns.call_module(torch.quantization.FakeQuantize): 3,
},
}
options = itertools.product(
(QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT),
(True, False), # use_bias
)
for quant_type, use_bias in options:
model = FuncLinear(use_bias)
qlinear_fun = quant_type_to_qlinear_fun[quant_type]
prepare_expected_node_occurrence = \
quant_type_to_prepare_expected_node_occurrence[quant_type]
self.checkGraphModeFxOp(
model, data, quant_type, qlinear_fun,
prepare_expected_node_occurrence=prepare_expected_node_occurrence)

# TODO(future PR): test for Linear + ReLU fusion

@skipIfNoFBGEMM
def test_conv_module(self):
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
Expand Down Expand Up @@ -1388,20 +1412,6 @@ def test_conv2d_functional(self):
expected_node_occurrence=expected_node_occurrence,
)

def test_linear_functional_bias_not_observed(self):
data = (torch.rand((1, 4), dtype=torch.float),)
for bias in [True, False]:
linear = torch.nn.Linear(4, 4, bias=bias)
# There should be 3 observers: after input, weight and activation.
expected_node_occurrence = {
ns.call_module(torch.quantization.HistogramObserver): 2,
ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1,
}
self.checkGraphModeFxOp(
linear, data, QuantType.STATIC,
prepare_expected_node_occurrence=expected_node_occurrence,
)

@skipIfNoFBGEMM
def test_quantized_conv_relu(self):
"""tests for conv1d_relu/conv2d_relu/conv3d_relu"""
Expand Down

0 comments on commit 82974ad

Please sign in to comment.