Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions backends/arm/_passes/arm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult
Expand Down Expand Up @@ -124,3 +125,31 @@ def call_shape_operator(
shape_meta.data[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
# Call the super (ArmPass) call operator with updated meta
return self.call_operator(op, args, kwargs, shape_meta, updated)

def call_scalar(self, value: int | float, meta: NodeMetadata | dict[str, Any]):
"""Return a scalar value for the current pass stage.

In transform-for-annotation passes this returns the Python scalar
directly. In later passes it materializes a `(1,)` `aten.full` node
using the output dtype/device from `meta["val"]` when available.

"""

if self.is_tfa_pass:
return value

kwargs = {}
if "val" in meta:
val = meta["val"]
if isinstance(val, tuple):
val = val[0]
kwargs = {"device": val.device, "dtype": val.dtype}

return ArmPass.call_operator(
self,
op=exir_ops.edge.aten.full.default,
args=((1,), value),
kwargs=kwargs,
meta=meta,
updated=True,
)
12 changes: 6 additions & 6 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,12 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
DecomposeDivTensorModePass(tfa_pass=True),
DecomposeWhereScalarOtherPass(tfa_pass=True),
RewriteInplaceArithmeticPass(tfa_pass=True),
DecomposeAddSubAlphaPass(tfa_pass=True),
DecomposeLeakyReLUPass(tfa_pass=True),
DecomposeGroupNormPass(tfa_pass=True),
DecomposeLayerNormPass(tfa_pass=True),
DecomposeVarPass(tfa_pass=True),
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
]
)

