Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ fbcode_target(_kind = runtime.python_library,
],
typing = True,
deps = [
":fuse_ops",
":ops_registrations",
"//caffe2:torch",
"//executorch/backends/cadence/aot:pass_utils",
Expand Down
6 changes: 4 additions & 2 deletions backends/cadence/aot/decompose_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from torch.fx.node import Argument


@register_cadence_pass(CadencePassAttribute(opt_level=0))
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class DecomposeAtenApproxGeluPass(ExportPass):
"""
Decompose the aten gelu op with an approximate arg to a series of simpler ops
Decompose the aten gelu op with an approximate arg to a series of simpler ops.
This is an optimization - gelu has a portable kernel fallback, but decomposing
may be more efficient on some backends.
"""

def call_operator(
Expand Down
9 changes: 7 additions & 2 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,15 @@
- arg_meta: null
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out

- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_max_pool2d_out
kernel_name: impl::generic::quantized_max_pool2d_nchw_out

- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_max_pool2d_nhwc_out

- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
Expand Down
5 changes: 4 additions & 1 deletion backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,10 @@ def can_fuse_for_chain(
return False

# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
producer_input = cast(torch.fx.Node, producer.args[0])
if "val" not in producer_input.meta:
return False
input_shape = producer_input.meta["val"].shape
ident_dims = list(range(len(input_shape)))
# this mapping helps to handle both transpose and permutations
f: dict[Any, Callable] = {
Expand Down
55 changes: 51 additions & 4 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,16 @@ def register_fake(
)

lib.define(
"quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
"quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
)
lib.define(
"quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
"quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
)
lib.define(
"quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
Expand Down Expand Up @@ -2277,8 +2283,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
return input.new_empty(input.size(), dtype=input.dtype)


@register_fake("cadence::quantized_max_pool2d")
def quantized_max_pool2d_meta(
@register_fake("cadence::quantized_max_pool2d_nchw")
def quantized_max_pool2d_nchw_meta(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
Expand Down Expand Up @@ -2318,6 +2324,47 @@ def quantized_max_pool2d_meta(
return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype)


@register_fake("cadence::quantized_max_pool2d_nhwc")
def quantized_max_pool2d_nhwc_meta(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
assert (
len(kernel_size) == 2
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
assert (
len(input.size()) == 4
), f"input must be 4D (N, H, W, C), got {len(input.size())}D"

batch = input.size(0)
height_in = input.size(1)
width_in = input.size(2)
channels = input.size(3)

height_out_raw = (
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
) / stride[0] + 1
width_out_raw = (
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
) / stride[1] + 1

if ceil_mode:
height_out = ceil(height_out_raw)
width_out = ceil(width_out_raw)
else:
height_out = int(height_out_raw)
width_out = int(width_out_raw)

return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype)


@register_fake("cadence::fully_connected")
def fully_connected_meta(
src: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from executorch.backends.cadence.aot.remove_ops import (
CadenceRemoveNops,
RemoveNopSliceOrViewOpPass,
RemovePermutesAroundElementwiseOps,
RemoveRedundantOps,
)
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
Expand Down Expand Up @@ -89,6 +90,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
CadenceSimplifyOpsInGraph.passes,
FinalizePipeline,
FuseFullThenReshapePass,
RemovePermutesAroundElementwiseOps,
FuseTransposeOrPermuteOpPairsPass,
RemoveNopSliceOrViewOpPass,
CompileTimeTypeDispatchPass,
Expand Down
7 changes: 5 additions & 2 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d.default
return torch.ops.cadence.quantized_max_pool2d_nchw.default


class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
Expand Down Expand Up @@ -498,7 +498,10 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d.default
return torch.ops.cadence.quantized_max_pool2d_nchw.default


# This is a base class for ReLU


# This is a base class for ReLU, since it can be used with two different aten ops
Expand Down
35 changes: 33 additions & 2 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,8 +1868,8 @@ def rms_norm(
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)


@impl_tracked(m, "quantized_max_pool2d")
def quantized_max_pool2d(
@impl_tracked(m, "quantized_max_pool2d_nchw")
def quantized_max_pool2d_nchw(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
Expand Down Expand Up @@ -1897,6 +1897,37 @@ def quantized_max_pool2d(
)


@impl_tracked(m, "quantized_max_pool2d_nhwc")
def quantized_max_pool2d_nhwc(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
"""
Quantized max pooling in NHWC layout.

Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC.
"""
# Convert NHWC [N, H, W, C] to NCHW [N, C, H, W]
input_nchw = input.permute(0, 3, 1, 2).contiguous()

# Call the NCHW version
output_nchw = quantized_max_pool2d_nchw(
input_nchw,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
)

# Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C]
return output_nchw.permute(0, 2, 3, 1).contiguous()


@impl_tracked(m, "where_Scalar")
def where_Scalar(
condition: torch.Tensor,
Expand Down
25 changes: 15 additions & 10 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

import torch
import torch.fx

from executorch.backends.cadence.aot.fuse_ops import FuseTransposeOrPermuteOpPairsPass
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
get_arg,
register_cadence_pass,
RemoveOrReplacePassInterface,
set_arg,
)

from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
Expand All @@ -33,7 +34,7 @@
from torch.fx.node import Node


@register_cadence_pass(CadencePassAttribute(opt_level=0))
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveCloneOpsTransformImported(ExportPass):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
finalize_passes: List[PassType] = [
Expand All @@ -44,7 +45,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return result


@register_cadence_pass(CadencePassAttribute(opt_level=0))
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveDetachCopyPass(RemoveOrReplacePassInterface):
@property
def targets(self) -> list[EdgeOpOverload]:
Expand All @@ -66,7 +67,7 @@ class RemoveRedundantOps:
]


@register_cadence_pass(CadencePassAttribute(opt_level=0))
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface):
@property
def targets(self) -> list[EdgeOpOverload]:
Expand Down Expand Up @@ -120,11 +121,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
return False


@register_cadence_pass(CadencePassAttribute(opt_level=0))
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveNopExpandOpPass(RemoveOrReplacePassInterface):
"""
For an expand op, if the operator shape matches the expand shape, then the
expand is a nop.
expand is a nop. This is an optimization that removes unnecessary ops.
"""

@property
Expand All @@ -143,9 +144,9 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
return False


@register_cadence_pass(CadencePassAttribute(opt_level=0))
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveToOpsPass(RemoveOrReplacePassInterface):
# aten.to.* as of now are all nops
# aten.to.* ops are no-ops in inference - this is an optimization
@property
def targets(self) -> list[EdgeOpOverload]:
return [
Expand Down Expand Up @@ -264,11 +265,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveAliasCopyOpPass(RemoveOrReplacePassInterface):
"""

alias_copy is a no-op and can be removed.
This is an optimization that removes unnecessary ops.
"""

@property
Expand Down Expand Up @@ -412,6 +413,9 @@ class Subgraph:
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantized_relu.per_tensor,
exir_ops.edge.cadence.requantize.per_tensor,
exir_ops.edge.cadence.quantized_add.per_tensor,
# Ops that require special handling.
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.mean.dim,
Expand Down Expand Up @@ -804,6 +808,7 @@ class CommonRemovePasses:
RemoveToOpsPass,
RemoveZeroSizedCatArgsPass,
RemovePermutesAroundElementwiseOps,
FuseTransposeOrPermuteOpPairsPass,
RemoveSqueezeViewBeforeElementwiseOps,
RemoveCatFromSliceCopyPass,
RemoveCloneOpsTransformImported,
Expand Down
Loading
Loading