Skip to content

Commit

Permalink
fx quant: do not observe bias on F.linear (#49628)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49628

Ensures that linear bias is not observed in a `F.linear` call. This should
be a small speedup in PTQ, and will change numerics (in a good way) for
QAT if someone is using `F.linear`.

Note: the implementation is slightly more verbose compared to conv
because bias is a keyword argument in Linear.

Test Plan:
```
python test/test_quantization.py TestQuantizeFxOps.test_linear_functional_bias_not_observed
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25653532

fbshipit-source-id: c93501bf6b55cbe4a11cfdad6f79313483133a39
  • Loading branch information
vkuzo authored and facebook-github-bot committed Dec 23, 2020
1 parent c3a7591 commit 19f972b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
14 changes: 14 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1387,6 +1387,20 @@ def test_conv2d_functional(self):
expected_node_occurrence=expected_node_occurrence,
)

def test_linear_functional_bias_not_observed(self):
data = (torch.rand((1, 4), dtype=torch.float),)
for bias in [True, False]:
linear = torch.nn.Linear(4, 4, bias=bias)
# There should be 3 observers: after input, weight and activation.
expected_node_occurrence = {
ns.call_module(torch.quantization.HistogramObserver): 2,
ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1,
}
self.checkGraphModeFxOp(
linear, data, QuantType.STATIC,
prepare_expected_node_occurrence=expected_node_occurrence,
)

@skipIfNoFBGEMM
def test_quantized_conv_relu(self):
"""tests for conv1d_relu/conv2d_relu/conv3d_relu"""
Expand Down
26 changes: 14 additions & 12 deletions torch/quantization/fx/quantize.py
Expand Up @@ -249,21 +249,23 @@ def node_arg_is_weight(node: Node, arg: Any) -> bool:
return True
return False

# A dictionary for querying the weight index for a given op
# TODO(future PR): handle linear
BIAS_INDEX_DICT = {
torch.nn.functional.conv1d : [2],
torch.nn.functional.conv2d : [2],
torch.nn.functional.conv3d : [2],
CONV_OPS_WITH_BIAS = {
torch.nn.functional.conv1d,
torch.nn.functional.conv2d,
torch.nn.functional.conv3d,
}
CONV_BIAS_ARG_INDEX = 2

def node_arg_is_bias(node: Node, arg: Any) -> bool:
if isinstance(node, Node) and node.op == 'call_function' and \
node.target in BIAS_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in \
BIAS_INDEX_DICT[node.target]: # type: ignore
return True
if isinstance(node, Node) and node.op == 'call_function':
if node.target in CONV_OPS_WITH_BIAS:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i == CONV_BIAS_ARG_INDEX:
return True
elif node.target is torch.nn.functional.linear:
for kwarg_name, kwarg_value in node.kwargs.items():
if kwarg_name == 'bias' and arg is kwarg_value:
return True
return False

# weight prepacking ops
Expand Down

0 comments on commit 19f972b

Please sign in to comment.