diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index d8024c0245a..1d63a41f989 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -468,3 +468,8 @@ kernels: - arg_meta: null kernel_name: impl::generic::requantize_per_tensor_out + +- func: cadence::quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, Tensor indices, bool pruned_weights, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_embedding_byte_out diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 9266cc72970..2b78d81b156 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -320,7 +320,7 @@ "float out_scale, int out_zero_point) -> (Tensor Z)" ) lib.define( - "quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " + "quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " "Tensor indices, bool pruned_weights=False) -> (Tensor X)" ) lib.define( @@ -514,7 +514,7 @@ "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " + "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " "Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)" ) @@ -2310,6 +2310,28 @@ def transposed_im2row_meta( return input.new_empty(output_size, dtype=input.dtype) +@register_fake("cadence::quantized_embedding_byte") +def quantized_embedding_byte_meta( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: torch.Tensor | None, + indices: torch.Tensor, + pruned_weights: bool = False, +) -> torch.Tensor: + assert not pruned_weights + assert len(weight.shape) == 2 + assert 1 <= len(weight_scales.shape) <= 2 + if len(weight_scales.shape) == 2: + num_groups = weight_scales.shape[-1] + assert weight.shape[1] % num_groups == 0 + + if weight_zero_points is not None: + assert weight_zero_points.shape == weight_scales.shape + + assert 1 <= len(indices.shape) <= 2 + return torch.empty(*indices.shape, weight.shape[1], dtype=torch.float32) + + @register_fake("cadence::where_Scalar") def where_Scalar_meta( condition: torch.Tensor, diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index ad1abb3ce4b..4f612e3bab4 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1572,3 +1572,34 @@ def transposed_im2row( # Optionally, flatten to (N, num_patches, patch_size) if needed patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous() return patches + + +@impl(m, "quantized_embedding_byte") +def quantized_embedding_byte( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: torch.Tensor | None, + indices: torch.Tensor, + pruned_weights: bool = False, +) -> torch.Tensor: + if pruned_weights: + raise NotImplementedError("Pruned weights not supported") + + # Cannot use torch.ops.quantized_decomposed.embedding_byte.dtype because + # it doesn't support num_groups == 1 + num_groups = 1 + if len(weight_scales.shape) == 2: + num_groups = weight_scales.shape[1] + + group_size = weight.shape[1] // num_groups + weight = torch.ops.torchao.dequantize_affine.default( + input=weight, + block_size=(1, group_size), + scale=weight_scales, + zero_point=weight_zero_points, + input_dtype=weight.dtype, + quant_min=torch.iinfo(weight.dtype).min, + quant_max=torch.iinfo(weight.dtype).max, + ) + + return weight[indices] diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index d8a79454097..5856c9def66 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2306,3 +2306,116 @@ def test_transposed_im2row( torch.equal(output, expected_output), f"transposed_im2row output mismatch in {name}: got {output}, expected {expected_output}", ) + + @expand( + [ + ( + "1_group", + torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), + torch.tensor([1, 1, 1], dtype=torch.float32), + torch.tensor([0, 0, 0], dtype=torch.int8), + torch.tensor([0, 2, 1], dtype=torch.int64), + torch.tensor( + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], + dtype=torch.float32, + ), + ), + ( + "2_groups", + torch.tensor( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8 + ), + torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32), + torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8), + torch.tensor([0, 2, 1], dtype=torch.int64), + torch.tensor( + [ + [0.0, 0.5, 1.0, 2.0], + [10.0, 12.5, 15.0, 18.0], + [3.0, 4.5, 6.0, 8.0], + ], + dtype=torch.float32, + ), + ), + ( + "1_group_none_zero_point", + torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), + torch.tensor([1, 1, 1], dtype=torch.float32), + None, + torch.tensor([0, 2, 1], dtype=torch.int64), + torch.tensor( + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], + dtype=torch.float32, + ), + ), + ( + "1_group_batch2", + torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), + torch.tensor([1, 1, 1], dtype=torch.float32), + torch.tensor([0, 0, 0], dtype=torch.int8), + torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64), + torch.tensor( + [ + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], + [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]], + ], + dtype=torch.float32, + ), + ), + ( + "2_groups_batch2", + torch.tensor( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8 + ), + torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32), + torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8), + torch.tensor([[0, 2, 1], [2, 1, 0]], dtype=torch.int64), + torch.tensor( + [ + [ + [0.0, 0.5, 1.0, 2.0], + [10.0, 12.5, 15.0, 18.0], + [3.0, 4.5, 6.0, 8.0], + ], + [ + [10.0, 12.5, 15.0, 18.0], + [3.0, 4.5, 6.0, 8.0], + [0.0, 0.5, 1.0, 2.0], + ], + ], + dtype=torch.float32, + ), + ), + ( + "1_group_none_zero_point_batch2", + torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), + torch.tensor([1, 1, 1], dtype=torch.float32), + None, + torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64), + torch.tensor( + [ + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], + [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]], + ], + dtype=torch.float32, + ), + ), + ] + ) + def test_quantized_embedding_byte( + self, + name: str, + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: torch.Tensor | None, + indices: torch.Tensor, + expected_out: torch.Tensor, + ) -> None: + self.assertTrue( + torch.equal( + torch.ops.cadence.quantized_embedding_byte( + weight, weight_scales, weight_zero_points, indices + ), + expected_out, + ) + )