diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index f65f9cb3398..d025c8e6029 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -458,9 +458,7 @@ def get_deps( a bool indicating if the deps are valid and a list of all the dep nodes. This handles the src partition for """ - if self.src_partitions is None: - # Cache src partitions so we don't have to recompute them every time - self.src_partitions = get_source_partitions(ep.graph, self.linear_modules) + self.src_partitions = get_source_partitions(ep.graph, self.linear_modules) # src_partition is None if node is not in source partition, # otherwise gives us the linear source partition it belongs to diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index 5f3581b6aeb..d20a6003f3f 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -113,3 +113,14 @@ runtime.python_test( "//executorch/examples/xnnpack:models", # @manual ], ) + +runtime.python_test( + name = "test_xnnpack_partitioner", + srcs = ["test_xnnpack_partitioner.py"], + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/exir:lib", + "//executorch/extension/pybindings:portable_lib", + ], +) diff --git a/backends/xnnpack/test/test_xnnpack_partitioner.py b/backends/xnnpack/test/test_xnnpack_partitioner.py index 8cd9eb92d56..894fab4098f 100644 --- a/backends/xnnpack/test/test_xnnpack_partitioner.py +++ b/backends/xnnpack/test/test_xnnpack_partitioner.py @@ -9,8 +9,13 @@ import unittest import torch +import torch.nn.functional as F + from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import to_edge, to_edge_transform_and_lower +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) from torch.export import export @@ -82,3 +87,77 @@ def test_no_warning_for_to_edge_transform_and_lower_workflow(self): log_contents = log_capture_string.getvalue() self.assertNotIn("DEPRECATION WARNING", log_contents) + + def test_multi_method_partitioning_with_shared_weights(self): + """ + Test that multi-method models with shared weights are correctly partitioned. + Verify that: + 1. Both methods are fully lowered to XNNPACK. + 2. Constants are not duplicated between named data and constant buffers. + 3. Program executes correctly. + """ + + class MultiMethodModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(8, 16) + self.linear2 = torch.nn.Linear(16, 8) + + def forward(self, x): + return self.linear2(F.sigmoid(self.linear(x))) + + def forward_2(self, x): + return self.linear2(F.relu(self.linear(x))) + + def example_inputs(self): + return (torch.randn(1, 8),) + + model = MultiMethodModel() + + # Get eager reference output. + example_inputs = model.example_inputs() + with torch.no_grad(): + fwd1_eager = model.forward(*example_inputs) + fwd2_eager = model.forward_2(*example_inputs) + + # Export both methods + ep_fwd = export(model, model.example_inputs(), strict=True) + # Patch the forward, as export only traces the 'forward' method. + model.forward = model.forward_2 + ep_fwd_2 = export(model, model.example_inputs(), strict=True) + + # Convert to edge and lower to executorch + edge = to_edge({"forward": ep_fwd, "forward_2": ep_fwd_2}) + lowered = edge.to_backend(XnnpackPartitioner(force_fp32_dynamic_linear=True)) + executorch = lowered.to_executorch() + + # Check that graph is fully delegated. + nodes_1 = list(lowered._edge_programs["forward"].graph.nodes) + nodes_2 = list(lowered._edge_programs["forward_2"].graph.nodes) + self.assertEqual(len(nodes_1), 5) + self.assertEqual(len(nodes_2), 5) + expected_node_names = [ + "x", + "lowered_module_0", + "executorch_call_delegate", + "getitem", + "output_1", + ] + for n in expected_node_names: + self.assertTrue(any(node.name == n for node in nodes_1)) + self.assertTrue(any(node.name == n for node in nodes_2)) + + # Check that weights are not duplicated. + self.assertEqual(len(executorch._named_data.pte_data), 4) + self.assertEqual(len(executorch._named_data.buffers), 4) + self.assertEqual(len(executorch._named_data.external_data), 0) + + # Check that there are no constant buffers (besides the placeholder). + self.assertEqual(len(executorch._emitter_output.program.constant_buffer), 1) + + # Check for model correctness. + executorch_module = _load_for_executorch_from_buffer(executorch.buffer) + fwd1_et = executorch_module.run_method("forward", example_inputs) + fwd2_et = executorch_module.run_method("forward_2", example_inputs) + self.assertTrue(torch.allclose(fwd1_eager, fwd1_et[0], 1e-3)) + self.assertTrue(torch.allclose(fwd2_eager, fwd2_et[0], 1e-3))