Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ runtime.python_library(
],
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/backends/cadence/aot:ref_implementations",
"fbcode//executorch/backends/cadence/aot:utils",
],
)
Expand Down Expand Up @@ -425,7 +426,6 @@ python_unittest(
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:lib",
":ref_implementations",
],
)

Expand Down Expand Up @@ -628,7 +628,6 @@ python_unittest(
deps = [
":typing_stubs",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:ref_implementations",
"//caffe2:torch",
]
)
123 changes: 120 additions & 3 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

# pyre-strict

import logging
from math import prod
from typing import Optional, Tuple
from typing import Callable, Optional, Tuple

import torch
from executorch.backends.cadence.aot.utils import (
Expand All @@ -21,6 +22,113 @@

lib = Library("cadence", "DEF")

# Track meta kernels that have been registered
_REGISTERED_META_KERNELS: set[str] = set()


# Original register_fake function to use for registrations
_register_fake_original = register_fake

_OUTPUTS_TYPE = torch.Tensor | tuple[torch.Tensor, ...]


def _validate_ref_impl_exists() -> None:
"""
Validates that all registered meta kernels have corresponding reference implementations.
This is called at module initialization time after both files have been imported.
"""

# Import here after module initialization to ensure ref_implementations has been loaded
from executorch.backends.cadence.aot.ref_implementations import (
get_registered_ref_implementations,
)

# If reference implementation should not be in
# executorch.backends.cadence.aot.ref_implementations, add here
_SKIP_OPS = {
"cadence::roi_align_box_processor",
}

# All of these should either
# 1. be removed
# 2. have a reference implementation added to ref_implementations.py
_WARN_ONLY = {
"cadence::quantized_w8a32_linear",
"cadence::quantized_add", # We should only support per_tensor variant, should remove
"cadence::idma_store",
"cadence::idma_load",
"cadence::_softmax_f32_f32",
"cadence::requantize", # We should only support per_tensor variant, should remove
"cadence::quantized_softmax.per_tensor",
"cadence::quantize_per_tensor_asym8u",
"cadence::quantize_per_tensor_asym8s",
"cadence::dequantize_per_tensor_asym8u",
"cadence::dequantize_per_tensor_asym32s",
"cadence::dequantize_per_tensor_asym16u",
"cadence::linalg_vector_norm",
"cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove
"cadence::quantized_w8a32_conv",
"cadence::quantize_per_tensor_asym32s",
"cadence::quantized_relu", # We should only support per_tensor variant, should remove
"cadence::linalg_svd",
"cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove
"cadence::idma_copy",
"cadence::quantize_per_tensor_asym16u",
"cadence::dequantize_per_tensor_asym8s",
"cadence::quantize_per_tensor_asym16s",
"cadence::dequantize_per_tensor_asym16s",
"cadence::quantized_softmax",
"cadence::idma_wait",
"cadence::quantized_w8a32_gru",
"cadence::quantized_layer_norm", # We should only support per_tensor variant, should remove
}

ref_impls = get_registered_ref_implementations()
warn_impls = []
error_impls = []
for op_name in _REGISTERED_META_KERNELS:
# Strip the namespace prefix if present (e.g., "cadence::" -> "")
op_name_clean = op_name.split("::")[-1] if "::" in op_name else op_name

if op_name_clean not in ref_impls:
if op_name in _WARN_ONLY:
warn_impls.append(op_name)
elif op_name not in _SKIP_OPS:
error_impls.append(op_name)

if warn_impls:
warn_msg = (
f"The following {len(warn_impls)} meta kernel registrations are missing reference implementations:\n"
+ "\n".join(f" - {op}" for op in warn_impls)
+ "\n\nPlease add reference implementations in ref_implementations.py using "
+ "@impl_tracked(m, '<op_name>')."
)
logging.warning(warn_msg)

if error_impls:
error_msg = (
f"The following {len(error_impls)} meta kernel registrations are missing reference implementations:\n"
+ "\n".join(f" - {op}" for op in error_impls)
+ "\n\nPlease add reference implementations in ref_implementations.py using "
+ "@impl_tracked(m, '<op_name>')."
)

raise RuntimeError(error_msg)


# Wrap register_fake to track all registrations
def register_fake(
op_name: str,
) -> Callable[[Callable[..., _OUTPUTS_TYPE]], Callable[..., _OUTPUTS_TYPE]]:
"""
Wrapped version of register_fake that tracks all meta kernel registrations.
This enables validation that all meta kernels have reference implementations.
"""
global _REGISTERED_META_KERNELS
_REGISTERED_META_KERNELS.add(op_name)
return _register_fake_original(op_name)


lib.define(
"quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
)
Expand Down Expand Up @@ -2406,7 +2514,9 @@ def idma_load_impl(
task_num: int = 0,
channel: int = 0,
) -> torch.Tensor:
return copy_idma_copy_impl(src, task_num, channel)
res = copy_idma_copy_impl(src, task_num, channel)
assert isinstance(res, torch.Tensor)
return res


@register_fake("cadence::idma_store")
Expand All @@ -2415,7 +2525,9 @@ def idma_store_impl(
task_num: int = 0,
channel: int = 0,
) -> torch.Tensor:
return copy_idma_copy_impl(src, task_num, channel)
res = copy_idma_copy_impl(src, task_num, channel)
assert isinstance(res, torch.Tensor)
return res


@register_fake("cadence::roi_align_box_processor")
Expand Down Expand Up @@ -2671,3 +2783,8 @@ def quantized_w8a32_gru_meta(
b_h_scale: float,
) -> torch.Tensor:
return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype)


# Validate that all meta kernels have reference implementations
# This is called at module import time to catch missing implementations early
_validate_ref_impl_exists()
Loading
Loading