Skip to content
3 changes: 3 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ python_library(
],
deps = [
":utils",
":ops_registrations",
":ref_implementations",
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand Down Expand Up @@ -614,6 +616,7 @@ python_unittest(
typing = True,
deps = [
":typing_stubs",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:ref_implementations",
"//caffe2:torch",
]
Expand Down
68 changes: 66 additions & 2 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,11 @@ def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.int8
and weight.dtype == torch.int8
and bias.dtype == torch.int32
)
out_channels, _, *kernel_size = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -917,6 +922,11 @@ def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.uint8
and weight.dtype == torch.uint8
and bias.dtype == torch.int32
)
out_channels, _, *kernel_size = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -961,6 +971,11 @@ def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.int8
and weight.dtype == torch.int8
and bias.dtype == torch.int32
)
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1005,6 +1020,11 @@ def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.uint8
and weight.dtype == torch.uint8
and bias.dtype == torch.int32
)
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1049,6 +1069,11 @@ def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.int8
and weight.dtype == torch.int8
and bias.dtype == torch.int32
)
out_channels, _, *kernel_size = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1093,6 +1118,11 @@ def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.uint8
and weight.dtype == torch.uint8
and bias.dtype == torch.int32
)
out_channels, _, *kernel_size = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1137,6 +1167,11 @@ def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.int8
and weight.dtype == torch.int8
and bias.dtype == torch.int32
)
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1181,6 +1216,11 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.uint8
and weight.dtype == torch.uint8
and bias.dtype == torch.int32
)
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1225,6 +1265,11 @@ def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.int8
and weight.dtype == torch.int8
and bias.dtype == torch.int32
)
out_channels, _, *kernel_size = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1269,6 +1314,11 @@ def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.uint8
and weight.dtype == torch.uint8
and bias.dtype == torch.int32
)
out_channels, _, *kernel_size = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1313,6 +1363,11 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.int8
and weight.dtype == torch.int8
and bias.dtype == torch.int32
)
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1357,6 +1412,11 @@ def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert (
input.dtype == torch.uint8
and weight.dtype == torch.uint8
and bias.dtype == torch.int32
)
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
Expand Down Expand Up @@ -1389,7 +1449,7 @@ def quantized_layer_norm_meta(
input: torch.Tensor,
X_scale: torch.Tensor,
X_zero_point: torch.Tensor,
normalized_shape: int,
normalized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
Expand All @@ -1404,7 +1464,7 @@ def quantized_layer_norm_per_tensor_meta(
input: torch.Tensor,
X_scale: float,
X_zero_point: int,
normalized_shape: int,
normalized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
Expand Down Expand Up @@ -1711,6 +1771,7 @@ def quantized_fully_connected_meta(
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
assert src.shape[0] == 1
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
Expand All @@ -1733,6 +1794,7 @@ def quantized_fully_connected_per_tensor_meta(
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
assert src.shape[0] == 1
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
Expand All @@ -1755,6 +1817,7 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta(
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
assert src.shape[0] == 1
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
Expand All @@ -1777,6 +1840,7 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta(
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
assert src.shape[0] == 1
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
Expand Down
131 changes: 129 additions & 2 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict

from dataclasses import dataclass
from typing import Callable, List, Optional, Set, Type, Union
from functools import partial
from operator import attrgetter
from torch.utils._python_dispatch import _disable_current_modes

from typing import Any, Callable, cast, List, Optional, Set, Type, Union

import executorch.backends.cadence.aot.ops_registrations # noqa
import executorch.backends.cadence.aot.ref_implementations # noqa

import torch
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
Expand All @@ -16,6 +22,8 @@
from executorch.exir.pass_base import PassBase, PassResult

from torch._ops import OpOverloadPacket
from torch.fx import GraphModule
from torch.utils._pytree import PyTree


# Is an overlap in tensor lifetime and storage allowed at the current opt level?
Expand Down Expand Up @@ -114,6 +122,125 @@ def op_counts_match(
return False
return True

def validate_pass(

) -> Callable[[type[PassBase]], type[PassBase]]:
tolerance = 1e-5
log_differences = False
fail_on_mismatch = True

def decorator(pass_class: type[PassBase]) -> type[PassBase]:
class WrappedPass(pass_class):
def call(self, graph_module: GraphModule) -> PassResult:
# Ensure we're not in fake tensor mode for actual execution
with _disable_current_modes():
# Get inputs for the graph module
original_inputs = self._get_concrete_inputs(graph_module)

if original_inputs is None:
raise RuntimeError("Could not extract concrete inputs for {pass_class.__name__}")

# Run original graph and collect outputs
with torch.no_grad():
original_outputs = graph_module(*original_inputs)

# Apply the transformation
result = super().call(graph_module)

# Run transformed graph and collect outputs
with torch.no_grad():
transformed_outputs = result.graph_module(*original_inputs)

# Compare outputs
self._compare_outputs(
original_outputs,
transformed_outputs,
pass_class.__name__,
tolerance,
log_differences,
fail_on_mismatch
)

return result

def _get_concrete_inputs(self, graph_module: GraphModule) -> Optional[List[torch.Tensor]]:
"""Extract concrete tensor inputs from the graph module metadata."""
inputs = []
for node in graph_module.graph.nodes:
if node.op == "placeholder":
if "val" in node.meta:
val = node.meta["val"]
if hasattr(val, "constant") and val.constant is not None:
inputs.append(val.constant.detach().clone())
elif isinstance(val, torch.Tensor):
# Create a concrete tensor with the same properties
concrete_tensor = torch.testing.make_tensor(val.shape, dtype=val.dtype, device='cpu')
# concrete_tensor = torch.randn(val.shape, dtype=val.dtype)
if hasattr(val, 'device'):
concrete_tensor = concrete_tensor.to(val.device)
inputs.append(concrete_tensor)
else:
raise ValueError(f"Unsupported type for {node.name}: {type(val)}")
else:
raise ValueError(f"Missing 'val' in node metadata for {node.name}")
return inputs

def _compare_outputs(
self,
original: Any,
transformed: Any,
pass_name: str,
tolerance: float,
log_differences: bool,
fail_on_mismatch: bool
) -> None:
"""Compare outputs and optionally log/fail on differences."""
if isinstance(original, torch.Tensor) and isinstance(transformed, torch.Tensor):
if not torch.allclose(original, transformed, atol=tolerance, rtol=tolerance):
max_diff = torch.max(torch.abs(original - transformed)).item()
message = f"{pass_name}: Output mismatch detected. Max difference: {max_diff}"

if log_differences:
pass
# logging.warning(message)
# logging.warning(f"Original shape: {original.shape}, Transformed shape: {transformed.shape}")

if fail_on_mismatch:
raise ValueError(message)
else:
if log_differences:
pass
# logging.info(f"{pass_name}: Outputs match within tolerance {tolerance}")

elif isinstance(original, (list, tuple)) and isinstance(transformed, (list, tuple)):
if len(original) != len(transformed):
message = f"{pass_name}: Output count mismatch. Original: {len(original)}, Transformed: {len(transformed)}"
if log_differences:
# logging.warning(message)
pass
if fail_on_mismatch:
raise ValueError(message)
else:
for i, (orig_item, trans_item) in enumerate(zip(original, transformed)):
self._compare_outputs(
orig_item, trans_item, f"{pass_name}[{i}]",
tolerance, log_differences, fail_on_mismatch
)
else:
if log_differences:
pass
# logging.info(f"{pass_name}: Non-tensor outputs, skipping numerical comparison")

# Preserve the original class name and documentation
WrappedPass.__name__ = pass_class.__name__
WrappedPass.__qualname__ = pass_class.__qualname__
WrappedPass.__doc__ = pass_class.__doc__

return cast(type[PassBase], WrappedPass) # type: ignore[return-value]

return decorator



# Testing utils
# Return the compute/function nodes in the graph
Expand Down
Loading
Loading