diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 1d63a41f989..d8024c0245a 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -468,8 +468,3 @@ kernels: - arg_meta: null kernel_name: impl::generic::requantize_per_tensor_out - -- func: cadence::quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, Tensor indices, bool pruned_weights, *, Tensor(a!) out) -> Tensor(a!) - kernels: - - arg_meta: null - kernel_name: impl::generic::quantized_embedding_byte_out diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 2b78d81b156..9266cc72970 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -320,7 +320,7 @@ "float out_scale, int out_zero_point) -> (Tensor Z)" ) lib.define( - "quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " "Tensor indices, bool pruned_weights=False) -> (Tensor X)" ) lib.define( @@ -514,7 +514,7 @@ "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " "Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)" ) @@ -2310,28 +2310,6 @@ def transposed_im2row_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_embedding_byte") -def quantized_embedding_byte_meta( - weight: torch.Tensor, - weight_scales: torch.Tensor, - weight_zero_points: torch.Tensor | None, - indices: torch.Tensor, - pruned_weights: bool = False, -) -> torch.Tensor: - assert not pruned_weights - assert len(weight.shape) == 2 - assert 1 <= len(weight_scales.shape) <= 2 - if len(weight_scales.shape) == 2: - num_groups = weight_scales.shape[-1] - assert weight.shape[1] % num_groups == 0 - - if weight_zero_points is not None: - assert weight_zero_points.shape == weight_scales.shape - - assert 1 <= len(indices.shape) <= 2 - return torch.empty(*indices.shape, weight.shape[1], dtype=torch.float32) - - @register_fake("cadence::where_Scalar") def where_Scalar_meta( condition: torch.Tensor, diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 4f612e3bab4..ad1abb3ce4b 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1572,34 +1572,3 @@ def transposed_im2row( # Optionally, flatten to (N, num_patches, patch_size) if needed patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous() return patches - - -@impl(m, "quantized_embedding_byte") -def quantized_embedding_byte( - weight: torch.Tensor, - weight_scales: torch.Tensor, - weight_zero_points: torch.Tensor | None, - indices: torch.Tensor, - pruned_weights: bool = False, -) -> torch.Tensor: - if pruned_weights: - raise NotImplementedError("Pruned weights not supported") - - # Cannot use torch.ops.quantized_decomposed.embedding_byte.dtype because - # it doesn't support num_groups == 1 - num_groups = 1 - if len(weight_scales.shape) == 2: - num_groups = weight_scales.shape[1] - - group_size = weight.shape[1] // num_groups - weight = torch.ops.torchao.dequantize_affine.default( - input=weight, - block_size=(1, group_size), - scale=weight_scales, - zero_point=weight_zero_points, - input_dtype=weight.dtype, - quant_min=torch.iinfo(weight.dtype).min, - quant_max=torch.iinfo(weight.dtype).max, - ) - - return weight[indices] diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 5856c9def66..d8a79454097 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2306,116 +2306,3 @@ def test_transposed_im2row( torch.equal(output, expected_output), f"transposed_im2row output mismatch in {name}: got {output}, expected {expected_output}", ) - - @expand( - [ - ( - "1_group", - torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), - torch.tensor([1, 1, 1], dtype=torch.float32), - torch.tensor([0, 0, 0], dtype=torch.int8), - torch.tensor([0, 2, 1], dtype=torch.int64), - torch.tensor( - [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], - dtype=torch.float32, - ), - ), - ( - "2_groups", - torch.tensor( - [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8 - ), - torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32), - torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8), - torch.tensor([0, 2, 1], dtype=torch.int64), - torch.tensor( - [ - [0.0, 0.5, 1.0, 2.0], - [10.0, 12.5, 15.0, 18.0], - [3.0, 4.5, 6.0, 8.0], - ], - dtype=torch.float32, - ), - ), - ( - "1_group_none_zero_point", - torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), - torch.tensor([1, 1, 1], dtype=torch.float32), - None, - torch.tensor([0, 2, 1], dtype=torch.int64), - torch.tensor( - [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], - dtype=torch.float32, - ), - ), - ( - "1_group_batch2", - torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), - torch.tensor([1, 1, 1], dtype=torch.float32), - torch.tensor([0, 0, 0], dtype=torch.int8), - torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64), - torch.tensor( - [ - [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], - [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]], - ], - dtype=torch.float32, - ), - ), - ( - "2_groups_batch2", - torch.tensor( - [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8 - ), - torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32), - torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8), - torch.tensor([[0, 2, 1], [2, 1, 0]], dtype=torch.int64), - torch.tensor( - [ - [ - [0.0, 0.5, 1.0, 2.0], - [10.0, 12.5, 15.0, 18.0], - [3.0, 4.5, 6.0, 8.0], - ], - [ - [10.0, 12.5, 15.0, 18.0], - [3.0, 4.5, 6.0, 8.0], - [0.0, 0.5, 1.0, 2.0], - ], - ], - dtype=torch.float32, - ), - ), - ( - "1_group_none_zero_point_batch2", - torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), - torch.tensor([1, 1, 1], dtype=torch.float32), - None, - torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64), - torch.tensor( - [ - [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], - [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]], - ], - dtype=torch.float32, - ), - ), - ] - ) - def test_quantized_embedding_byte( - self, - name: str, - weight: torch.Tensor, - weight_scales: torch.Tensor, - weight_zero_points: torch.Tensor | None, - indices: torch.Tensor, - expected_out: torch.Tensor, - ) -> None: - self.assertTrue( - torch.equal( - torch.ops.cadence.quantized_embedding_byte( - weight, weight_scales, weight_zero_points, indices - ), - expected_out, - ) - )