Skip to content

Commit

Permalink
fx quant: do not insert observers at quantized inputs (#49239)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49239

Context: the existing implementation of `quantized_input_idxs` is convert-only.
Therefore, observers are inserted between the input and the first
quantized node.  This is a problem during QAT, because the initial
input is a fake_quant, and it starts with scale=1 and zp=0.  This does
not match the quantization parameters of the graph input, which can
lead to incorrect numerics.

Fix: do not insert observer for a quantized input.

Test Plan:
```
python test/test_quantization.py TestQuantizeFx
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25499486

fbshipit-source-id: 303b49cc9d95a9fd06fef3b0859c08be34e19d8a
  • Loading branch information
vkuzo authored and facebook-github-bot committed Dec 17, 2020
1 parent 92df870 commit 7542076
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
34 changes: 24 additions & 10 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,8 @@ def forward(self, x):
M().eval(), (data,), quant_type, expected_node_list=node_list)

def _test_quantized_inputs_outputs(
self, prepare_custom_config_dict, count_check):
self, prepare_custom_config_dict, prepare_count_check,
convert_count_check):
"""
Test the option to have inputs and outputs of the graph quantized
"""
Expand All @@ -1182,48 +1183,61 @@ def forward(self, x):
mp = torch.quantization.quantize_fx.prepare_fx(
m, qconfig_dict,
prepare_custom_config_dict=prepare_custom_config_dict)
self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check)
mp(torch.randn(1, 1, 4, 4))
mq = torch.quantization.quantize_fx.convert_fx(mp)
self.checkGraphModuleNodes(mq, expected_node_occurrence=count_check)
self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check)

def test_quantized_input_quantized_output(self):
prepare_custom_config_dict = {
'input_quantized_idxs': [0], 'output_quantized_idxs': [0]}
count_check = {
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 0,
ns.call_method('dequantize'): 0,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, count_check)
prepare_custom_config_dict, prepare_count_check, convert_count_check)

def test_fp32_input_quantized_output(self):
prepare_custom_config_dict = {
'output_quantized_idxs': [0]}
count_check = {
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 3,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 1,
ns.call_method('dequantize'): 0,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, count_check)
prepare_custom_config_dict, prepare_count_check, convert_count_check)

def test_quantized_input_fp32_output(self):
prepare_custom_config_dict = {
'input_quantized_idxs': [0]}
count_check = {
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 0,
ns.call_method('dequantize'): 1,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, count_check)
prepare_custom_config_dict, prepare_count_check, convert_count_check)

def test_fp32_input_fp32_output(self):
prepare_custom_config_dict = {}
count_check = {
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 3,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 1,
ns.call_method('dequantize'): 1,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, count_check)
prepare_custom_config_dict, prepare_count_check, convert_count_check)

@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
Expand Down
13 changes: 13 additions & 0 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,10 @@ def load_arg(a):
get_new_observer_name = get_new_attr_name_with_prefix(
'activation_post_process_')

placeholder_node_seen_cnt = 0
input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
"input_quantized_idxs", [])

result_node : Optional[Node] = None
for node in model.graph.nodes:
if node.op == 'output':
Expand Down Expand Up @@ -427,6 +431,15 @@ def load_arg(a):
matched_nodes)
else:
env[node.name] = observed_graph.node_copy(node, load_arg)

if node.op == 'placeholder':
# skip adding observers at the graph input if the input is
# overriden to be quantized
cur_placeholder_node_idx = placeholder_node_seen_cnt
placeholder_node_seen_cnt += 1
if cur_placeholder_node_idx in input_quantized_idxs:
continue

insert_observer_for_input_arg_of_observed_node(
node, observed_node_names_set, quants,
model, self.activation_post_process_map, env,
Expand Down

0 comments on commit 7542076

Please sign in to comment.