Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions backends/cadence/aot/memory_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,37 @@ def compute_slice_and_select_loc_constraints(
]


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class GenerateIdmaConstraints(PassBase):
"""Generate constraints for idma ops."""

def __init__(self, constraint: MemConstraints) -> None:
self.constraint = constraint

def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.cadence.idma_wait.out
):
# This is just an alias op.
self.constraint.add_relative_placement_constraint(node.args[0], node)

for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.cadence.idma_load.out
):
# TODO: set correct dtcm bank here.
mem_id = 1
self.constraint.add_absolute_placement_constraint(node, mem_id, None)

for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.cadence.idma_store.out
):
# TODO: set correct dtcm bank here.
mem_id = 1
self.constraint.add_absolute_placement_constraint(
node.args[0], mem_id, None
)


# The class to generate all the constraints that will be passed on to the memory
# planning algorithm.
class GenerateMemConstraints:
Expand All @@ -671,6 +702,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
constraint_gen_passes: Sequence[ConstraintsGenPass] = cast(
list[ConstraintsGenPass],
[
GenerateIdmaConstraints,
GenerateMemoryViewConstraints,
GenerateSliceAndSelectNopConstraints,
GenerateCatNopConstraints,
Expand Down
4 changes: 3 additions & 1 deletion backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,9 @@ def run(
# True.
mem_planning = MemoryPlanningPass(
self.algo,
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
# Always allow lifetime and storage overlap.
# At opt level 0, we need overlap for idma wait.
allow_lifetime_and_storage_overlap=True,
alloc_graph_input=self.alloc_graph_input,
alloc_graph_output=self.alloc_graph_output,
)
Expand Down
6 changes: 6 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,13 @@
# Post memory planning, we check that outputs/inputs for the load/store are in
# DTCM and replace idma_load/idma_store with idma_copy.
lib.define("idma_load(Tensor src, int task_num=0, int channel=0) -> Tensor")
lib.define(
"idma_load.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define("idma_store(Tensor src, int task_num=0, int channel=0) -> Tensor")
lib.define(
"idma_store.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
)

# Non-blocking iDMA copy.
lib.define("idma_copy(Tensor src, int task_num=0, int channel=0) -> Tensor")
Expand Down
Loading