Skip to content

Commit

Permalink
[quant] Add support for quantize_per_channel in the reference flow wi…
Browse files Browse the repository at this point in the history
…th decomposed tensor (#89270)

Summary:
att, after this PR we can produce quantize_per_channel and dequantize_per_channel ops (typically used for quantizing weights)
in the reference flow using decomposed tensor

Test Plan:
python test/test_quantization.py -k test__convert_to_reference_decomposed_fx_per_channel_quant

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: #89270
Approved by: https://github.com/vkuzo
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Nov 23, 2022
1 parent c651944 commit 39772a6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
31 changes: 31 additions & 0 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
QuantWrapper,
default_qconfig,
default_dynamic_qconfig,
default_per_channel_qconfig,
default_qat_qconfig,
default_reuse_input_qconfig,
default_symmetric_qnnpack_qconfig,
Expand Down Expand Up @@ -5377,6 +5378,36 @@ def forward(self, x):
res = m(*example_inputs)
self.assertEqual(res, res_ref)

def test__convert_to_reference_decomposed_fx_per_channel_quant(self):
class M(torch.nn.Module):
def forward(self, x, weight, bias):
return F.linear(x, weight, bias)

m = M().eval()
qconfig_mapping = get_default_qconfig_mapping("fbgemm") \
.set_object_type(F.linear, default_per_channel_qconfig)
example_inputs = (torch.randn(1, 5), torch.randn(10, 5), torch.randn(10,))
m = prepare_fx(m, qconfig_mapping, example_inputs)
m(*example_inputs)
m_ref = copy.deepcopy(m)
m_ref = convert_to_reference_fx(m_ref)
m = _convert_to_reference_decomposed_fx(m)
expected_occurrence = {
# for input and output activations
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2,
# for weight
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1,
}
self.checkGraphModuleNodes(
m,
expected_node_occurrence=expected_occurrence)
# make sure it runs
res_ref = m_ref(*example_inputs)
res = m(*example_inputs)
self.assertEqual(res, res_ref)

def test_change_backend_config_for_fixed_qparam_ops(self):
""" Making sure we can skip validation of qconfigs for fixedqparam ops based
on BackendConfig
Expand Down
19 changes: 16 additions & 3 deletions torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,23 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
quantize_op : Optional[Callable] = None
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined]
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
raise NotImplementedError("decomposed quantize_per_channel op not implemented yet")
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined]
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel
quant_min = activation_post_process.quant_min
quant_max = activation_post_process.quant_max
dtype_ = to_underlying_dtype(dtype)
qparams = {
"_scale_": scale,
"_zero_point_": zero_point,
"_axis_": ch_axis,
"_quant_min_": quant_min,
"_quant_max_": quant_max,
"_dtype_": dtype_
}
else:
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor
scale = float(scale)
zero_point = int(zero_point)
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
Expand All @@ -160,7 +175,6 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
"_quant_max_": quant_max,
"_dtype_": dtype_
}
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor

# 2. replace activation_post_process node with quantize and dequantize
with graph.inserting_before(node):
Expand All @@ -182,7 +196,6 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor
dequantized_node = graph.call_function(
dequantize_op,
tuple(dq_inputs),
Expand Down

0 comments on commit 39772a6

Please sign in to comment.