Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 149 additions & 133 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
CadencePassAttribute,
none_throws,
register_cadence_pass,
RemoveOrReplacePassInterface,
)
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
Expand Down Expand Up @@ -115,84 +116,84 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep
class ReplaceSafeSoftmaxWithSoftmax(RemoveOrReplacePassInterface): # keep
"""
Replace _safe_softmax with _softmax
"""

def call_operator(
self,
op,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op != torch.ops.aten._safe_softmax.default:
return super().call_operator(op, args, kwargs, meta)
@property
def targets(self) -> list[EdgeOpOverload]:
return [torch.ops.aten._safe_softmax.default]

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Add False for the half_to_float argument of softmax
softmax_args = list(args) + [False]
softmax_args = tuple(list(node.args) + [False])

return super().call_operator(
torch.ops.aten._softmax.default,
tuple(softmax_args),
kwargs,
meta,
)
with node.graph.inserting_before(node):
new_node = node.graph.call_function(
torch.ops.aten._softmax.default,
args=softmax_args,
kwargs=node.kwargs,
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplacePT2QuantWithCadenceQuantPass(ExportPass):
class ReplacePT2QuantWithCadenceQuantPass(RemoveOrReplacePassInterface):
"""
Replace the pt2 quantization ops with cadence quantization ops.
We do not link kernels to the PT2 quantization ops, so we need to
replace them with cadence ops at all optimization levels.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops
if op != ns.quantized_decomposed.quantize_per_tensor.default:
return super().call_operator(op, args, kwargs, meta)
@property
def targets(self) -> list[EdgeOpOverload]:
return [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
]

return super().call_operator(
ns.cadence.quantize_per_tensor.default,
args,
kwargs,
meta,
)
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops
with node.graph.inserting_before(node):
new_node = node.graph.call_function(
ns.cadence.quantize_per_tensor.default,
args=node.args,
kwargs=node.kwargs,
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
class ReplacePT2DequantWithCadenceDequantPass(RemoveOrReplacePassInterface):
"""
Replace the pt2 dequantization ops with cadence dequantization ops.
We do not link kernels to the PT2 quantization ops, so we need to
replace them with cadence ops at all optimization levels.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops
if op != ns.quantized_decomposed.dequantize_per_tensor.default:
return super().call_operator(op, args, kwargs, meta)
@property
def targets(self) -> list[EdgeOpOverload]:
return [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
]

return super().call_operator(
ns.cadence.dequantize_per_tensor.default,
args,
kwargs,
meta,
)
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops
with node.graph.inserting_before(node):
new_node = node.graph.call_function(
ns.cadence.dequantize_per_tensor.default,
args=node.args,
kwargs=node.kwargs,
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down Expand Up @@ -232,18 +233,34 @@ def call_operator(


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceFunctionallyEquivalentOpTargets(ExportPass):
class ReplaceFunctionallyEquivalentOpTargets(RemoveOrReplacePassInterface):
"""
Replace an op with a functionally equivalent op by just switching the op
target, but without incurring any change to the op args.
"""

def call_operator(self, op, args, kwargs, meta):
if op not in functionally_equivalent_op_targets:
return super().call_operator(op, args, kwargs, meta)
return super().call_operator(
functionally_equivalent_op_targets[op], args, kwargs, meta
)
@property
def targets(self) -> list[EdgeOpOverload]:
return list(functionally_equivalent_op_targets.keys())

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
assert isinstance(node.target, EdgeOpOverload)
target_op = functionally_equivalent_op_targets[node.target]
with node.graph.inserting_before(node):
new_node = node.graph.call_function(
target_op,
args=node.args,
kwargs=node.kwargs,
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)

# RemoveOrReplacePassInterface calls eliminate_dead_code, but this doesn't
# remove impure nodes (nodes which have side effects). Not sure if that is
# generally safe, so instead of modifying the interface, just erasing
# these nodes for this pass.
node.graph.erase_node(node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand Down Expand Up @@ -1438,82 +1455,95 @@ def call_operator(self, op, args, kwargs, meta):


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceScalarTensorWithFullPass(ExportPass):
class ReplaceScalarTensorWithFullPass(RemoveOrReplacePassInterface):
"""
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
scalar_tensor is not supported, so this is an opt_level=0 pass.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.aten.scalar_tensor.default,
@property
def targets(self) -> list[EdgeOpOverload]:
return [
torch.ops.aten.scalar_tensor.default,
}:
return super().call_operator(op, args, kwargs, meta)
exir_ops.edge.aten.scalar_tensor.default,
]

return super().call_operator(
exir_ops.edge.aten.full.default,
(
[1],
args[0],
),
{"dtype": torch.float32},
meta,
)
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
with node.graph.inserting_before(node):
new_node = node.graph.call_function(
exir_ops.edge.aten.full.default,
args=(
[1],
node.args[0],
),
kwargs={"dtype": torch.float32},
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceFullLikeWithFullPass(ExportPass):
class ReplaceFullLikeWithFullPass(RemoveOrReplacePassInterface):
"""
aten.full_like can be replaced by aten.full with the shape of the arg tensor.
full_like is not supported, so this is an opt_level=0 pass.
"""

def call_operator(self, op, args, kwargs, meta):
if op not in {
exir_ops.edge.aten.full_like.default,
}:
return super().call_operator(op, args, kwargs, meta)
@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.full_like.default]

# Get the shape of the "like" tensor, and pass that in to the full op.
return super().call_operator(
exir_ops.edge.aten.full.default,
(
args[0].to_tensor().shape,
args[1],
),
{},
meta,
)
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
input_arg = node.args[0]
assert isinstance(input_arg, torch.fx.Node)
shape = input_arg.meta["val"].shape
fill_value = node.args[1]

with node.graph.inserting_before(node):
new_node = node.graph.call_function(
exir_ops.edge.aten.full.default,
args=(shape, fill_value),
kwargs={},
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceInfArgInFullWithValuePass(ExportPass):
class ReplaceInfArgInFullWithValuePass(RemoveOrReplacePassInterface):
"""
aten.full allows "-inf" and "inf" as inputs. The profiler cannot
handle that, so replace them with the maximum value of the type.
"""

def call_operator(self, op, args, kwargs, meta):
if op not in {
exir_ops.edge.aten.full.default,
}:
return super().call_operator(op, args, kwargs, meta)
@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.full.default]

new_args = list(args)
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:

if args[1] == float("-inf"):
new_args = list(node.args)
fill_value = node.args[1]
if fill_value == float("-inf"):
new_args[1] = torch.finfo(torch.float32).min
elif args[1] == float("inf"):
elif fill_value == float("inf"):
new_args[1] = torch.finfo(torch.float32).max
else:
return False

return super().call_operator(op, tuple(new_args), kwargs, meta)
new_args = tuple(new_args)

with node.graph.inserting_before(node):
new_node = node.graph.call_function(
exir_ops.edge.aten.full.default,
args=new_args,
kwargs=node.kwargs,
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down Expand Up @@ -1713,26 +1743,6 @@ def call_operator(
return super().call_operator(op, args, kwargs, meta)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass):
"""
Replace the aten gelu op with an approximate arg with an approximate gelu op.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.aten.gelu.default,
}:
return super().call_operator(op, args, kwargs, meta)
return super().call_operator(op, args, kwargs, meta)


# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py
@register_cadence_pass(CadencePassAttribute(opt_level=2))
class ReplaceSplitWithSlicePass(ExportPass):
Expand Down Expand Up @@ -2122,18 +2132,25 @@ class CommonReplacePasses:


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass):
class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(RemoveOrReplacePassInterface):
"""
Replace aten linalg svd op with cadence custom op.
"""

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten._linalg_svd.default:
return super().call_operator(op, args, kwargs, meta)
@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten._linalg_svd.default]

return super().call_operator(
exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta
)
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
with node.graph.inserting_before(node):
new_node = node.graph.call_function(
exir_ops.edge.cadence.linalg_svd.default,
args=node.args,
kwargs=node.kwargs,
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True


# This class encapsulates all the functions that replace/switch one op in the
Expand Down Expand Up @@ -2165,6 +2182,5 @@ class CadenceReplaceOpsInGraph:
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
ReplaceAtenAvgPoolWithCadenceAvgPoolPass,
ReplaceWhereWithFullArgsWithWhereScalar,
ReplaceAtenApproxGeluWithApproxGeluPass,
ReplaceMulTensorWithMulAndFullOpsPass,
]
Loading
Loading