diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 483d8f18241..2a53c2dde7a 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -241,7 +241,7 @@ def quantized_linear_common( bias: torch.Tensor, in_zero_point: int, weight_zero_point: torch.Tensor | int, - out_multiplier: torch.Tensor | int, + out_multiplier: int, out_shift: int, out_zero_point: int, ) -> torch.Tensor: @@ -329,34 +329,30 @@ def variant( assert isinstance(weight_zero_point, int) assert isinstance(out_multiplier, int) assert isinstance(out_shift, int) - return quantized_linear_common( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - ) + _out_shift = out_shift + _out_multiplier = out_multiplier else: assert isinstance(out_shift, torch.Tensor) + assert isinstance(out_multiplier, torch.Tensor) if out_shift.numel() != 1: raise ValueError("out_shift must be a scalar") if out_shift.dtype != torch.int64: raise ValueError("out_shift must be an int64") - return quantized_linear_common( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - int(out_shift.item()), - out_zero_point, - ) + _out_shift = int(out_shift.item()) + _out_multiplier = int(out_multiplier[0].item()) + + return quantized_linear_common( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + _out_multiplier, + _out_shift, + out_zero_point, + ) return variant @@ -403,6 +399,112 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... +@impl(m, "quantized_matmul") +def quantized_matmul( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: torch.Tensor | None, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + """ + Quantized matmul operation. + + Args: + - X (Tensor): The activations tensor + - X_zero_point (int): The quantized mapping of zero for the input + - Y (Tensor): The weight tensor + - Y_zero_point (int): The quantized mapping of zero for the weight + - bias (Tensor): The bias tensor + - out_multiplier (int): The multiplier used to scale the output + - out_shift (int): The shift used to scale the output + - out_zero_point (int): The quantized mapping of zero for the output + - transposed (bool): Whether to transpose the weight tensor + """ + if bias is not None and not torch.all(bias == 0): + raise ValueError("bias must be None or all zeros since unused in out variant") + + # Looks weird, but quantized linear assumes weights are pre-transposed, + # hence we transpose only if `transposed` is False. + if not transposed: + Y = Y.T + + return quantized_linear_common( + X, + Y, + bias or torch.zeros(1, dtype=torch.int32), + X_zero_point, + Y_zero_point, + out_multiplier, + out_shift, + out_zero_point, + ) + + +@impl(m, "quantized_matmul_asym8sxasym8s_asym8s") +def quantized_matmul_asym8sxasym8s_asym8s( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: torch.Tensor | None, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + if X.dtype != torch.int8: + raise ValueError("X dtype must be torch.int8") + if Y.dtype != torch.int8: + raise ValueError("Y dtype must be torch.int8") + + return quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + ) + + +@impl(m, "quantized_matmul_asym8uxasym8u_asym8u") +def quantized_matmul_asym8uxasym8u_asym8u( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: torch.Tensor | None, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + if X.dtype != torch.uint8: + raise ValueError("X dtype must be torch.uint8") + if Y.dtype != torch.uint8: + raise ValueError("Y dtype must be torch.uint8") + + return quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + ) + + @impl(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 03de587c3be..30b30e085dc 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -177,6 +177,8 @@ def test_quantized_add( 0, # out_zero_point torch.tensor([[-2]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), @@ -200,6 +202,8 @@ def test_quantized_add( 0, # out_zero_point torch.tensor([[-10, -30]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), @@ -225,6 +229,8 @@ def test_quantized_add( [[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype ), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), @@ -248,6 +254,8 @@ def test_quantized_add( 1, # out_zero_point torch.tensor([[-15, 25]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), @@ -271,6 +279,8 @@ def test_quantized_add( 1, # out_zero_point torch.tensor([[-23, 17]], dtype=dtype), # expected_output False, + False, + False, ) for dtype in (torch.int8, torch.uint8) ], @@ -292,9 +302,34 @@ def test_quantized_add( 1, # out_zero_point torch.tensor([[-7, 13]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8)) ], + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2] + ), # weight_shape: 2 output features, 2 input features + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (0.125 * 2^31) + torch.tensor( + [1], dtype=torch.int64 + ), # out_shift (shift=1, doubles the scale) + 1, # out_zero_point + torch.tensor([[-7, 17]], dtype=dtype), # expected_output + per_tensor, + matmul, + transposed_matmul, + ) + for (matmul, transposed_matmul) in ((True, False), (True, True)) + for (per_tensor, dtype) in ((True, torch.int8), (True, torch.uint8)) + ], ] ) def test_quantized_linear( @@ -308,7 +343,12 @@ def test_quantized_linear( out_zero_point: int, expected_output: torch.Tensor, per_tensor: bool, + matmul: bool, + transposed_matmul: bool, ) -> None: + if not per_tensor and matmul: + self.skipTest("Only per_tensor supported for matmul") + src = ( torch.arange(np.prod(src_shape)) .reshape(src_shape) @@ -319,7 +359,9 @@ def test_quantized_linear( .reshape(weight_shape) .to(expected_output.dtype) ) - bias = torch.arange(weight_shape[0]).to(torch.int32) + if matmul and not transposed_matmul: + weight = weight.T + if per_tensor: weight_zero_point = weight_zero_point[0] out_multiplier = out_multiplier[0] @@ -328,20 +370,34 @@ def test_quantized_linear( if per_tensor: match expected_output.dtype: case torch.int8: - linear_ops = ( - torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, - torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, - ) + if matmul: + linear_ops = ( + # Doesn't have per tensor name, but it is per tensor + torch.ops.cadence.quantized_matmul_asym8sxasym8s_asym8s, + ) + else: + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, + ) case torch.uint8: - linear_ops = ( - torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, - torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, - ) + if matmul: + linear_ops = ( + torch.ops.cadence.quantized_matmul_asym8uxasym8u_asym8u, + ) + else: + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, + ) case _: - linear_ops = ( - torch.ops.cadence.quantized_linear.per_tensor, - torch.ops.cadence.quantized_fully_connected.per_tensor, - ) + if matmul: + linear_ops = (torch.ops.cadence.quantized_matmul,) + else: + linear_ops = ( + torch.ops.cadence.quantized_linear.per_tensor, + torch.ops.cadence.quantized_fully_connected.per_tensor, + ) else: linear_ops = ( torch.ops.cadence.quantized_linear, @@ -349,17 +405,40 @@ def test_quantized_linear( ) for linear_op in linear_ops: - output = linear_op( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - typing.cast(torch.Tensor, None), + # Get the function name for linear_op for debugging + op_name = ( + linear_op.__name__ if hasattr(linear_op, "__name__") else str(linear_op) ) + if matmul: + assert "quantized_matmul" in op_name + output = linear_op( + src, + in_zero_point, + weight, + weight_zero_point, + None, + out_multiplier, + out_shift, + out_zero_point, + transposed_matmul, + ) + else: + assert ( + "quantized_linear" in op_name + or "quantized_fully_connected" in op_name + ) + bias = torch.arange(weight_shape[0]).to(torch.int32) + output = linear_op( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + typing.cast(torch.Tensor, None), + ) self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")