Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# pyre-strict

import logging
from math import prod
from typing import Callable, Optional, Tuple

Expand Down Expand Up @@ -49,36 +48,16 @@ def _validate_ref_impl_exists() -> None:
"cadence::roi_align_box_processor",
}

# All of these should either
# 1. be removed
# 2. have a reference implementation added to ref_implementations.py
_WARN_ONLY = {
"cadence::quantized_softmax.per_tensor",
"cadence::quantized_softmax",
}

ref_impls = get_registered_ref_implementations()
warn_impls = []
error_impls = []
for op_name in _REGISTERED_META_KERNELS:
# Strip the namespace prefix if present (e.g., "cadence::" -> "")
op_name_clean = op_name.split("::")[-1] if "::" in op_name else op_name

if op_name_clean not in ref_impls:
if op_name in _WARN_ONLY:
warn_impls.append(op_name)
elif op_name not in _SKIP_OPS:
if op_name not in _SKIP_OPS:
error_impls.append(op_name)

if warn_impls:
warn_msg = (
f"The following {len(warn_impls)} meta kernel registrations are missing reference implementations:\n"
+ "\n".join(f" - {op}" for op in warn_impls)
+ "\n\nPlease add reference implementations in ref_implementations.py using "
+ "@impl_tracked(m, '<op_name>')."
)
logging.warning(warn_msg)

if error_impls:
error_msg = (
f"The following {len(error_impls)} meta kernel registrations are missing reference implementations:\n"
Expand Down
92 changes: 92 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,3 +2054,95 @@ def softmax_f32_f32(
assert input_tensor.dtype == torch.float32, "input_tensor must be float32"
assert not half_to_float, "half_to_float is not supported"
return torch.nn.functional.softmax(input_tensor, dim=dim, dtype=torch.float32)


def quantized_softmax_per_tensor_common(
input_tensor: torch.Tensor,
mask: torch.Tensor | None,
dim: int,
in_scale: float,
in_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
"""
Quantized softmax operation.

Args:
- input_tensor (Tensor): The quantized input tensor
- mask (Tensor): Mask tensor
- dim (int): The dimension along which softmax is computed
- in_scale (float): The scale of the input quantization
- in_zero_point (int): The zero point of the input quantization
- out_scale (float): The scale of the output quantization
- out_zero_point (int): The zero point of the output quantization
"""
# TODO: T228751479 - Add support for mask parameter in softmax
assert mask is None
supported_dtypes = [torch.int8, torch.uint8, torch.int16]
if input_tensor.dtype not in supported_dtypes:
raise ValueError(
f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}"
)

float_input_tensor = dequantize_per_tensor(
input_tensor,
in_scale,
in_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
input_tensor.dtype,
)

softmax_output = torch.nn.functional.softmax(float_input_tensor, dim=dim)

return quantize_per_tensor(
softmax_output,
out_scale,
out_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
input_tensor.dtype,
)


@impl_tracked(m, "quantized_softmax.per_tensor")
def quantized_softmax_per_tensor(
input_tensor: torch.Tensor,
mask: torch.Tensor | None,
dim: int,
in_scale: float,
in_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
return quantized_softmax_per_tensor_common(
input_tensor,
mask,
dim,
in_scale,
in_zero_point,
out_scale,
out_zero_point,
)


@impl_tracked(m, "quantized_softmax")
def quantized_softmax(
input_tensor: torch.Tensor,
mask: torch.Tensor | None,
dim: int,
in_scale: torch.Tensor,
in_zero_point: torch.Tensor,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
return quantized_softmax_per_tensor_common(
input_tensor,
mask,
dim,
float(in_scale.item()),
int(in_zero_point.item()),
out_scale,
out_zero_point,
)
132 changes: 132 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3079,3 +3079,135 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
self.assertIn(
"Hidden dimension must be a multiple of 4", str(context.exception)
)

@expand(
[
(
"basic_int8_dim_1",
torch.tensor([[10, 20, 30]], dtype=torch.int8),
None,
1,
0.1,
0,
0.004,
0,
torch.int8,
torch.tensor([[23, 61, 127]], dtype=torch.int8),
),
(
"uint8_with_zero_points",
torch.tensor([[128, 130, 132]], dtype=torch.uint8),
None,
1,
0.1,
128,
0.004,
128,
torch.uint8,
torch.tensor([[195, 210, 228]], dtype=torch.uint8),
),
(
"basic_int16",
torch.tensor([[100, 200, 300]], dtype=torch.int16),
None,
1,
0.01,
0,
0.004,
0,
torch.int16,
torch.tensor([[23, 61, 166]], dtype=torch.int16),
),
(
"multi_row_int8",
torch.tensor([[10, 20, 30], [5, 10, 15]], dtype=torch.int8),
None,
1,
0.1,
0,
0.004,
0,
torch.int8,
torch.tensor([[23, 61, 127], [47, 77, 127]], dtype=torch.int8),
),
(
"softmax_dim_0",
torch.tensor([[10, 20], [30, 40]], dtype=torch.int8),
None,
0,
0.1,
0,
0.004,
0,
torch.int8,
torch.tensor([[30, 30], [127, 127]], dtype=torch.int8),
),
]
)
def test_quantized_softmax_per_tensor(
self,
name: str,
input_tensor: torch.Tensor,
mask: torch.Tensor | None,
dim: int,
in_scale: float,
in_zero_point: int,
out_scale: float,
out_zero_point: int,
dtype: torch.dtype,
expected_output: torch.Tensor,
) -> None:
output = torch.ops.cadence.quantized_softmax.per_tensor(
input_tensor,
mask,
dim,
in_scale,
in_zero_point,
out_scale,
out_zero_point,
)

# Verify output properties
self.assertEqual(
output.dtype, dtype, f"Output dtype should be {dtype} in {name}"
)
self.assertEqual(
output.shape,
input_tensor.shape,
f"Output shape should match input shape in {name}",
)

# Verify output matches expected values (allowing for small quantization errors)
# For softmax, we expect outputs to be in [0, 1] range when dequantized
self.assertTrue(
torch.allclose(
output.to(torch.float32),
expected_output.to(torch.float32),
rtol=0.05,
atol=5.0,
),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)

def test_quantized_softmax(self) -> None:
# Test quantized_softmax (default variant with tensor scale/zero_point)
input_tensor = torch.tensor([[10, 20, 30]], dtype=torch.int8)
in_scale = torch.tensor([0.1])
in_zero_point = torch.tensor([0])
output = torch.ops.cadence.quantized_softmax(
input_tensor,
None, # mask
1, # dim
in_scale,
in_zero_point,
0.004, # out_scale
0, # out_zero_point
)

# Verify output properties
self.assertEqual(output.dtype, torch.int8, "Output dtype should be int8")
self.assertEqual(
output.shape,
input_tensor.shape,
"Output shape should match input shape",
)
Loading