Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions backends/cadence/aot/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ fbcode_target(_kind = runtime.python_library,
],
typing = True,
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/exir:pass_base",
"//executorch/backends/test:graph_builder",
],
)

Expand All @@ -238,11 +237,7 @@ fbcode_target(_kind = runtime.python_library,
],
typing = True,
deps = [
":graph_builder",
"fbcode//caffe2:torch",
"fbcode//executorch/exir:lib",
"fbcode//executorch/exir:pass_base",
"fbcode//executorch/exir/verification:verifier",
"//executorch/backends/test:program_builder",
],
)

Expand All @@ -253,7 +248,7 @@ fbcode_target(_kind = python_unittest,
],
typing = True,
deps = [
":program_builder",
"//executorch/backends/test:program_builder",
"//caffe2:torch",
"//later:lib",
],
Expand Down Expand Up @@ -397,7 +392,7 @@ fbcode_target(_kind = python_unittest,
":typing_stubs",
":type_dispatch",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand Down Expand Up @@ -437,7 +432,7 @@ fbcode_target(_kind = python_unittest,
deps = [
":ops_registrations",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand All @@ -458,7 +453,7 @@ fbcode_target(_kind = python_unittest,
":replace_ops",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand All @@ -479,7 +474,7 @@ fbcode_target(_kind = python_unittest,
"//caffe2:torch",
":typing_stubs",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand All @@ -500,7 +495,7 @@ fbcode_target(_kind = python_unittest,
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:fuse_ops",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir/dialects:lib",
Expand All @@ -521,7 +516,7 @@ fbcode_target(_kind = python_unittest,
":compiler",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:remove_ops",
Expand All @@ -541,7 +536,7 @@ fbcode_target(_kind = python_unittest,
":typing_stubs",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:simplify_ops",
Expand All @@ -561,7 +556,7 @@ fbcode_target(_kind = python_unittest,
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:fuse_ops",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:reorder_ops",
Expand All @@ -574,6 +569,7 @@ fbcode_target(_kind = runtime.python_library,
srcs = [
"memory_planning_algo.py",
],
typing = True,
deps = [
":memory_constraints",
":pass_utils",
Expand Down Expand Up @@ -618,6 +614,23 @@ fbcode_target(_kind = runtime.python_library,
],
)

fbcode_target(_kind = python_unittest,
name = "test_memory_planning_algo",
srcs = [
"tests/test_memory_planning_algo.py",
],
supports_static_listing = False,
typing = True,
deps = [
":memory_constraints",
":memory_planning",
":memory_planning_algo",
":utils",
"//caffe2:torch",
"//executorch/exir:tensor",
],
)

fbcode_target(_kind = python_unittest,
name = "test_memory_passes",
srcs = [
Expand All @@ -631,11 +644,11 @@ fbcode_target(_kind = python_unittest,
":typing_stubs",
":ops_registrations",
":pass_utils",
":program_builder",
"//executorch/backends/test:program_builder",
"//caffe2:torch",
"//executorch/exir:memory",
"//executorch/exir/dialects:lib",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/exir/tests:models",
],
)
Expand All @@ -647,8 +660,7 @@ fbcode_target(_kind = python_unittest,
],
typing = True,
deps = [
":program_builder",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/runtime:runtime",
"//later:lib",
Expand Down Expand Up @@ -678,7 +690,7 @@ fbcode_target(_kind = python_unittest,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/exir:pass_base",
"//pytorch/ao:torchao",
Expand All @@ -693,7 +705,7 @@ fbcode_target(_kind = python_unittest,
typing = True,
deps = [
":ops_registrations",
":program_builder",
"//executorch/backends/test:program_builder",
":to_out_var_pass",
"//caffe2:torch",
"//executorch/exir:lib",
Expand Down
137 changes: 4 additions & 133 deletions backends/cadence/aot/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,137 +6,8 @@

# pyre-strict

import logging
from typing import Optional, Sequence, Union
# This module has moved to executorch.backends.test.graph_builder.
# This re-export exists for backward compatibility.
from executorch.backends.test.graph_builder import GraphBuilder, single_op_builder

import torch
from executorch.exir.pass_base import (
Argument,
ExportPass,
NodeMetadata,
PassResult,
ProxyValue,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.node import Target
from torch.utils import _pytree as pytree


class GraphBuilder(ExportPass):
"""Utility class for creating a graph module with user-specified ops.

This class allows us to create test graph modules with any ops we want
directly, rather than relying on decomposition or passes.

Usage:
builder = GraphBuilder()
# To insert placeholders, use builder.placeholder.
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
# To insert an op, use builder.call_operator.
op = builder.call_operator(
some_op
(x, other_args, ...),
)
# Insert outputs as a list of ProxyValues using builder.output.
builder.output([op])
# Get GraphModule from builder.
gm = builder.get_graph_module()
"""

def __init__(self, fake_tensor_mode: Optional[FakeTensorMode] = None) -> None:
self.exporter = ExportPass()
self.tracer: ExportPass.ExportTracer = self.ExportTracer(
self, torch.fx.graph.CodeGen()
)
self.fake_tensor_mode: FakeTensorMode = fake_tensor_mode or FakeTensorMode(
allow_fallback_kernels=False,
allow_non_fake_inputs=True,
)
self.tracer.fake_tensor_mode = self.fake_tensor_mode

# This will be called to create nodes in tracer.
self.interpreter = torch.fx.Interpreter(
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
)

# pyre-ignore[14]: Inconsistent override.
def placeholder(
self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor]
) -> ProxyValue:
if not isinstance(fake_tensor, FakeTensor):
fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor)
logging.debug(f"Creating placeholder {target} => {fake_tensor.shape}")
placeholder = super().placeholder(target, fake_tensor, NodeMetadata({}))
return placeholder

# pyre-ignore[14]: Inconsistent override.
def output(self, results: list[ProxyValue]) -> ProxyValue:
logging.debug(f"Creating outputs {results}")
return super().output(results, NodeMetadata({}))

def get_graph_module(self) -> torch.fx.GraphModule:
return torch.fx.GraphModule(self.tracer.root, self.tracer.graph)

def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: Optional[dict[str, Argument]] = None,
meta: Optional[NodeMetadata] = None,
) -> ProxyValue:
if meta is None:
meta = NodeMetadata({})
if kwargs is None:
kwargs = {}
return super().call_operator(op, args, kwargs, meta)

def call_submodule(
self, graph_module: torch.fx.GraphModule, inputs: tuple[Argument, ...]
) -> PassResult:
return ExportPass().call(graph_module)

def call_getitem(
self, value: ProxyValue, key: int, meta: Optional[NodeMetadata] = None
) -> ProxyValue:
return super().call_getitem(value, key, meta or NodeMetadata({}))

def _fx(
self,
kind: str,
target: torch.fx.node.Target,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
with self.fake_tensor_mode, enable_python_dispatcher():
return super()._fx(kind, target, args, kwargs, meta)


def single_op_builder(
placeholders: Sequence[Union[torch.Tensor, FakeTensor]],
op: Target,
args: Sequence[Argument],
kwargs: Optional[dict[str, Argument]] = None,
) -> torch.fx.GraphModule:
"""Create a graph module with a single op.

Args:
placeholders: Placeholders to be used as inputs to the GraphModule.
op: The op to be inserted.
args: The args to be passed to the op.
kwargs: The kwargs to be passed to the op.

Returns:
A graph module with a single op
"""
builder = GraphBuilder()
op_to_placeholder_dict = {
p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders)
}
proxy_args, proxy_kwargs = pytree.tree_map_only(
(torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs)
)
node = builder.call_operator(op, proxy_args, proxy_kwargs)
builder.output([node])
return builder.get_graph_module()
__all__ = ["GraphBuilder", "single_op_builder"]
11 changes: 11 additions & 0 deletions backends/cadence/aot/memory_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,17 @@ def add_absolute_placement_constraint(
)
)

def set_absolute_placement_constraint(
self, spec: TensorSpec, constraint: AbsolutePlacementConstraint
) -> None:
"""Set an absolute placement constraint for `spec` by spec identity.

Overwrites any existing constraint for the same spec. Range validation
of pinned_memory_id is the caller's responsibility (depends on the
planner's MemoryConfig).
"""
self._absolute_placement_constraints[id(spec)] = constraint

def get_absolute_placement_constraint(
self, spec: TensorSpec
) -> Optional[AbsolutePlacementConstraint]:
Expand Down
34 changes: 34 additions & 0 deletions backends/cadence/aot/memory_planning_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,40 @@ def plan_with_constraints(
) -> None:
"""Callable interface for ET memory planning."""

# Promote specs with a pre-set mem_id to AbsolutePlacementConstraint so
# the planner honours the pinned memory tier and only assigns the offset.
# This is used by planned-temporary alloc nodes whose spec.mem_id is set
# by the AOT pass before planning runs.
#
# mem_id semantics:
# None — not yet assigned (default); planner picks freely
# <= 0 — sentinel for "unassigned/unpinned"; planner picks freely
# [1, num_memories) — valid tier; promoted to constraint below
#
# Materialize to list because collect_specs_from_nodes returns a
# generator and we iterate twice (promotion here, constraint
# collection in spec_and_abs_constraints below).
specs = list(specs)
for spec in specs:
if (
spec.mem_id is not None
and isinstance(spec.mem_id, int)
and spec.mem_id > 0
and placement_constraints.get_absolute_placement_constraint(spec)
is None
):
num_memories = self.get_num_memories()
assert 1 <= spec.mem_id < num_memories, (
f"Pre-set spec.mem_id={spec.mem_id} is invalid. "
f"Memory IDs must be in range [1, {num_memories}) for this planner configuration. "
f"Check that the spec.mem_id was set correctly in the AOT pass, "
f"or verify your MemoryConfig defines enough memory tiers."
)
placement_constraints.set_absolute_placement_constraint(
spec,
AbsolutePlacementConstraint(pinned_memory_id=spec.mem_id),
)

spec_and_abs_constraints = {
spec: placement_constraints.get_absolute_placement_constraint(spec)
for spec in specs
Expand Down
Loading
Loading