From d84c4be082e7923ddefe45dc3e942c63304e279b Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Wed, 3 Sep 2025 13:01:18 -0700 Subject: [PATCH 1/2] Add constraints for cadence idma ops. (#12597) Summary: Add memory planning constraints for idma ops: 1. idma load: output needs to be in DTCM 2. idma store: input needs to be in DTCM 3. idma wait: output aliases the input Reviewed By: zonglinpeng Differential Revision: D77232760 --- backends/cadence/aot/memory_constraints.py | 32 ++++++++++++++++++++++ backends/cadence/aot/memory_planning.py | 4 ++- backends/cadence/aot/ops_registrations.py | 6 ++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/memory_constraints.py b/backends/cadence/aot/memory_constraints.py index 8e784cd2779..0eaaa8987c6 100644 --- a/backends/cadence/aot/memory_constraints.py +++ b/backends/cadence/aot/memory_constraints.py @@ -654,6 +654,37 @@ def compute_slice_and_select_loc_constraints( ] +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class GenerateIdmaConstraints(PassBase): + """Generate constraints for idma ops.""" + + def __init__(self, constraint: MemConstraints) -> None: + self.constraint = constraint + + def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]: + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.cadence.idma_wait.out + ): + # This is just an alias op. + self.constraint.add_relative_placement_constraint(node.args[0], node) + + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.cadence.idma_load.out + ): + # TODO: set correct dtcm bank here. + mem_id = 1 + self.constraint.add_absolute_placement_constraint(node, mem_id, None) + + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.cadence.idma_store.out + ): + # TODO: set correct dtcm bank here. + mem_id = 1 + self.constraint.add_absolute_placement_constraint( + node.args[0], mem_id, None + ) + + # The class to generate all the constraints that will be passed on to the memory # planning algorithm. class GenerateMemConstraints: @@ -671,6 +702,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: constraint_gen_passes: Sequence[ConstraintsGenPass] = cast( list[ConstraintsGenPass], [ + GenerateIdmaConstraints, GenerateMemoryViewConstraints, GenerateSliceAndSelectNopConstraints, GenerateCatNopConstraints, diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index ecf3fcef01c..7608c708125 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -423,7 +423,9 @@ def run( # True. mem_planning = MemoryPlanningPass( self.algo, - allow_lifetime_and_storage_overlap=(self.opt_level >= 2), + # Always allow lifetime and storage overlap. + # At opt level 0, we need overlap for idma wait. + allow_lifetime_and_storage_overlap=True, alloc_graph_input=self.alloc_graph_input, alloc_graph_output=self.alloc_graph_output, ) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index d64bc7d83ce..68091e2d521 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -304,7 +304,13 @@ # Post memory planning, we check that outputs/inputs for the load/store are in # DTCM and replace idma_load/idma_store with idma_copy. lib.define("idma_load(Tensor src, int task_num=0, int channel=0) -> Tensor") +lib.define( + "idma_load.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define("idma_store(Tensor src, int task_num=0, int channel=0) -> Tensor") +lib.define( + "idma_store.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)" +) # Non-blocking iDMA copy. lib.define("idma_copy(Tensor src, int task_num=0, int channel=0) -> Tensor") From 461971cd42abc44c9e7a6f4abcf8e82a011811cb Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Wed, 3 Sep 2025 13:01:18 -0700 Subject: [PATCH 2/2] Post memory planning passes. Summary: Adds passes that run after memory planning. 1. Convert idma load/store to idma copy. This saves code space. 2. Remove dtcm-to-dtcm idma copy. This saves cycles. Reviewed By: zonglinpeng Differential Revision: D77232764 --- backends/cadence/aot/memory_planning.py | 37 ++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index 7608c708125..5ff39580a23 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -9,7 +9,7 @@ import collections import itertools import logging -from typing import Iterable, Optional, Sequence +from typing import Callable, Iterable, Optional, Sequence, TypeAlias import torch from executorch.backends.cadence.aot.memory_constraints import MemConstraints @@ -26,6 +26,8 @@ from executorch.exir import ExecutorchProgramManager from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier +from executorch.exir.pass_base import PassBase +from executorch.exir.pass_manager import PassManager from executorch.exir.passes import MemoryPlanningPass from executorch.exir.tensor import TensorSpec from tabulate import tabulate @@ -359,6 +361,35 @@ def print_memory_planning_info( ) +class SimplifyIdmaOpsPass(PassBase): + """Replace idma_load and idma_store with idma_copy.""" + + def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]: + modified = False + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.cadence.idma_load.out + ): + modified = True + node.target = torch.ops.cadence.idma_copy.out + node.args = (node.args[0], *node.args[2:]) + + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.cadence.idma_store.out + ): + modified = True + node.target = torch.ops.cadence.idma_copy.out + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, modified) + + +ConstraintGenPassType: TypeAlias = Callable[ + [MemConstraints], + Callable[[torch.fx.GraphModule], Optional[PassResult]], +] + + class CadenceMemoryPlanning: def __init__( self, @@ -431,4 +462,8 @@ def run( ) mem_planning.run(graph_module, graph_signature) + graph_module = PassManager(passes=[SimplifyIdmaOpsPass()])( + graph_module + ).graph_module + return PassResult(graph_module, True)