Skip to content

Commit

Permalink
fx quant: do not observe bias on F.conv (#49623)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49623

(not ready for review)

Ensures that conv bias is not observed in a `F.conv{n}d` call.

Test Plan: Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25652856

fbshipit-source-id: 884f87be1948d3e049a557d79bec3c90aec34340
  • Loading branch information
vkuzo authored and facebook-github-bot committed Dec 23, 2020
1 parent b414123 commit c3a7591
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 17 deletions.
26 changes: 22 additions & 4 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1346,12 +1346,12 @@ def forward(self, x):
self.checkGraphModeFxOp(model, data, quant_type, quantized_node)

@skipIfNoFBGEMM
def test_quantized_conv(self):
def test_conv_module(self):
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}

class Conv(torch.nn.Module):
class ConvWrapper(torch.nn.Module):
def __init__(self, dim):
super(Conv, self).__init__()
super(ConvWrapper, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()

def forward(self, x):
Expand All @@ -1366,9 +1366,27 @@ def forward(self, x):
}
for dim, quant_type in options:
model = self.checkGraphModeFxOp(
Conv(dim), self.img_data_dict[dim], quant_type,
ConvWrapper(dim), self.img_data_dict[dim], quant_type,
quantized_nodes[dim])

@skipIfNoFBGEMM
def test_conv2d_functional(self):
for bias in [True, False]:
conv = torch.nn.Conv2d(1, 1, 1, bias=bias)
# There should be 3 observers: after input, weight and activation.
# No observer after bias.
prepare_expected_node_occurrence = {
ns.call_module(torch.quantization.HistogramObserver): 2,
ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1,
}
expected_node_occurrence = \
{ns.call_function(torch.ops.quantized.conv2d): 1}
self.checkGraphModeFxOp(
conv, (torch.randn(4, 1, 4, 4),), QuantType.STATIC,
prepare_expected_node_occurrence=prepare_expected_node_occurrence,
expected_node_occurrence=expected_node_occurrence,
)

@skipIfNoFBGEMM
def test_quantized_conv_relu(self):
"""tests for conv1d_relu/conv2d_relu/conv3d_relu"""
Expand Down
47 changes: 38 additions & 9 deletions torch/quantization/fx/quantize.py
Expand Up @@ -234,10 +234,38 @@ def insert_observer_for_input_arg_of_observed_node(

# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv1d : [1],
torch.nn.functional.conv2d : [1],
torch.nn.functional.conv3d : [1],
torch.nn.functional.linear : [1],
}

def node_arg_is_weight(node: Node, arg: Any) -> bool:
if isinstance(node, Node) and node.op == 'call_function' and \
node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in \
WEIGHT_INDEX_DICT[node.target]: # type: ignore
return True
return False

# A dictionary for querying the weight index for a given op
# TODO(future PR): handle linear
BIAS_INDEX_DICT = {
torch.nn.functional.conv1d : [2],
torch.nn.functional.conv2d : [2],
torch.nn.functional.conv3d : [2],
}

def node_arg_is_bias(node: Node, arg: Any) -> bool:
if isinstance(node, Node) and node.op == 'call_function' and \
node.target in BIAS_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in \
BIAS_INDEX_DICT[node.target]: # type: ignore
return True
return False

# weight prepacking ops
WEIGHT_PREPACK_OPS = {
torch._ops.ops.quantized.linear_prepack,
Expand Down Expand Up @@ -956,15 +984,16 @@ def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],

def visit(node, matched_pattern, qconfig):
def visit_arg(arg):
is_weight = False
if isinstance(node, Node) and node.op == 'call_function' and \
node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in \
WEIGHT_INDEX_DICT[node.target]: # type: ignore
is_weight = True
if qconfig is not None and \
(activation_is_statically_quantized(qconfig) or is_weight):
is_weight = node_arg_is_weight(node, arg)
is_bias = node_arg_is_bias(node, arg)
is_activation = not (is_weight or is_bias)
should_add_handler = qconfig is not None and (
(is_activation and
activation_is_statically_quantized(qconfig)) or
(is_weight and weight_is_statically_quantized(qconfig))
)

if should_add_handler:
act_post_process_ctr = qconfig.weight if is_weight else \
qconfig.activation
# overwrite the constructor from qconfig
Expand Down
12 changes: 8 additions & 4 deletions torch/testing/_internal/common_quantization.py
Expand Up @@ -672,6 +672,13 @@ def checkGraphModeFxOp(self, model, inputs, quant_type,
if not quant_type == QuantType.DYNAMIC:
prepared(*inputs)

if print_debug_info:
print()
print('quant type:\n', quant_type)
print('original model:\n', model)
print()
print('prepared model:\n', prepared)

self.checkGraphModuleNodes(
prepared, prepare_expected_node,
prepare_expected_node_occurrence, prepare_expected_node_list)
Expand All @@ -685,10 +692,7 @@ def checkGraphModeFxOp(self, model, inputs, quant_type,
qgraph_to_check = qgraph_debug if debug else qgraph
if print_debug_info:
print()
print('quant type:', quant_type)
print('original model:', model)
print()
print('quantized model:', qgraph_to_check)
print('quantized model:\n', qgraph_to_check)
self.printGraphModule(qgraph_to_check)
print()
self.checkGraphModuleNodes(
Expand Down

0 comments on commit c3a7591

Please sign in to comment.