Expand All @@ -573,16 +579,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_passes(
[
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
DecomposeAddSubAlphaPass(tfa_pass=True),
DecomposeGroupNormPass(tfa_pass=True),
DecomposeLayerNormPass(tfa_pass=True),
DecomposeVarPass(tfa_pass=True),
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
DecomposeNotEqualPass(tfa_pass=True),
DecomposeCosineSimilarityPass(tfa_pass=True),
DecomposeGluPass(tfa_pass=True),
DecomposeDivPass(tfa_pass=True),
DecomposeLeakyReLUPass(tfa_pass=True),
DecomposeLinalgVectorNormPass(tfa_pass=True),
DecomposeSqrtPass(tfa_pass=True),
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
Expand Down
32 changes: 32 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,38 @@ def meta_without_qparams(meta: NodeMetadata) -> NodeMetadata:
return NodeMetadata(plain_meta_dict)


def insert_scalar(
graph: torch.fx.Graph,
value: int | float,
meta: NodeMetadata | dict,
from_node: torch.fx.Node,
is_tfa_pass: bool = False,
) -> torch.fx.Node | int | float:
"""Insert an `aten.full` scalar node for direct graph-rewrite passes."""

if is_tfa_pass:
return value

kwargs = {}
val = None
if "val" in meta:
val = meta["val"]
if isinstance(val, tuple):
val = val[0]
kwargs = {"device": val.device, "dtype": val.dtype}

scalar = create_node(
graph=graph,
op_target=exir_ops.edge.aten.full.default,
args=((1,), value),
kwargs=kwargs,
from_node=from_node,
)
if val is not None:
scalar.meta["val"] = torch.full((1,), value, **kwargs)
return scalar


def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
"""Returns a FakeTensor from the meta field of 'node'.

Expand Down
15 changes: 2 additions & 13 deletions backends/arm/_passes/decompose_add_sub_alpha_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,20 @@ def _get_ops(op):
if op is exir_ops.edge.aten.add.Tensor:
return (
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.add.Tensor,
)
return (
torch.ops.aten.mul.Tensor,
torch.ops.aten.full.default,
torch.ops.aten.add.Tensor,
)
if op in _SUB_OPS:
if op is exir_ops.edge.aten.sub.Tensor:
return (
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.sub.Tensor,
)
return (
torch.ops.aten.mul.Tensor,
torch.ops.aten.full.default,
torch.ops.aten.sub.Tensor,
)
raise RuntimeError(f"Unsupported operator {op}")
Expand All @@ -72,19 +68,12 @@ def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
if not _should_decompose(alpha):
return super().call_operator(op, args, kwargs, meta, updated)

mul_op, full_op, binary_op = _get_ops(op)
mul_op, binary_op = _get_ops(op)
lhs, rhs = args

alpha_full = super().call_operator(
full_op,
((1,), float(alpha)),
{"device": meta["val"].device, "dtype": meta["val"].dtype},
meta,
updated=True,
)
scaled_rhs = super().call_operator(
mul_op,
(rhs, alpha_full),
(rhs, super().call_scalar(alpha, meta)),
{},
meta,
updated=True,
Expand Down
19 changes: 5 additions & 14 deletions backends/arm/_passes/decompose_asin_and_acos_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def get_decomposition(op) -> tuple:
exir_ops.edge.aten.gt.Scalar,
exir_ops.edge.aten.lt.Scalar,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.neg.default,
)

Expand Down Expand Up @@ -79,15 +78,12 @@ def _build_polynomial(
"""Helper function to build polynomial from coefficients and
variable.
"""
full_like_op, add_op, mul_op_scalar, mul_op = (
exir_ops.edge.aten.full_like.default,
add_op, mul_op_scalar, mul_op = (
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Scalar,
exir_ops.edge.aten.mul.Tensor,
)
result = super().call_operator(
full_like_op, (variable, coefficients[0]), {}, meta, True
)
result = super().call_scalar(coefficients[0], meta)
for coeff in coefficients[1:]:
result = super().call_operator(
add_op,
Expand Down Expand Up @@ -150,7 +146,6 @@ def call_operator(self, op, args, kwargs, meta):
gt_op,
lt_op,
sub_op,
full_like_op,
neg_op,
) = get_decomposition(op)

Expand Down Expand Up @@ -179,7 +174,7 @@ def call_operator(self, op, args, kwargs, meta):

# Step 2: Compute the transformed approximation for large values
# Calculate z = -0.5 * (|x| - 1)
tmp_ones = super().call_operator(full_like_op, (x_abs, one), {}, meta, True)
tmp_ones = super().call_scalar(one, meta)
tmp = super().call_operator(sub_op, (x_abs, tmp_ones), {}, meta, True)
z = super().call_operator(mul_op_scalar, (tmp, neg_half), {}, meta, True)

Expand All @@ -201,9 +196,7 @@ def call_operator(self, op, args, kwargs, meta):
t2 = super().call_operator(mul_op_scalar, (t1, two), {}, meta, True)

diff = super().call_operator(sub_op_scalar, (t2, pi_over_2), {}, meta, True)
tmp_neg_ones = super().call_operator(
full_like_op, (diff, neg_one), {}, meta, True
)
tmp_neg_ones = super().call_scalar(neg_one, meta)
asin_large = super().call_operator(mul_op, (diff, tmp_neg_ones), {}, meta, True)

asin_unsigned = self._combine_branches(
Expand All @@ -218,9 +211,7 @@ def call_operator(self, op, args, kwargs, meta):

if op in edge_acos_op:
# If x <= 0.5: acos(x) = pi/2 - asin(x)
const_tensor = super().call_operator(
full_like_op, (x, pi_over_2), {}, meta, True
)
const_tensor = super().call_scalar(pi_over_2, meta)
acos_small = super().call_operator(
sub_op, (const_tensor, asin), {}, meta, True
)
Expand Down
12 changes: 4 additions & 8 deletions backends/arm/_passes/decompose_erfinv_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def get_erfinv_decomposition(op) -> tuple:
if op in edge_erfinv_ops:
# Ordered by first use in call_operator below.
return (
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.where.self,
exir_ops.edge.aten.abs.default,
Expand Down Expand Up @@ -140,7 +139,6 @@ def call_operator(self, op, args, kwargs, meta):
x = args[0]

(
op_full_like,
op_lt_t,
op_where,
op_abs,
Expand Down Expand Up @@ -179,12 +177,10 @@ def call_operator(self, op, args, kwargs, meta):
CORR_MAX = 0.5
TWO_OVER_SQRT_PI = 1.1283791670955126

# ---- zeros / ones (tensor-shaped) ----
zeros = super().call_operator(op_full_like, (x, 0.0), {}, meta, updated=True)
ones = super().call_operator(op_full_like, (x, 1.0), {}, meta, updated=True)
neg_ones = super().call_operator(
op_full_like, (x, -1.0), {}, meta, updated=True
)
# ---- zeros / ones constants ----
zeros = super().call_scalar(0.0, meta)
ones = super().call_scalar(1.0, meta)
neg_ones = super().call_scalar(-1.0, meta)

# ---- s = sign(x): -1 for x<0 else +1 ----
x_lt0 = super().call_operator(op_lt_t, (x, zeros), {}, meta, updated=True)
Expand Down
40 changes: 7 additions & 33 deletions backends/arm/_passes/decompose_gelu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ def _get_gelu_ops(op) -> tuple:

if op in edge_gelu:
return (
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.erf.default,
)
if op in torch_gelu:
return (
torch.ops.aten.full.default,
torch.ops.aten.add.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.tanh.default,
Expand Down Expand Up @@ -98,30 +96,18 @@ def call_operator(self, op, args, kwargs, meta):
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta)

full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op)
add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op)

input = get_node_arg(args, 0)
# If approximate is default (none) it does not appear in kwargs
approximate = get_node_arg(kwargs, "approximate", "none")

shape = meta["val"].size()
dtype = meta["val"].dtype

FULL_0_5 = super().call_operator(
full_op, ([1] * len(shape), 0.5), {"dtype": dtype}, meta
)
FULL_1 = super().call_operator(
full_op, ([1] * len(shape), 1), {"dtype": dtype}, meta
)
FULL_0_5 = super().call_scalar(0.5, meta)
FULL_1 = super().call_scalar(1, meta)

if approximate == "none":
# Constant mirrors ExecuTorch implementation for parity.
FULL_SQRT1_2 = super().call_operator(
full_op,
([1] * len(shape), 0.70710678118654752440),
{"dtype": dtype},
meta,
)
FULL_SQRT1_2 = super().call_scalar(0.70710678118654752440, meta)

op1 = super().call_operator(mul_op, (input, FULL_SQRT1_2), {}, meta)
op2 = super().call_operator(erf_op, (op1,), {}, meta)
Expand All @@ -131,21 +117,9 @@ def call_operator(self, op, args, kwargs, meta):

elif approximate == "tanh":
# Constants mirror ExecuTorch implementation for parity.
FULL_SQRT2 = super().call_operator(
full_op,
([1] * len(shape), 1.41421356237309504880),
{"dtype": dtype},
meta,
)
FULL_2_SQRTPI = super().call_operator(
full_op,
([1] * len(shape), 1.12837916709551257390),
{"dtype": dtype},
meta,
)
FULL_CUBE_COEFF = super().call_operator(
full_op, ([1] * len(shape), 0.044715), {"dtype": dtype}, meta
)
FULL_SQRT2 = super().call_scalar(1.41421356237309504880, meta)
FULL_2_SQRTPI = super().call_scalar(1.12837916709551257390, meta)
FULL_CUBE_COEFF = super().call_scalar(0.044715, meta)

# Mirrors ExecuTorch implementations for calculating this value
SQRT_MUL = super().call_operator(
Expand Down
Loading
Loading