Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,13 +1640,26 @@ def placeholder( # noqa: C901
else:
spec.extra_tensor_info.fully_qualified_name = fqn
spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL
if self.emitter_state.emit_mutable_buffer_names and is_mutable_buffer:
if spec.extra_tensor_info is None:
spec.extra_tensor_info = ExtraTensorInfo(
fully_qualified_name=fqn, location=TensorDataLocation.SEGMENT

if is_mutable_buffer:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual core logic change, the rest of the changes are mostly piping a flag around

# Emit names if we are supposed to.
if self.emitter_state.emit_mutable_buffer_names:
if spec.extra_tensor_info is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little confused by this as i haven't kept track of the tensor info changes, what does it mean if extra_tensor_info is None and if it isn't we overwrite the fqn again? Maybe add a small comment here too.

Copy link
Contributor Author

@JacobSzwejbka JacobSzwejbka Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None just means no one has created one, if its not none they might not have populated the fqn so I populate it. I guess I could check what it is before overwriting it but the name is unique so it should always be safe to do this.

extra_tensor_info is where we put optional info in the flatbuffer to not regress the size of every tensor by too much in embedded cases that dont need it.

spec.extra_tensor_info = ExtraTensorInfo(
fully_qualified_name=fqn,
location=TensorDataLocation.SEGMENT,
)
else:
spec.extra_tensor_info.fully_qualified_name = fqn
# if We aren't emitting the name then it needs to be memory planned.
elif spec.mem_id is None or spec.mem_offset is None:
raise InternalError(
self._emit_node_specific_error(
self.node,
# [2:] to remove the b_ prefix buffers get
f'Mutable buffer "{target[2:]}" must have a memory id and offset if we are emitting it without a name. Please either memory plan your mutable buffers or call to_executorch with config=ExecutorchBackendConfig(emit_mutable_buffer_names=True)',
)
)
else:
spec.extra_tensor_info.fully_qualified_name = fqn

# From the fqn find the corresponding tensor
real_tensor = None
Expand Down
34 changes: 33 additions & 1 deletion exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,8 +1838,40 @@ def forward(self, x):
ep = to_edge(ep)
# Lower the graph to executorch.
ep = ep.to_executorch(
config=ExecutorchBackendConfig(emit_mutable_buffer_names=True)
config=ExecutorchBackendConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also add a post_init in the ExecutorchBackendConfig dataclass that asserts if emit_mutable_buffer_names is False and alloc_mutable_buffers is also False.

Copy link
Contributor Author

@JacobSzwejbka JacobSzwejbka Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this but since the MemoryPlanningPass is user configureable there is no guarantee memory_planning_pass.alloc_mutable_buffers exists. So I just check the end result in the emitter.

emit_mutable_buffer_names=True,
memory_planning_pass=MemoryPlanningPass(alloc_mutable_buffers=False),
)
)
for val in ep.executorch_program.execution_plan[0].values:
if isinstance(val, Tensor) and val.extra_tensor_info:
self.assertEqual(val.extra_tensor_info.fully_qualified_name, "buffer")
self.assertEqual(val.allocation_info, None)

def test_emit_mutable_buffer_names_fails(self) -> None:
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)
self.register_buffer("buffer", torch.zeros(1, 2))

def forward(self, x):
self.buffer.add_(1)
return self.linear(x) + self.buffer

net = Net()

ep = export(net, (torch.randn(1, 2),), strict=True)
# Lower the graph to edge dialect.
ep = to_edge(ep)
# Lower the graph to executorch.
# Must emit mutable buffer names if we don't allocate mutable buffers
with self.assertRaises(InternalError):
ep.to_executorch(
config=ExecutorchBackendConfig(
emit_mutable_buffer_names=False,
memory_planning_pass=MemoryPlanningPass(
alloc_mutable_buffers=False
),
)
)
12 changes: 10 additions & 2 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ def __init__(
graph_module: torch.fx.GraphModule,
alloc_graph_input: bool,
alloc_graph_output: bool,
alloc_mutable_buffers: bool,
graph_signature: Optional[ExportGraphSignature] = None,
) -> None:
self.graph_module = graph_module
self.graph_signature = graph_signature
self.alloc_graph_input = alloc_graph_input
self.alloc_graph_output = alloc_graph_output
self.alloc_mutable_buffers = alloc_mutable_buffers

