From fa22fd2cbe915ff634c3edeaf214569a0cdf6343 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Fri, 10 Oct 2025 23:21:53 -0700 Subject: [PATCH] Support for batched matmul (#14956) Summary: Matmul was relying on linear infra which didn't support batched second argument. This adds support. Reviewed By: hsharma35 Differential Revision: D84279595 --- backends/cadence/aot/ref_implementations.py | 32 ++++++++++--------- .../aot/tests/test_ref_implementations.py | 25 ++++++++++++++- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 6a13a4424da..ed9bb438a9e 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -62,7 +62,7 @@ def quantize_per_tensor( ] if dtype not in supported_quant_types: raise ValueError( - f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" + f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_quant_types}" ) return torch.ops.quantized_decomposed.quantize_per_tensor( @@ -264,7 +264,7 @@ def quantized_linear_common( supported_dtypes = [torch.int8, torch.uint8, torch.int32] if dtype not in supported_dtypes: raise ValueError( - f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}" + f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_dtypes}" ) out = torch.nn.functional.linear( @@ -427,25 +427,27 @@ def quantized_matmul( - out_multiplier (int): The multiplier used to scale the output - out_shift (int): The shift used to scale the output - out_zero_point (int): The quantized mapping of zero for the output - - transposed (bool): Whether to transpose the weight tensor + - transposed (bool): Whether Y is transposed. """ if bias is not None and not torch.all(bias == 0): raise ValueError("bias must be None or all zeros since unused in out variant") - # Looks weird, but quantized linear assumes weights are pre-transposed, - # hence we transpose only if `transposed` is False. - if not transposed: - Y = Y.T + if transposed: + Y = Y.transpose(-1, -2) - return quantized_linear_common( - X, - Y, - bias or torch.zeros(1, dtype=torch.int32), - X_zero_point, - Y_zero_point, - out_multiplier, - out_shift, + out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift)) + + out = torch.matmul( + (X - X_zero_point).float(), + (Y - Y_zero_point).float(), + ) + return quantize_per_tensor( + out, + out_scale, out_zero_point, + torch.iinfo(X.dtype).min, + torch.iinfo(X.dtype).max, + X.dtype, ) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index f679bae9485..259752f3893 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -350,6 +350,29 @@ def test_quantized_add( for (matmul, transposed_matmul) in ((True, False), (True, True)) for (per_tensor, dtype) in ((True, torch.int8),) ], + *[ + ( + torch.Size([2, 1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2, 2] + ), # weight_shape: 2 output features, 2 input features + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (0.125 * 2^31) + torch.tensor( + [1], dtype=torch.int32 + ), # out_shift (shift=1, doubles the scale) + 1, # out_zero_point + torch.tensor([[[1, 2]], [[0, -1]]], dtype=dtype), # expected_output + per_tensor, + matmul, + transposed_matmul, + ) + for (matmul, transposed_matmul) in ((True, False), (True, True)) + for (per_tensor, dtype) in ((True, torch.int8),) + ], ] ) def test_quantized_linear( @@ -380,7 +403,7 @@ def test_quantized_linear( .to(expected_output.dtype) ) if matmul and not transposed_matmul: - weight = weight.T + weight = weight.transpose(-1, -2) if per_tensor: weight_zero_point = weight_zero_point[0]