diff --git a/exir/backend/canonical_partitioners/group_partitioner.py b/exir/backend/canonical_partitioners/group_partitioner.py index 2594bbe05c4..191df174831 100644 --- a/exir/backend/canonical_partitioners/group_partitioner.py +++ b/exir/backend/canonical_partitioners/group_partitioner.py @@ -97,20 +97,24 @@ def __init__( def _can_merge_partitions(self, p1, p2, partitions_by_id): """Check if merging two partitions would create a cycle.""" + from torch.fx.passes.utils.fuser_utils import validate_partition + p1_nodes = set(partitions_by_id[p1].nodes.keys()) p2_nodes = set(partitions_by_id[p2].nodes.keys()) combined_nodes = p1_nodes.union(p2_nodes) + # Check external users from BOTH partitions. The original code only + # checked p2 under the assumption that p2 is always topologically + # before p1. However, when partition groups contain nodes that span + # wide topological ranges (e.g. due to shared dynamic-quantization + # choose_qparams nodes), the two partitions can *interleave* in + # topological order, making the single-direction check insufficient. + # + # We still only need to collect the *direct* external users (not + # transitive ones), because dependency_viewer.downstreams_of already + # returns the full transitive closure. user_nodes = [] - # topologically, p2_nodes comes before p1_nodes, so we only - # need to check the downstream nodes of p2. - # Additionally, we don't need to check all the downstream nodes - # of p2, we only need to check the nodes directly outside of p2. - # example: - # partition[a --> b --> c] --> d --> e --> f - # we don't need to check [d, e, f] we only need to check [d] because - # the downstream users of [d] will include [e, f] - for node in p2_nodes: + for node in combined_nodes: for user in node.users: if user not in combined_nodes: user_nodes.append(user) @@ -121,6 +125,13 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id): if any(n in combined_nodes for n in downstream_nodes): return False + # Final safety net: validate_partition performs a direct BFS on the + # live graph edges, catching any cycle the pre-computed + # dependency_viewer might miss (e.g. when the graph was transformed + # after the viewer was built). + if not validate_partition(list(combined_nodes)): + return False + return True def _process_all_nodes( diff --git a/exir/backend/test/BUCK b/exir/backend/test/BUCK index 057aaf4caa3..12c8fb1015e 100644 --- a/exir/backend/test/BUCK +++ b/exir/backend/test/BUCK @@ -426,5 +426,9 @@ fbcode_target(_kind = runtime.python_test, deps = [ "//caffe2:torch", "//executorch/exir/backend/canonical_partitioners:group_partitioner_lib", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", + "//executorch/exir:lib", + "//pytorch/ao:torchao", ], ) diff --git a/exir/backend/test/test_group_partitioner.py b/exir/backend/test/test_group_partitioner.py index e629e240be5..2af59fb763f 100644 --- a/exir/backend/test/test_group_partitioner.py +++ b/exir/backend/test/test_group_partitioner.py @@ -1672,3 +1672,99 @@ def forward(self, x): # With allows_single_node_partition=True, we should have partitions self.assertGreater(len(partitions_with_single), 0) + + def test_interleaved_groups_no_false_merge(self): + """ + Test that _can_merge_partitions correctly rejects merges when two + partition groups interleave in topological order. + + This reproduces a real-world failure with + XnnpackDynamicallyQuantizedPartitioner on transformer decoder models + where cross-attention K/V projections across multiple decoder layers + share the same encoder ``memory`` input. Dynamic quantization inserts + a shared ``choose_qparams`` node for that input, causing the DSJ phase + to create partition groups whose nodes span wide topological ranges. + When GroupBasedPartitioner later tries to merge these groups, the + original single-direction downstream check missed the cycle because it + assumed p2 is entirely before p1 — which is false for interleaved + groups. + + The model is a minimal two-layer cross-attention decoder: + + .. code-block:: text + + query ──→ layer0(query, memory) ──→ layer1(x, memory) ──→ output + ↑ ↑ + memory ───────────┴─────────────────────────┘ + (shared K/V input across layers) + """ + import math + + class DecoderLayer(torch.nn.Module): + def __init__(self, d: int = 256): + super().__init__() + self.q_proj = torch.nn.Linear(d, d, bias=False) + self.k_proj = torch.nn.Linear(d, d, bias=False) + self.v_proj = torch.nn.Linear(d, d, bias=False) + self.out_proj = torch.nn.Linear(d, d, bias=False) + self.ffn1 = torch.nn.Linear(d, d * 2, bias=False) + self.ffn2 = torch.nn.Linear(d * 2, d, bias=False) + self.norm1 = torch.nn.LayerNorm(d) + self.norm2 = torch.nn.LayerNorm(d) + + def forward(self, x: torch.Tensor, mem: torch.Tensor) -> torch.Tensor: + q = self.q_proj(x) + k = self.k_proj(mem) + v = self.v_proj(mem) + attn = torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) + attn = torch.softmax(attn, dim=-1) + out = self.out_proj(torch.bmm(attn, v)) + x = self.norm1(x + out) + x = self.norm2(x + self.ffn2(torch.relu(self.ffn1(x)))) + return x + + class TwoLayerDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer0 = DecoderLayer() + self.layer1 = DecoderLayer() + + def forward( + self, query: torch.Tensor, memory: torch.Tensor + ) -> torch.Tensor: + x = self.layer0(query, memory) + x = self.layer1(x, memory) + return x + + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackDynamicallyQuantizedPartitioner, + ) + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + from executorch.exir import to_edge_transform_and_lower + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + model = TwoLayerDecoder().eval() + query = torch.randn(1, 10, 256) + memory = torch.randn(1, 20, 256) + + exported = torch.export.export(model, (query, memory), strict=False) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True) + ) + prepared = prepare_pt2e(exported.module(), quantizer) + with torch.no_grad(): + prepared(query, memory) + converted = convert_pt2e(prepared) + + re_exported = torch.export.export(converted, (query, memory), strict=False) + + # Before the fix this raised: + # AssertionError: Invalid partition, found dependency cycles + to_edge_transform_and_lower( + re_exported, + partitioner=[XnnpackDynamicallyQuantizedPartitioner()], + )