Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions backends/cadence/aot/memory_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,37 @@ def compute_slice_and_select_loc_constraints(
]


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class GenerateIdmaConstraints(PassBase):
"""Generate constraints for idma ops."""

def __init__(self, constraint: MemConstraints) -> None:
self.constraint = constraint

def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.cadence.idma_wait.out
):
# This is just an alias op.
self.constraint.add_relative_placement_constraint(node.args[0], node)

for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.cadence.idma_load.out
):
# TODO: set correct dtcm bank here.
mem_id = 1
self.constraint.add_absolute_placement_constraint(node, mem_id, None)

for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.cadence.idma_store.out
):
# TODO: set correct dtcm bank here.
mem_id = 1
self.constraint.add_absolute_placement_constraint(
node.args[0], mem_id, None
)


# The class to generate all the constraints that will be passed on to the memory
# planning algorithm.
class GenerateMemConstraints:
Expand All @@ -671,6 +702,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
constraint_gen_passes: Sequence[ConstraintsGenPass] = cast(
list[ConstraintsGenPass],
[
GenerateIdmaConstraints,
GenerateMemoryViewConstraints,
GenerateSliceAndSelectNopConstraints,
GenerateCatNopConstraints,
Expand Down
41 changes: 39 additions & 2 deletions backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import collections
import itertools
import logging
from typing import Iterable, Optional, Sequence
from typing import Callable, Iterable, Optional, Sequence, TypeAlias

import torch
from executorch.backends.cadence.aot.memory_constraints import MemConstraints
Expand All @@ -26,6 +26,8 @@

from executorch.exir import ExecutorchProgramManager
from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier
from executorch.exir.pass_base import PassBase
from executorch.exir.pass_manager import PassManager
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.tensor import TensorSpec
from tabulate import tabulate
Expand Down Expand Up @@ -359,6 +361,35 @@ def print_memory_planning_info(
)


class SimplifyIdmaOpsPass(PassBase):
"""Replace idma_load and idma_store with idma_copy."""

def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
modified = False
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.cadence.idma_load.out
):
modified = True
node.target = torch.ops.cadence.idma_copy.out
node.args = (node.args[0], *node.args[2:])

for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.cadence.idma_store.out
):
modified = True
node.target = torch.ops.cadence.idma_copy.out

graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, modified)


ConstraintGenPassType: TypeAlias = Callable[
[MemConstraints],
Callable[[torch.fx.GraphModule], Optional[PassResult]],
]


class CadenceMemoryPlanning:
def __init__(
self,
Expand Down Expand Up @@ -423,10 +454,16 @@ def run(
# True.
mem_planning = MemoryPlanningPass(
self.algo,
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
# Always allow lifetime and storage overlap.
# At opt level 0, we need overlap for idma wait.
allow_lifetime_and_storage_overlap=True,
alloc_graph_input=self.alloc_graph_input,
alloc_graph_output=self.alloc_graph_output,
)
mem_planning.run(graph_module, graph_signature)

graph_module = PassManager(passes=[SimplifyIdmaOpsPass()])(
graph_module
).graph_module

return PassResult(graph_module, True)
6 changes: 6 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,13 @@
# Post memory planning, we check that outputs/inputs for the load/store are in
# DTCM and replace idma_load/idma_store with idma_copy.
lib.define("idma_load(Tensor src, int task_num=0, int channel=0) -> Tensor")
lib.define(
"idma_load.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define("idma_store(Tensor src, int task_num=0, int channel=0) -> Tensor")
lib.define(
"idma_store.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
)

# Non-blocking iDMA copy.
lib.define("idma_copy(Tensor src, int task_num=0, int channel=0) -> Tensor")
Expand Down
Loading