From 445c7a478ec360b55be34398707728b2f14b6db3 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Wed, 15 Oct 2025 10:35:38 -0700 Subject: [PATCH] Cadence: Warning if reference kernels not implemented for registered ops (#15130) Summary: I ran into a problem where some new ops were checked into ops_registrations.py without an associated ref implementation. Any current meta kernels that don't have a reference we will warn on, but anything new will error out if no reference is provided. Reviewed By: ethansfng Differential Revision: D84650725 --- backends/cadence/aot/TARGETS | 3 +- backends/cadence/aot/ops_registrations.py | 123 +++++++++++++++++++- backends/cadence/aot/ref_implementations.py | 118 +++++++++++-------- 3 files changed, 189 insertions(+), 55 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 4497b425557..b0e7101c9d2 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -117,6 +117,7 @@ runtime.python_library( ], deps = [ "fbcode//caffe2:torch", + "fbcode//executorch/backends/cadence/aot:ref_implementations", "fbcode//executorch/backends/cadence/aot:utils", ], ) @@ -425,7 +426,6 @@ python_unittest( "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/passes:lib", - ":ref_implementations", ], ) @@ -628,7 +628,6 @@ python_unittest( deps = [ ":typing_stubs", "//executorch/backends/cadence/aot:ops_registrations", - "//executorch/backends/cadence/aot:ref_implementations", "//caffe2:torch", ] ) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index f827488adfb..572a19ca872 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -6,8 +6,9 @@ # pyre-strict +import logging from math import prod -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch from executorch.backends.cadence.aot.utils import ( @@ -21,6 +22,113 @@ lib = Library("cadence", "DEF") +# Track meta kernels that have been registered +_REGISTERED_META_KERNELS: set[str] = set() + + +# Original register_fake function to use for registrations +_register_fake_original = register_fake + +_OUTPUTS_TYPE = torch.Tensor | tuple[torch.Tensor, ...] + + +def _validate_ref_impl_exists() -> None: + """ + Validates that all registered meta kernels have corresponding reference implementations. + This is called at module initialization time after both files have been imported. + """ + + # Import here after module initialization to ensure ref_implementations has been loaded + from executorch.backends.cadence.aot.ref_implementations import ( + get_registered_ref_implementations, + ) + + # If reference implementation should not be in + # executorch.backends.cadence.aot.ref_implementations, add here + _SKIP_OPS = { + "cadence::roi_align_box_processor", + } + + # All of these should either + # 1. be removed + # 2. have a reference implementation added to ref_implementations.py + _WARN_ONLY = { + "cadence::quantized_w8a32_linear", + "cadence::quantized_add", # We should only support per_tensor variant, should remove + "cadence::idma_store", + "cadence::idma_load", + "cadence::_softmax_f32_f32", + "cadence::requantize", # We should only support per_tensor variant, should remove + "cadence::quantized_softmax.per_tensor", + "cadence::quantize_per_tensor_asym8u", + "cadence::quantize_per_tensor_asym8s", + "cadence::dequantize_per_tensor_asym8u", + "cadence::dequantize_per_tensor_asym32s", + "cadence::dequantize_per_tensor_asym16u", + "cadence::linalg_vector_norm", + "cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove + "cadence::quantized_w8a32_conv", + "cadence::quantize_per_tensor_asym32s", + "cadence::quantized_relu", # We should only support per_tensor variant, should remove + "cadence::linalg_svd", + "cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove + "cadence::idma_copy", + "cadence::quantize_per_tensor_asym16u", + "cadence::dequantize_per_tensor_asym8s", + "cadence::quantize_per_tensor_asym16s", + "cadence::dequantize_per_tensor_asym16s", + "cadence::quantized_softmax", + "cadence::idma_wait", + "cadence::quantized_w8a32_gru", + "cadence::quantized_layer_norm", # We should only support per_tensor variant, should remove + } + + ref_impls = get_registered_ref_implementations() + warn_impls = [] + error_impls = [] + for op_name in _REGISTERED_META_KERNELS: + # Strip the namespace prefix if present (e.g., "cadence::" -> "") + op_name_clean = op_name.split("::")[-1] if "::" in op_name else op_name + + if op_name_clean not in ref_impls: + if op_name in _WARN_ONLY: + warn_impls.append(op_name) + elif op_name not in _SKIP_OPS: + error_impls.append(op_name) + + if warn_impls: + warn_msg = ( + f"The following {len(warn_impls)} meta kernel registrations are missing reference implementations:\n" + + "\n".join(f" - {op}" for op in warn_impls) + + "\n\nPlease add reference implementations in ref_implementations.py using " + + "@impl_tracked(m, '')." + ) + logging.warning(warn_msg) + + if error_impls: + error_msg = ( + f"The following {len(error_impls)} meta kernel registrations are missing reference implementations:\n" + + "\n".join(f" - {op}" for op in error_impls) + + "\n\nPlease add reference implementations in ref_implementations.py using " + + "@impl_tracked(m, '')." + ) + + raise RuntimeError(error_msg) + + +# Wrap register_fake to track all registrations +def register_fake( + op_name: str, +) -> Callable[[Callable[..., _OUTPUTS_TYPE]], Callable[..., _OUTPUTS_TYPE]]: + """ + Wrapped version of register_fake that tracks all meta kernel registrations. + This enables validation that all meta kernels have reference implementations. + """ + global _REGISTERED_META_KERNELS + _REGISTERED_META_KERNELS.add(op_name) + return _register_fake_original(op_name) + + lib.define( "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" ) @@ -2406,7 +2514,9 @@ def idma_load_impl( task_num: int = 0, channel: int = 0, ) -> torch.Tensor: - return copy_idma_copy_impl(src, task_num, channel) + res = copy_idma_copy_impl(src, task_num, channel) + assert isinstance(res, torch.Tensor) + return res @register_fake("cadence::idma_store") @@ -2415,7 +2525,9 @@ def idma_store_impl( task_num: int = 0, channel: int = 0, ) -> torch.Tensor: - return copy_idma_copy_impl(src, task_num, channel) + res = copy_idma_copy_impl(src, task_num, channel) + assert isinstance(res, torch.Tensor) + return res @register_fake("cadence::roi_align_box_processor") @@ -2671,3 +2783,8 @@ def quantized_w8a32_gru_meta( b_h_scale: float, ) -> torch.Tensor: return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype) + + +# Validate that all meta kernels have reference implementations +# This is called at module import time to catch missing implementations early +_validate_ref_impl_exists() diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index b91f585fb16..90f39089edc 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -18,6 +18,24 @@ m = Library("cadence", "IMPL", "CompositeExplicitAutograd") torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") +# Registry to track all ops with reference implementations +_REGISTERED_REF_IMPLEMENTATIONS: set[str] = set() + + +# Custom impl wrapper that tracks registrations +def impl_tracked( + lib: Library, op_name: str +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Wrapper around impl that tracks registered ops.""" + _REGISTERED_REF_IMPLEMENTATIONS.add(op_name) + return impl(lib, op_name) + + +def get_registered_ref_implementations() -> set[str]: + """Get all ops that have reference implementations.""" + return _REGISTERED_REF_IMPLEMENTATIONS.copy() + + qdtype_map: dict[ScalarType, torch.dtype] = { ScalarType.QINT8: torch.qint8, ScalarType.QUINT8: torch.quint8, @@ -25,7 +43,7 @@ } -@impl(m, "quantize_per_tensor") +@impl_tracked(m, "quantize_per_tensor") def quantize_per_tensor( input_tensor: torch.Tensor, scale: float, @@ -75,7 +93,7 @@ def quantize_per_tensor( ) -@impl(m, "dequantize_per_tensor") +@impl_tracked(m, "dequantize_per_tensor") def dequantize_per_tensor( input_tensor: torch.Tensor, scale: float, @@ -123,7 +141,7 @@ def dequantize_per_tensor( ) -@impl(m, "quantized_add.per_tensor") +@impl_tracked(m, "quantized_add.per_tensor") def quantized_add_per_tensor( X: torch.Tensor, X_scale: float, @@ -187,7 +205,7 @@ def quantized_add_per_tensor( ) -@impl(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor") def quantized_add_asym8sxasym8s_asym8s_per_tensor( X: torch.Tensor, X_scale: float, @@ -208,7 +226,7 @@ def quantized_add_asym8sxasym8s_asym8s_per_tensor( ) -@impl(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor") def quantized_add_asym8uxasym8u_asym8u_per_tensor( X: torch.Tensor, X_scale: float, @@ -352,47 +370,47 @@ def variant( return decorator -@impl(m, "quantized_linear") +@impl_tracked(m, "quantized_linear") @quantized_linear_variant(False, False) def quantized_linear() -> torch.Tensor: ... -@impl(m, "quantized_linear.per_tensor") +@impl_tracked(m, "quantized_linear.per_tensor") @quantized_linear_variant(True, False) def quantized_linear_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor") @quantized_linear_variant(True, False, torch.int8, torch.int8) def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor") @quantized_linear_variant(True, False, torch.uint8, torch.uint8) def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_fully_connected") +@impl_tracked(m, "quantized_fully_connected") @quantized_linear_variant(False, True) def quantized_fully_connected() -> torch.Tensor: ... -@impl(m, "quantized_fully_connected.per_tensor") +@impl_tracked(m, "quantized_fully_connected.per_tensor") @quantized_linear_variant(True, True) def quantized_fully_connected_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor") @quantized_linear_variant(True, True, torch.int8, torch.int8) def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor") @quantized_linear_variant(True, True, torch.uint8, torch.uint8) def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "fully_connected") +@impl_tracked(m, "fully_connected") def fully_connected( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -403,7 +421,7 @@ def fully_connected( return F.linear(input_tensor, weight, bias) -@impl(m, "quantized_matmul") +@impl_tracked(m, "quantized_matmul") def quantized_matmul( X: torch.Tensor, X_zero_point: int, @@ -451,7 +469,7 @@ def quantized_matmul( ) -@impl(m, "quantized_matmul_asym8sxasym8s_asym8s") +@impl_tracked(m, "quantized_matmul_asym8sxasym8s_asym8s") def quantized_matmul_asym8sxasym8s_asym8s( X: torch.Tensor, X_zero_point: int, @@ -481,7 +499,7 @@ def quantized_matmul_asym8sxasym8s_asym8s( ) -@impl(m, "quantized_matmul_asym8uxasym8u_asym8u") +@impl_tracked(m, "quantized_matmul_asym8uxasym8u_asym8u") def quantized_matmul_asym8uxasym8u_asym8u( X: torch.Tensor, X_zero_point: int, @@ -511,7 +529,7 @@ def quantized_matmul_asym8uxasym8u_asym8u( ) -@impl(m, "quantized_layer_norm.per_tensor") +@impl_tracked(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, X_scale: float, @@ -629,7 +647,7 @@ def quantized_conv_per_tensor( ) -@impl(m, "quantized_conv2d_nchw.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw.per_tensor") def quantized_conv2d_nchw_per_tensor( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -685,7 +703,7 @@ def quantized_conv2d_nchw_per_tensor( ) -@impl(m, "quantized_conv2d_nhwc.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor") def quantized_conv2d_nhwc_per_tensor( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -847,95 +865,95 @@ def variant( return decorator -@impl(m, "quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) def quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) def quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) def quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) def quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) def quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) def quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) def quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> ( torch.Tensor ): ... -@impl(m, "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) def quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> ( torch.Tensor ): ... -@impl(m, "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) def quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> ( torch.Tensor ): ... -@impl(m, "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> ( torch.Tensor ): ... -@impl(m, "quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8, is_1d=True) def quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8, is_1d=True) def quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8, is_1d=True) def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8, is_1d=True) def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "convolution") +@impl_tracked(m, "convolution") def convolution( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -981,7 +999,7 @@ def convolution( return conv_out -@impl(m, "transposed_convolution") +@impl_tracked(m, "transposed_convolution") def transposed_convolution( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -1039,7 +1057,7 @@ def transposed_convolution( return conv_out -@impl(m, "avg_pool2d") +@impl_tracked(m, "avg_pool2d") def avg_pool2d( input_tensor: torch.Tensor, kernel_size: tuple[int, int], @@ -1155,22 +1173,22 @@ def variant( return decorator -@impl(m, "quantized_relu.per_tensor") +@impl_tracked(m, "quantized_relu.per_tensor") @quantized_relu_variant() def quantized_relu_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_relu_asym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_relu_asym8s_asym8s.per_tensor") @quantized_relu_variant(torch.int8) def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_relu_asym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_relu_asym8u_asym8u.per_tensor") @quantized_relu_variant(torch.uint8) def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "requantize.per_tensor") +@impl_tracked(m, "requantize.per_tensor") def requantize_per_tensor( input: torch.Tensor, in_scale: float, @@ -1208,7 +1226,7 @@ def requantize_per_tensor( ) -@impl(m, "rms_norm") +@impl_tracked(m, "rms_norm") def rms_norm( X: torch.Tensor, normalized_shape: tuple[int], @@ -1218,7 +1236,7 @@ def rms_norm( return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) -@impl(m, "where_Scalar") +@impl_tracked(m, "where_Scalar") def where_Scalar( condition: torch.Tensor, if_true: float, @@ -1230,7 +1248,7 @@ def where_Scalar( return torch.where(condition, if_true, if_false) -@impl(m, "rope") +@impl_tracked(m, "rope") def rope( input_tensor: torch.Tensor, sin_tensor: torch.Tensor, @@ -1278,7 +1296,7 @@ def rope( return rotated.view(original_shape) -@impl(m, "im2row") +@impl_tracked(m, "im2row") def im2row( input_tensor: torch.Tensor, kernel_size: tuple[int, int], @@ -1370,7 +1388,7 @@ def im2row( return patches -@impl(m, "im2row.per_tensor") +@impl_tracked(m, "im2row.per_tensor") def im2row_per_tensor( input_tensor: torch.Tensor, kernel_size: tuple[int, int], @@ -1391,7 +1409,7 @@ def im2row_per_tensor( ) -@impl(m, "transposed_im2row") +@impl_tracked(m, "transposed_im2row") def transposed_im2row( input_tensor: torch.Tensor, kernel_size: tuple[int, int], @@ -1547,7 +1565,7 @@ def transposed_im2row( return patches -@impl(m, "quantized_embedding_byte") +@impl_tracked(m, "quantized_embedding_byte") def quantized_embedding_byte( weight: torch.Tensor, weight_scales: torch.Tensor,