From 376ea1917a349e204902d90d9edd33e5c92154d1 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Wed, 17 Sep 2025 12:56:40 -0700 Subject: [PATCH] Migrate old ref impl to new ref impl (#14357) Summary: Migrated old ref impl to new ref impl, fixed some interface bugs caught in migration Reviewed By: mcremon-meta Differential Revision: D82566217 --- backends/cadence/aot/TARGETS | 1 + backends/cadence/aot/ref_implementations.py | 88 +++++++++++-------- .../aot/tests/test_ref_implementations.py | 65 ++++++++++---- 3 files changed, 102 insertions(+), 52 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 0ec09bf4f9e..af1e052a68e 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -130,6 +130,7 @@ runtime.python_library( deps = [ "fbcode//caffe2:torch", "fbcode//executorch/exir:scalar_type", + "fbcode//executorch/kernels/quantized:custom_ops_generated_lib", ], ) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 2a53c2dde7a..c2e74e024f8 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -6,16 +6,17 @@ # pyre-strict - from typing import Callable import torch +import torch.nn as nn +import torch.nn.functional as F from executorch.exir.scalar_type import ScalarType from torch.library import impl, Library - m = Library("cadence", "IMPL", "CompositeExplicitAutograd") +torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") qdtype_map: dict[ScalarType, torch.dtype] = { ScalarType.QINT8: torch.qint8, @@ -38,7 +39,7 @@ def quantize_per_tensor( Args: - input_tensor (Tensor): input tensor - - scale (float): Inverse of quantization scale. Derived from the ratio + - scale (float): Quantization scale. Derived from the ratio between the min/max of the floating-point tensor and the min/max of the quantized range, and then inverted. - zero_point (int): The point which represents 0 in the quantized @@ -64,10 +65,13 @@ def quantize_per_tensor( f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" ) - quantized = torch.round(input_tensor * scale + zero_point).to(dtype) - return torch.max( - torch.min(quantized, torch.tensor(quant_max)), - torch.tensor(quant_min), + return torch.ops.quantized_decomposed.quantize_per_tensor( + input_tensor, + scale, + zero_point, + quant_min, + quant_max, + dtype, ) @@ -97,7 +101,7 @@ def dequantize_per_tensor( is already provided. - quant_max (int): The largest value in the quantized domain. Unused since scale is already provided. - - dtype (torch.dtype): The type of the output tensor. Must be a floating point type. + - dtype (torch.dtype): The type of the input tensor. """ supported_quant_types = [ torch.int8, @@ -108,23 +112,15 @@ def dequantize_per_tensor( ] if input_tensor.dtype not in supported_quant_types: raise ValueError(f"Input dtype must be one of {supported_quant_types}") - supported_dequant_types = [ - torch.float, - torch.float32, - torch.float16, - torch.bfloat16, - ] - if dtype not in supported_dequant_types: - raise ValueError( - f"Unsupported dtype to dequantize to. Supported dtypes must be one of {supported_dequant_types}" - ) - - # Needed to prevent underflow in cases where the zero_point is larger than - # the quantized value. - if not input_tensor.dtype.is_signed: - input_tensor = input_tensor.to(torch.int32) - - return (input_tensor - zero_point).to(dtype) * scale + if input_tensor.dtype != dtype: + raise ValueError("Input dtype must match dtype") + + # Use the reference implementation from torch quantized_decomposed library + # Unlike quantize_per_tensor, dequantize_per_tensor doesn't have a behavior + # difference, since there's no rounding algorithm (just arithmetic). + return torch.ops.quantized_decomposed.dequantize_per_tensor( + input_tensor, scale, zero_point, quant_min, quant_max, dtype + ) @impl(m, "quantized_add.per_tensor") @@ -180,12 +176,10 @@ def quantized_add_per_tensor( dequant_X = X_scale * (X - X_zero_point) dequant_Y = Y_scale * (Y - Y_zero_point) - out_scale_inv = 1 / out_scale - # q_min/q_max are unused args return quantize_per_tensor( dequant_X + dequant_Y, - out_scale_inv, + out_scale, out_zero_point, torch.iinfo(dtype).min, torch.iinfo(dtype).max, @@ -259,8 +253,7 @@ def quantized_linear_common( - out_zero_point (int): The quantized mapping of zero for the output - offset (Tensor): Unused """ - out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift) - out_scale_inv = 1 / out_scale + out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift)) N, K = weight.shape @@ -281,7 +274,7 @@ def quantized_linear_common( ) return quantize_per_tensor( out, - out_scale_inv, + out_scale, out_zero_point, torch.iinfo(dtype).min, torch.iinfo(dtype).max, @@ -399,6 +392,17 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... +@impl(m, "fully_connected") +def fully_connected( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + if input_tensor.shape[0] != 1: + raise ValueError("Fully connected linear only supports batch size of 1") + return F.linear(input_tensor, weight, bias) + + @impl(m, "quantized_matmul") def quantized_matmul( X: torch.Tensor, @@ -538,7 +542,7 @@ def quantized_layer_norm_per_tensor( ) float_input_tensor = dequantize_per_tensor( - input_tensor, X_scale, X_zero_point, -128, 127, torch.float32 + input_tensor, X_scale, X_zero_point, -128, 127, input_tensor.dtype ) out = torch.nn.functional.layer_norm( float_input_tensor, normalized_shape, weight, bias, eps=eps @@ -546,7 +550,7 @@ def quantized_layer_norm_per_tensor( return quantize_per_tensor( out, - 1 / output_scale, + output_scale, output_zero_point, torch.iinfo(input_tensor.dtype).min, torch.iinfo(input_tensor.dtype).max, @@ -615,7 +619,7 @@ def quantized_conv_per_tensor( return quantize_per_tensor( float_out, - 1.0 / output_scale, + output_scale, output_zero_point, torch.iinfo(input_tensor.dtype).min, torch.iinfo(input_tensor.dtype).max, @@ -942,8 +946,10 @@ def quantized_relu_common( if X.dtype not in supported_dtypes: raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}") - out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift) - dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X)) + out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift)) + dequantized_X = torch.where( + X > X_zero_point, X - X_zero_point, torch.zeros_like(X) + ).to(torch.float32) return quantize_per_tensor( dequantized_X, out_scale, @@ -1068,3 +1074,13 @@ def requantize( out_quant_max, dtype, ) + + +@impl(m, "rms_norm") +def rms_norm( + X: torch.Tensor, + normalized_shape: tuple[int], + W: torch.Tensor, + eps: float, +) -> torch.Tensor: + return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 30b30e085dc..88742437ac7 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -36,12 +36,11 @@ def test_quantize_per_tensor( ) -> None: input_tensor = torch.tensor([input_value]) scale = (f_max - f_min) / (q_max - q_min) - inv_scale = 1.0 / scale - zero_point = round(-f_min * inv_scale) + q_min + zero_point = round(-f_min * 1 / scale) + q_min expected_output = torch.tensor([expected_value], dtype=target_dtype) output = torch.ops.cadence.quantize_per_tensor( - input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype + input_tensor, scale, zero_point, q_min, q_max, target_dtype ) self.assertEqual( @@ -85,7 +84,7 @@ def test_dequantize_per_tensor( expected_output = torch.tensor([expected_value], dtype=torch.float32) output = torch.ops.cadence.dequantize_per_tensor( - input_tensor, scale, zero_point, q_min, q_max, torch.float32 + input_tensor, scale, zero_point, q_min, q_max, input_tensor.dtype ) self.assertEqual( @@ -175,7 +174,7 @@ def test_quantized_add( ), # out_multiplier (0.5 * 2^31) torch.tensor([0], dtype=torch.int64), # out_shift 0, # out_zero_point - torch.tensor([[-2]], dtype=dtype), # expected_output + torch.tensor([[0]], dtype=dtype), # expected_output per_tensor, False, False, @@ -200,7 +199,7 @@ def test_quantized_add( ), # out_multiplier (0.5 * 2^31) torch.tensor([0], dtype=torch.int64), # out_shift 0, # out_zero_point - torch.tensor([[-10, -30]], dtype=dtype), # expected_output + torch.tensor([[-2, -8]], dtype=dtype), # expected_output per_tensor, False, False, @@ -208,6 +207,28 @@ def test_quantized_add( for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), + ) + ], + *[ + ( + torch.Size([1, 3]), # src_shape: 1 sample, 3 input features + torch.Size( + [2, 3] + ), # weight_shape: 2 output features, 3 input features + 0, # in_zero_point + torch.tensor([0, 0, 0], dtype=dtype), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 0, # out_zero_point + torch.tensor([[0, 0]], dtype=dtype), # expected_output + per_tensor, + False, + False, + ) + for (per_tensor, dtype) in ( + (False, torch.uint8), (True, torch.uint8), ) ], @@ -226,7 +247,7 @@ def test_quantized_add( torch.tensor([0], dtype=torch.int64), # out_shift 0, # out_zero_point torch.tensor( - [[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype + [[[0, -2, -4], [-2, -7, -12]]], dtype=dtype ), # expected_output per_tensor, False, @@ -235,7 +256,6 @@ def test_quantized_add( for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), - (True, torch.uint8), ) ], # Test case 4: Non-zero zero points @@ -252,7 +272,7 @@ def test_quantized_add( ), # out_multiplier (1.0 * 2^31) torch.tensor([0], dtype=torch.int64), # out_shift 1, # out_zero_point - torch.tensor([[-15, 25]], dtype=dtype), # expected_output + torch.tensor([[1, 1]], dtype=dtype), # expected_output per_tensor, False, False, @@ -260,7 +280,7 @@ def test_quantized_add( for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), - (True, torch.uint8), + # (True, torch.uint8), ) ], # Test case 5: Non-uniform weight zero points @@ -277,12 +297,12 @@ def test_quantized_add( ), # out_multiplier (1.0 * 2^31) torch.tensor([0], dtype=torch.int64), # out_shift 1, # out_zero_point - torch.tensor([[-23, 17]], dtype=dtype), # expected_output + torch.tensor([[1, 1]], dtype=dtype), # expected_output False, False, False, ) - for dtype in (torch.int8, torch.uint8) + for dtype in (torch.int8,) ], # Test case 6: Non-zero out_shift (shift=1) *[ @@ -300,7 +320,7 @@ def test_quantized_add( [1], dtype=torch.int64 ), # out_shift (shift=1, doubles the scale) 1, # out_zero_point - torch.tensor([[-7, 13]], dtype=dtype), # expected_output + torch.tensor([[1, 2]], dtype=dtype), # expected_output per_tensor, False, False, @@ -322,13 +342,13 @@ def test_quantized_add( [1], dtype=torch.int64 ), # out_shift (shift=1, doubles the scale) 1, # out_zero_point - torch.tensor([[-7, 17]], dtype=dtype), # expected_output + torch.tensor([[1, 2]], dtype=dtype), # expected_output per_tensor, matmul, transposed_matmul, ) for (matmul, transposed_matmul) in ((True, False), (True, True)) - for (per_tensor, dtype) in ((True, torch.int8), (True, torch.uint8)) + for (per_tensor, dtype) in ((True, torch.int8),) ], ] ) @@ -1045,7 +1065,20 @@ def test_quantized_conv_per_tensor( [4, 2, 0, -2], dtype=dtype ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2) ) - for dtype in [torch.int8, torch.uint8] + for dtype in [torch.int8] + ], + *[ + ( + "positive_with_shift_unsigned", + torch.tensor([2, 4, 6, 8], dtype=dtype), # input + 1, # X_zero_point + 5, # out_zero_point + 1073741824, # out_multiplier (0.5 * 2^31) + 1, # out_shift (multiply by 2^1 = 2) + dtype, # dtype + torch.tensor([4, 2, 0, 0], dtype=dtype), + ) + for dtype in [torch.uint8] ], # Test case 4: Non-per-tensor *[