diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index f496b23a82b..40ae6d23085 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -748,13 +748,12 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tens def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_relu") -def quantized_relu( +def quantized_relu_common( X: torch.Tensor, - X_zero_point: torch.Tensor, + X_zero_point: torch.Tensor | int, out_zero_point: int, - out_multiplier: torch.Tensor, - out_shift: torch.Tensor, + out_multiplier: int, + out_shift: int, ) -> torch.Tensor: """ Quantized ReLU operation followed by requantization. @@ -770,7 +769,7 @@ def quantized_relu( if X.dtype not in supported_dtypes: raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}") - out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0]) + out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift) dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X)) return quantize_per_tensor( dequantized_X, @@ -782,6 +781,79 @@ def quantized_relu( ) +def quantized_relu_variant( + per_tensor: bool, + dtype: torch.dtype | None = None, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Create a quantized relu variant with type checking.""" + + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + X: torch.Tensor, + X_zero_point: torch.Tensor | int, + out_zero_point: int, + out_multiplier: torch.Tensor | int, + out_shift: torch.Tensor | int, + ) -> torch.Tensor: + if per_tensor: + if dtype and X.dtype != dtype: + raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}") + + assert isinstance(out_shift, int) + assert isinstance(out_multiplier, int) + _out_shift = out_shift + _out_multiplier = out_multiplier + else: + assert isinstance(out_multiplier, torch.Tensor) + if out_multiplier.numel() > 1: + raise ValueError("Only scalar out_multiplier is supported") + + assert isinstance(out_shift, torch.Tensor) + if out_shift.numel() > 1: + raise ValueError("Only scalar out_shift is supported") + + assert isinstance(X_zero_point, torch.Tensor) + if X_zero_point.shape != X.shape: + raise ValueError( + f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}" + ) + + _out_multiplier = int(out_multiplier.item()) + _out_shift = int(out_shift.item()) + + return quantized_relu_common( + X, + X_zero_point, + out_zero_point, + _out_multiplier, + _out_shift, + ) + + return variant + + return decorator + + +@impl(m, "quantized_relu") +@quantized_relu_variant(False) +def quantized_relu() -> torch.Tensor: ... + + +@impl(m, "quantized_relu.per_tensor") +@quantized_relu_variant(True) +def quantized_relu_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_relu_asym8s_asym8s.per_tensor") +@quantized_relu_variant(True, torch.int8) +def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_relu_asym8u_asym8u.per_tensor") +@quantized_relu_variant(True, torch.uint8) +def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ... + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index a08327a7646..04b3e8e75ba 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -884,73 +884,124 @@ def test_quantized_conv_per_tensor( @expand( [ # Test case 1: Basic int8 case with negative scale - ( - "basic_int8", - torch.tensor([-1, 0, 1, 3], dtype=torch.int8), # input - torch.tensor([0], dtype=torch.int8), # X_zero_point (scalar broadcast) - 0, # out_zero_point - torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) - torch.tensor([0]), # out_shift - torch.int8, # dtype - torch.tensor( - [0, 0, 0, -2], dtype=torch.int8 - ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2) - ), + *[ + ( + "basic_int8", + torch.tensor([-1, 0, 1, 3], dtype=dtype), # input + 0, # X_zero_point (scalar broadcast) + 0, # out_zero_point + 1073741824, # out_multiplier (0.5 * 2^31) + 0, # out_shift + dtype, # dtype + torch.tensor( + [0, 0, 0, -2], dtype=dtype + ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2) + ) + for dtype in [torch.int8] + ], # Test case 2: uint8 with non-zero zero point - ( - "uint8_with_zp", - torch.tensor([126, 128, 130, 132], dtype=torch.uint8), # input - torch.tensor([128], dtype=torch.uint8), # X_zero_point - 64, # out_zero_point - torch.tensor([536870912]), # out_multiplier (0.25 * 2^31) - torch.tensor([0]), # out_shift - torch.uint8, # dtype - torch.tensor( - [64, 64, 64, 63], dtype=torch.uint8 - ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63) - ), + *[ + ( + "uint8_with_zp", + torch.tensor([126, 128, 130, 132], dtype=dtype), # input + 128, # X_zero_point + 64, # out_zero_point + 536870912, # out_multiplier (0.25 * 2^31) + 0, # out_shift + dtype, # dtype + torch.tensor( + [64, 64, 64, 63], dtype=dtype + ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63) + ) + for dtype in [torch.uint8] + ], # Test case 3: All negative values (should all become zero after ReLU) - ( - "all_negative_int8", - torch.tensor([-5, -3, -1], dtype=torch.int8), # input - torch.tensor([0], dtype=torch.int8), # X_zero_point - 10, # out_zero_point - torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) - torch.tensor([0]), # out_shift - torch.int8, # dtype - torch.tensor( - [10, 10, 10], dtype=torch.int8 - ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10) - ), + *[ + ( + "all_negative_int8", + torch.tensor([-5, -3, -1], dtype=dtype), # input + 0, # X_zero_point + 10, # out_zero_point + 1073741824, # out_multiplier (0.5 * 2^31) + 0, # out_shift + dtype, # dtype + torch.tensor( + [10, 10, 10], dtype=dtype + ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10) + ) + for dtype in [torch.int8] + ], # Test case 4: All positive values with shift (scale becomes -0.25) - ( - "positive_with_shift", - torch.tensor([2, 4, 6, 8], dtype=torch.int8), # input - torch.tensor([1], dtype=torch.int8), # X_zero_point - 5, # out_zero_point - torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) - torch.tensor([1]), # out_shift (multiply by 2^1 = 2) - torch.int8, # dtype - torch.tensor( - [4, 2, 0, -2], dtype=torch.int8 - ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2) - ), + *[ + ( + "positive_with_shift", + torch.tensor([2, 4, 6, 8], dtype=dtype), # input + 1, # X_zero_point + 5, # out_zero_point + 1073741824, # out_multiplier (0.5 * 2^31) + 1, # out_shift (multiply by 2^1 = 2) + dtype, # dtype + torch.tensor( + [4, 2, 0, -2], dtype=dtype + ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2) + ) + for dtype in [torch.int8, torch.uint8] + ], + # Test case 4: Non-per-tensor + *[ + ( + "non_per_tensor", + torch.tensor([-1, -2, -3, 1, 2, 3], dtype=dtype), # input + torch.tensor([0, 0, 0, 1, 1, 1]), # X_zero_point + 5, # out_zero_point + torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) + torch.tensor([1]), # out_shift (multiply by 2^1 = 2) + dtype, # dtype + torch.tensor([5, 5, 5, 5, 4, 3], dtype=dtype), + ) + for dtype in [torch.int8] + ], ] ) def test_quantized_relu( self, name: str, X: torch.Tensor, - X_zero_point: torch.Tensor, + X_zero_point: torch.Tensor | int, out_zero_point: int, - out_multiplier: torch.Tensor, - out_shift: torch.Tensor, + out_multiplier: torch.Tensor | int, + out_shift: torch.Tensor | int, dtype: torch.dtype, expected_output: torch.Tensor, ) -> None: - output = torch.ops.cadence.quantized_relu( - X, X_zero_point, out_zero_point, out_multiplier, out_shift - ) + + if isinstance(X_zero_point, int): + assert isinstance(out_multiplier, int) + assert isinstance(out_shift, int) + + match dtype: + case torch.int8: + quantized_relu = ( + torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor + ) + case torch.uint8: + quantized_relu = ( + torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor + ) + case _: + quantized_relu = torch.ops.cadence.quantized_relu_per_tensor + + output = quantized_relu( + X, + X_zero_point, + out_zero_point, + out_multiplier, + out_shift, + ) + else: + output = torch.ops.cadence.quantized_relu( + X, X_zero_point, out_zero_point, out_multiplier, out_shift + ) # Verify output properties self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")