Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions backends/cadence/aot/program_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
from torch import Tensor
from torch._export.verifier import Verifier
from torch._ops import OpOverload
from torch.export import ExportedProgram
from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature
from torch.export.graph_signature import (
ExportGraphSignature,
InputKind,
Expand All @@ -20,6 +22,7 @@
OutputSpec,
TensorArgument,
)
from torch.utils import _pytree as pytree


class IrMode(Enum):
Expand All @@ -30,12 +33,19 @@ class IrMode(Enum):
class ProgramBuilder(GraphBuilder):
"""Utility class to build a program from a graph module."""

def __init__(self, mode: Optional[IrMode] = None) -> None:
def __init__(
self,
mode: Optional[IrMode] = None,
_core_aten_ops_exception_list: Optional[list[OpOverload]] = None,
) -> None:
self.input_specs: list[InputSpec] = []
self.output_specs: list[OutputSpec] = []
self.constants: dict[str, Tensor] = {}
self.state_dict: dict[str, Tensor] = {}
self.mode: IrMode = mode or IrMode.EXIR
self._core_aten_ops_exception_list: list[OpOverload] = (
_core_aten_ops_exception_list or []
)
super().__init__()

def insert_input_spec(
Expand Down Expand Up @@ -80,27 +90,42 @@ def get_verifiers(self) -> Optional[list[Verifier]]:
return None
return [
EXIREdgeDialectVerifier(
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_core_aten_ops_exception_list=self._core_aten_ops_exception_list,
),
core_aten_ops_exception_list=self._core_aten_ops_exception_list,
class_only=True,
)
]

def get_program(self) -> ExportedProgram:
gm = self.get_graph_module()
graph_signature = ExportGraphSignature(self.input_specs, self.output_specs)
in_spec = pytree.tree_flatten((tuple(graph_signature.user_inputs), {}))[1]
out_spec = pytree.tree_flatten(graph_signature.user_outputs)[1]
return ExportedProgram(
root=gm,
graph=gm.graph,
graph_signature=ExportGraphSignature(
input_specs=self.input_specs, output_specs=self.output_specs
),
graph_signature=graph_signature,
# pyre-ignore[6]: Incompatible parameter type.
constants=self.constants,
state_dict=self.state_dict,
range_constraints={},
module_call_graph=[],
module_call_graph=[
ModuleCallEntry(
"",
ModuleCallSignature(
inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
),
)
],
# pyre-ignore[6]: Incompatible parameter type.
verifiers=self.get_verifiers(),
)

def get_edge_program(self) -> EdgeProgramManager:
return EdgeProgramManager(self.get_program())
return EdgeProgramManager(
self.get_program(),
core_aten_ops_exception_list=self._core_aten_ops_exception_list,
)
16 changes: 16 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2242,10 +2242,26 @@ def call_operator(self, op, args, kwargs, meta):
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass):
"""
Replace aten linalg svd op with cadence custom op.
"""

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten._linalg_svd.default:
return super().call_operator(op, args, kwargs, meta)

return super().call_operator(
exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta
)


# This class encapsulates all the functions that replace/switch one op in the
# graph with another.
class CadenceReplaceOpsInGraph:
passes = [
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
ReplaceEmptyTensorsWithFullPass,
ReplaceFunctionallyEquivalentOpTargets,
ReplacePermuteWithTransposePass,
Expand Down
36 changes: 36 additions & 0 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ReplaceAddMMWithLinearPass,
ReplaceAtenApproxGeluWithApproxGeluPass,
ReplaceAtenConvolutionWithCadenceConvolutionPass,
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
ReplaceConstantPadNdWithSlicePass,
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
ReplaceConvWithChannelLastConvPass,
Expand Down Expand Up @@ -2045,3 +2046,38 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None:
len(avg_pool2d_nodes),
0,
)


class TestReplaceLinalgSvdPass(unittest.TestCase):
@expand(
[
("2x2", (2, 2)),
("3x3", (3, 3)),
("4x5", (4, 5)),
("10x10", (10, 10)),
]
)
@torch.no_grad()
def test_replace_aten_linalg_svd_with_cadence_linalg_svd(
self, _: str, shape: Tuple[int, int]
) -> None:
x = torch.randn(shape, dtype=torch.float32)
original_gm = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten._linalg_svd.default,
args=(x, False, True),
kwargs={"driver": None},
)

p = ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass()
graph_after_passes = cast(PassResult, p(original_gm)).graph_module

# Assert that the aten linalg_svd op was replaced with cadence linalg_svd op
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten._linalg_svd.default),
0,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.cadence.linalg_svd.default),
1,
)
Loading