From ba97db369ea95b754f6637a5de232ca76656ab00 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Tue, 2 Sep 2025 17:34:06 -0700 Subject: [PATCH] Add constraints for cadence idma ops. (#12597) Summary: Add memory planning constraints for idma ops: 1. idma load: output needs to be in DTCM 2. idma store: input needs to be in DTCM 3. idma wait: output aliases the input Reviewed By: zonglinpeng Differential Revision: D77232760 --- backends/cadence/aot/memory_constraints.py | 32 ++++++++++++++++++++++ backends/cadence/aot/memory_planning.py | 4 ++- backends/cadence/aot/ops_registrations.py | 6 ++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/memory_constraints.py b/backends/cadence/aot/memory_constraints.py index 8e784cd2779..0eaaa8987c6 100644 --- a/backends/cadence/aot/memory_constraints.py +++ b/backends/cadence/aot/memory_constraints.py @@ -654,6 +654,37 @@ def compute_slice_and_select_loc_constraints( ] +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class GenerateIdmaConstraints(PassBase): + """Generate constraints for idma ops.""" + + def __init__(self, constraint: MemConstraints) -> None: + self.constraint = constraint + + def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]: + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.cadence.idma_wait.out + ): + # This is just an alias op. + self.constraint.add_relative_placement_constraint(node.args[0], node) + + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.cadence.idma_load.out + ): + # TODO: set correct dtcm bank here. + mem_id = 1 + self.constraint.add_absolute_placement_constraint(node, mem_id, None) + + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.cadence.idma_store.out + ): + # TODO: set correct dtcm bank here. + mem_id = 1 + self.constraint.add_absolute_placement_constraint( + node.args[0], mem_id, None + ) + + # The class to generate all the constraints that will be passed on to the memory # planning algorithm. class GenerateMemConstraints: @@ -671,6 +702,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: constraint_gen_passes: Sequence[ConstraintsGenPass] = cast( list[ConstraintsGenPass], [ + GenerateIdmaConstraints, GenerateMemoryViewConstraints, GenerateSliceAndSelectNopConstraints, GenerateCatNopConstraints, diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index ecf3fcef01c..7608c708125 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -423,7 +423,9 @@ def run( # True. mem_planning = MemoryPlanningPass( self.algo, - allow_lifetime_and_storage_overlap=(self.opt_level >= 2), + # Always allow lifetime and storage overlap. + # At opt level 0, we need overlap for idma wait. + allow_lifetime_and_storage_overlap=True, alloc_graph_input=self.alloc_graph_input, alloc_graph_output=self.alloc_graph_output, ) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index d64bc7d83ce..68091e2d521 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -304,7 +304,13 @@ # Post memory planning, we check that outputs/inputs for the load/store are in # DTCM and replace idma_load/idma_store with idma_copy. lib.define("idma_load(Tensor src, int task_num=0, int channel=0) -> Tensor") +lib.define( + "idma_load.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define("idma_store(Tensor src, int task_num=0, int channel=0) -> Tensor") +lib.define( + "idma_store.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)" +) # Non-blocking iDMA copy. lib.define("idma_copy(Tensor src, int task_num=0, int channel=0) -> Tensor")