From 147b3495e227dea92926f79eef28de784acbc01b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 29 Aug 2023 07:12:51 +0000 Subject: [PATCH] [quant][pt2e] Add reference representation for dynamic quantized linear (#108073) Summary: att Test Plan: python test/test_quantization.py TestQuantizePT2E.test_representation_dynamic_linear buck2 test 'fbcode//mode/opt' fbcode//caffe2/test:quantization_pt2e -- 'test_representation_dynamic_linear' Reviewed By: kimishpatel Differential Revision: D48703076 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108073 Approved by: https://github.com/andrewor14 --- test/quantization/pt2e/test_quantize_pt2e.py | 45 ++++++++-- .../pt2e/representation/rewrite.py | 87 +++++++++++++++++++ 2 files changed, 124 insertions(+), 8 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 81df21939ea2..86d14b51a979 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -552,6 +552,7 @@ def _test_representation( quantizer: Quantizer, ref_node_occurrence: Dict[ns, int], non_ref_node_occurrence: Dict[ns, int], + fixed_output_tol: float = None, output_scale_idx: int = 3, ) -> torch.nn.Module: """ TODO: need to implement output checking based on output_scale once @@ -581,17 +582,22 @@ def _test_representation( self.checkGraphModuleNodes(model_copy, expected_node_occurrence=non_ref_node_occurrence) pt2e_quant_output_copy = model_copy(*example_inputs) - idx = 0 - for n in model_copy.graph.nodes: - if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default: - idx += 1 - if idx == output_scale_idx: - output_scale = n.args[1] - assert output_scale is not None + + output_tol = None + if fixed_output_tol is not None: + output_tol = fixed_output_tol + else: + idx = 0 + for n in model_copy.graph.nodes: + if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default: + idx += 1 + if idx == output_scale_idx: + output_tol = n.args[1] + assert output_tol is not None # make sure the result is off by one at most in the quantized integer representation self.assertTrue( - torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) <= (2 * output_scale + 1e-5) + torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) <= (2 * output_tol + 1e-5) ) @skipIfNoQNNPACK @@ -2148,6 +2154,29 @@ def forward(self, x): non_ref_node_occurrence={} ) + def test_representation_dynamic_linear(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False, is_dynamic=True) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + fixed_output_tol=1e-4, + ) + def test_representation_conv2d(self): class M(torch.nn.Module): def __init__(self): diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index c7c27925e54d..26a5d0128b68 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -85,6 +85,72 @@ def _reference_quantized_linear( return out_i8 +_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randn((2, 5), dtype=torch.float), + -128, + 127, + torch.finfo(torch.float32).eps, + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), +) + + +def _qdq_dynamic_quantized_linear( + x_fp32, x_quant_min, x_quant_max, x_eps, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8) + x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8) + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + return out_fp32 + +def _reference_dynamic_quantized_linear( + x_fp32, x_quant_min, x_quant_max, x_eps, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8) + # decomposed representation for quantize_per_tensor + # TODO: use out_dtype(mul, ...) here when the op is ready + x_fp32 = x_fp32 / x_scale # fp32 + # round modes might be different here + # pytorch is rounding to even, which is also common for most of the backends + x_fp32 = torch.round(x_fp32) # fp32 + x_i32 = x_fp32.to(dtype=torch.int32) # int32 + x_i32 = x_i32 + x_zero_point # int32 + # clamp works for fp32, int32 and int8 dtypes + x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32 + x_i8 = x_i32.to(dtype=torch.int8) + + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None) + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + out_fp32 = acc_i32 * (x_scale * weight_scale) + return out_fp32 + + _QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), torch.randn(1, dtype=torch.float), @@ -465,6 +531,27 @@ class _RewriteInfo: replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None _REWRITE_INFO_LIST = [ + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _qdq_dynamic_quantized_linear, + _reference_dynamic_quantized_linear, + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + -128: 1, + 127: 2, + torch.finfo(torch.float32).eps: 3 + } + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + -128: 1, + 127: 2, + torch.finfo(torch.float32).eps: 3 + } + ), + ), _RewriteInfo( _QUANTIZED_LINEAR_EXAMPLE_INPUTS, _qdq_quantized_linear,