@classmethod
def mem_obj_id_match(
Expand Down Expand Up @@ -149,6 +151,7 @@ def verify_storage_reuse(
ignore_const=True,
ignore_graph_input=not self.alloc_graph_input,
ignore_graph_output=not self.alloc_graph_output,
ignore_mutable_buffers=not self.alloc_mutable_buffers,
do_assertion=False,
ignore_out_var_node=False,
dedup=True,
Expand Down Expand Up @@ -374,6 +377,7 @@ def collect_specs_from_nodes( # noqa: C901
graph_signature: Optional[ExportGraphSignature] = None,
ignore_graph_input: bool = False,
ignore_graph_output: bool = False,
ignore_mutable_buffers: bool = False,
ignore_const: bool = True,
ignore_out_var_node: bool = True,
dedup: bool = True,
Expand Down Expand Up @@ -414,6 +418,9 @@ def collect_specs_from_nodes( # noqa: C901
if _is_inplace_node(node):
continue

if _is_mutable_buffer(node, graph_signature) and ignore_mutable_buffers:
continue

if do_assertion:
internal_assert(
node.op in ("placeholder", "output")
Expand Down Expand Up @@ -469,6 +476,7 @@ def update_all_tensors_lifetime(
Set the lifetime for all the tensors encountered in the Fx graph.
"""
specs = set()

for node_idx, node in enumerate(graph_module.graph.nodes):
for spec in collect_specs_from_nodes(
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
Expand Down Expand Up @@ -1053,6 +1061,7 @@ def apply_algo(
graph_signature: Optional[ExportGraphSignature] = None,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
alloc_mutable_buffers: bool = True,
) -> List[int]:
"""
Recursively apply algo to graph_module and its submodules for control flow.
Expand All @@ -1065,19 +1074,18 @@ def apply_algo(
storage with tensors in the outer module.
TODO: make these optimizations once we have some baseline working.
"""

# Extract the nodes and their lifespans from the graph_module
# Difficult to just filter the list of specs returned by this due to
# how we flag trainable weights.
_ = update_all_tensors_lifetime(graph_module, graph_signature)

# Filter specs based on alloc_graph_input and alloc_graph_output
specs = collect_specs_from_nodes(
graph_module.graph.nodes,
graph_signature,
do_assertion=False,
ignore_graph_input=not alloc_graph_input,
ignore_graph_output=not alloc_graph_output,
ignore_mutable_buffers=not alloc_mutable_buffers,
)

# Get extra padding for XNNPACK if needed
Expand Down
9 changes: 7 additions & 2 deletions exir/passes/memory_planning_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
allow_lifetime_and_storage_overlap: bool = False,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
alloc_mutable_buffers: bool = True,
alignment: int = ALIGNMENT,
) -> None:
r"""
Expand All @@ -54,10 +55,11 @@ def __init__(
"""
if memory_planning_algo is None:
memory_planning_algo = MemoryPlanningAlgorithmSuite()
self.memory_planning_algo = memory_planning_algo
self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo
self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
self.alloc_graph_input = alloc_graph_input
self.alloc_graph_output = alloc_graph_output
self.alloc_mutable_buffers = alloc_mutable_buffers
self.alignment = alignment

def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
Expand Down Expand Up @@ -124,13 +126,15 @@ def run(
# customized fields. Using the graph_module object to convey information across
# passes/stages is quite natural and avoid yet another 'context' data structure
# to do the job.

_ = apply_algo(
self.memory_planning_algo, # pyre-ignore[6]
self.memory_planning_algo,
graph_module,
self.alignment,
graph_signature,
self.alloc_graph_input,
self.alloc_graph_output,
self.alloc_mutable_buffers,
)

# TODO: make the verifier do the work recursively to handle
Expand All @@ -139,6 +143,7 @@ def run(
graph_module,
self.alloc_graph_input,
self.alloc_graph_output,
self.alloc_mutable_buffers,
graph_signature,
)

Expand Down
27 changes: 21 additions & 6 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def maketest(
use_functionalization: bool = True,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
alloc_mutable_buffer: bool = True,
has_unused_graph_input: bool = False,
) -> Callable[..., None]:
# parameterized.expand is not compatible with maketest. I'll just loop thru
Expand Down Expand Up @@ -282,10 +283,17 @@ def wrapper(self: "TestMemoryPlanning") -> None:
)(graph_module).graph_module

self.verify_reuse(
graph_module, expect_reuse, alloc_graph_input, alloc_graph_output
graph_module,
expect_reuse,
alloc_graph_input,
alloc_graph_output,
alloc_mutable_buffer,
)
self.verify_graph_input_output(
graph_module, alloc_graph_input, alloc_graph_output
graph_module,
alloc_graph_input,
alloc_graph_output,
alloc_mutable_buffer,
)

self.verify_overlap_placeholders(has_unused_graph_input, graph_module)
Expand All @@ -306,6 +314,7 @@ def verify_reuse(
expect_reuse: bool,
alloc_graph_input: bool,
alloc_graph_output: bool,
alloc_mutable_buffer: bool,
) -> None:
r"""
Do sanity check and verify tensor storage reuse.
Expand All @@ -321,6 +330,7 @@ def verify_reuse(
graph_module,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
alloc_mutable_buffers=alloc_mutable_buffer,
).verify_storage_reuse()

print(f"num_reuse_pairs is {num_reuse_pairs}")
Expand All @@ -334,9 +344,10 @@ def verify_graph_input_output(
graph_module: torch.fx.GraphModule,
alloc_graph_input: bool,
alloc_graph_output: bool,
alloc_mutable_buffers: bool,
) -> None:
Verifier(
graph_module, alloc_graph_input, alloc_graph_output
graph_module, alloc_graph_input, alloc_graph_output, alloc_mutable_buffers
).verify_graph_input_output()

def verify_overlap_placeholders(
Expand Down Expand Up @@ -404,13 +415,16 @@ def verify_overlap_placeholders(
)

def test_graph_input_output(self) -> None:
for alloc_graph_input, alloc_graph_output in itertools.product(
[True, False], [True, False]
):
for (
alloc_graph_input,
alloc_graph_output,
alloc_mutable_buffers,
) in itertools.product([True, False], [True, False], [True, False]):
case = maketest(
ModelWithDifferentTensorSizes,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
alloc_mutable_buffer=alloc_mutable_buffers,
)
case(self)

Expand Down Expand Up @@ -535,6 +549,7 @@ def test_multiple_pools(
graph_module,
alloc_graph_input=True,
alloc_graph_output=True,
alloc_mutable_buffers=True,
)
verifier.verify_storage_reuse()
verifier.verify_graph_input_output()
Expand Down
Loading