From cbc1d3112cf426b6fce249e6ce189f48d4f4c275 Mon Sep 17 00:00:00 2001 From: Hyungkeun Park Date: Mon, 23 Mar 2026 07:57:42 +0000 Subject: [PATCH 1/4] fix: check both partitions for cycles in GroupBasedPartitioner._can_merge_partitions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous implementation only checked downstream dependencies from p2, assuming p2 always precedes p1 in topological order. This assumption breaks when partition groups contain nodes spanning wide topological ranges — for example, when dynamic quantization inserts a shared `choose_qparams` node consumed by GEMM ops in different sequential transformer decoder layers. In that case the two groups *interleave* in topological order, and the single-direction check misses cycles flowing from p1 through external nodes back into p2. This change: 1. Collects external users from *both* p1 and p2 (combined_nodes) instead of only p2. 2. Adds a `validate_partition` safety net that performs a direct BFS on the live graph edges, catching any cycle the pre-computed `_DependencyViewer` might miss. Fixes `AssertionError: Invalid partition, found dependency cycles` when lowering cross-attention transformer decoders (e.g. DETR) with `XnnpackDynamicallyQuantizedPartitioner`. --- .../group_partitioner.py | 29 +++-- exir/backend/test/test_group_partitioner.py | 112 ++++++++++++++++++ 2 files changed, 132 insertions(+), 9 deletions(-) diff --git a/exir/backend/canonical_partitioners/group_partitioner.py b/exir/backend/canonical_partitioners/group_partitioner.py index 2594bbe05c4..191df174831 100644 --- a/exir/backend/canonical_partitioners/group_partitioner.py +++ b/exir/backend/canonical_partitioners/group_partitioner.py @@ -97,20 +97,24 @@ def __init__( def _can_merge_partitions(self, p1, p2, partitions_by_id): """Check if merging two partitions would create a cycle.""" + from torch.fx.passes.utils.fuser_utils import validate_partition + p1_nodes = set(partitions_by_id[p1].nodes.keys()) p2_nodes = set(partitions_by_id[p2].nodes.keys()) combined_nodes = p1_nodes.union(p2_nodes) + # Check external users from BOTH partitions. The original code only + # checked p2 under the assumption that p2 is always topologically + # before p1. However, when partition groups contain nodes that span + # wide topological ranges (e.g. due to shared dynamic-quantization + # choose_qparams nodes), the two partitions can *interleave* in + # topological order, making the single-direction check insufficient. + # + # We still only need to collect the *direct* external users (not + # transitive ones), because dependency_viewer.downstreams_of already + # returns the full transitive closure. user_nodes = [] - # topologically, p2_nodes comes before p1_nodes, so we only - # need to check the downstream nodes of p2. - # Additionally, we don't need to check all the downstream nodes - # of p2, we only need to check the nodes directly outside of p2. - # example: - # partition[a --> b --> c] --> d --> e --> f - # we don't need to check [d, e, f] we only need to check [d] because - # the downstream users of [d] will include [e, f] - for node in p2_nodes: + for node in combined_nodes: for user in node.users: if user not in combined_nodes: user_nodes.append(user) @@ -121,6 +125,13 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id): if any(n in combined_nodes for n in downstream_nodes): return False + # Final safety net: validate_partition performs a direct BFS on the + # live graph edges, catching any cycle the pre-computed + # dependency_viewer might miss (e.g. when the graph was transformed + # after the viewer was built). + if not validate_partition(list(combined_nodes)): + return False + return True def _process_all_nodes( diff --git a/exir/backend/test/test_group_partitioner.py b/exir/backend/test/test_group_partitioner.py index e629e240be5..1c56ccf961d 100644 --- a/exir/backend/test/test_group_partitioner.py +++ b/exir/backend/test/test_group_partitioner.py @@ -1672,3 +1672,115 @@ def forward(self, x): # With allows_single_node_partition=True, we should have partitions self.assertGreater(len(partitions_with_single), 0) + + def test_interleaved_groups_no_false_merge(self): + """ + Test that _can_merge_partitions correctly rejects merges when two + partition groups interleave in topological order. + + This reproduces a real-world failure with + XnnpackDynamicallyQuantizedPartitioner on transformer decoder models + where cross-attention K/V projections across multiple decoder layers + share the same encoder ``memory`` input. Dynamic quantization inserts + a shared ``choose_qparams`` node for that input, causing the DSJ phase + to create partition groups whose nodes span wide topological ranges. + When GroupBasedPartitioner later tries to merge these groups, the + original single-direction downstream check missed the cycle because it + assumed p2 is entirely before p1 — which is false for interleaved + groups. + + The model is a minimal two-layer cross-attention decoder: + + .. code-block:: text + + query ──→ layer0(query, memory) ──→ layer1(x, memory) ──→ output + ↑ ↑ + memory ───────────┴─────────────────────────┘ + (shared K/V input across layers) + """ + import math + + class DecoderLayer(torch.nn.Module): + def __init__(self, d: int = 256): + super().__init__() + self.q_proj = torch.nn.Linear(d, d, bias=False) + self.k_proj = torch.nn.Linear(d, d, bias=False) + self.v_proj = torch.nn.Linear(d, d, bias=False) + self.out_proj = torch.nn.Linear(d, d, bias=False) + self.ffn1 = torch.nn.Linear(d, d * 2, bias=False) + self.ffn2 = torch.nn.Linear(d * 2, d, bias=False) + self.norm1 = torch.nn.LayerNorm(d) + self.norm2 = torch.nn.LayerNorm(d) + + def forward( + self, x: torch.Tensor, mem: torch.Tensor + ) -> torch.Tensor: + q = self.q_proj(x) + k = self.k_proj(mem) + v = self.v_proj(mem) + attn = torch.bmm( + q, k.transpose(-2, -1) + ) / math.sqrt(q.size(-1)) + attn = torch.softmax(attn, dim=-1) + out = self.out_proj(torch.bmm(attn, v)) + x = self.norm1(x + out) + x = self.norm2( + x + self.ffn2(torch.relu(self.ffn1(x))) + ) + return x + + class TwoLayerDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer0 = DecoderLayer() + self.layer1 = DecoderLayer() + + def forward( + self, query: torch.Tensor, memory: torch.Tensor + ) -> torch.Tensor: + x = self.layer0(query, memory) + x = self.layer1(x, memory) + return x + + from torch.ao.quantization.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + ) + from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackDynamicallyQuantizedPartitioner, + ) + from executorch.exir import to_edge_transform_and_lower + + model = TwoLayerDecoder().eval() + query = torch.randn(1, 10, 256) + memory = torch.randn(1, 20, 256) + + exported = torch.export.export( + model, (query, memory), strict=False + ) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + ) + prepared = prepare_pt2e(exported.module(), quantizer) + with torch.no_grad(): + prepared(query, memory) + converted = convert_pt2e(prepared) + + re_exported = torch.export.export( + converted, (query, memory), strict=False + ) + + # Before the fix this raised: + # AssertionError: Invalid partition, found dependency cycles + to_edge_transform_and_lower( + re_exported, + partitioner=[XnnpackDynamicallyQuantizedPartitioner()], + ) From 10ea9c1a65e6beebf185e972b6c3c0e654c9a0c8 Mon Sep 17 00:00:00 2001 From: Hyungkeun-Park-Nota Date: Tue, 14 Apr 2026 05:45:59 +0000 Subject: [PATCH 2/4] fix: fix lint errors in test_interleaved_groups_no_false_merge - Replace torch.ao.quantization imports with torchao.quantization.pt2e.quantize_pt2e and executorch.backends.xnnpack.quantizer.xnnpack_quantizer - Fix UFMT formatting issues in forward() signatures and torch.bmm/norm2 calls --- exir/backend/test/test_group_partitioner.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/exir/backend/test/test_group_partitioner.py b/exir/backend/test/test_group_partitioner.py index 1c56ccf961d..d475a9038cb 100644 --- a/exir/backend/test/test_group_partitioner.py +++ b/exir/backend/test/test_group_partitioner.py @@ -1712,21 +1712,15 @@ def __init__(self, d: int = 256): self.norm1 = torch.nn.LayerNorm(d) self.norm2 = torch.nn.LayerNorm(d) - def forward( - self, x: torch.Tensor, mem: torch.Tensor - ) -> torch.Tensor: + def forward(self, x: torch.Tensor, mem: torch.Tensor) -> torch.Tensor: q = self.q_proj(x) k = self.k_proj(mem) v = self.v_proj(mem) - attn = torch.bmm( - q, k.transpose(-2, -1) - ) / math.sqrt(q.size(-1)) + attn = torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) attn = torch.softmax(attn, dim=-1) out = self.out_proj(torch.bmm(attn, v)) x = self.norm1(x + out) - x = self.norm2( - x + self.ffn2(torch.relu(self.ffn1(x))) - ) + x = self.norm2(x + self.ffn2(torch.relu(self.ffn1(x)))) return x class TwoLayerDecoder(torch.nn.Module): @@ -1735,18 +1729,16 @@ def __init__(self): self.layer0 = DecoderLayer() self.layer1 = DecoderLayer() - def forward( - self, query: torch.Tensor, memory: torch.Tensor - ) -> torch.Tensor: + def forward(self, query: torch.Tensor, memory: torch.Tensor) -> torch.Tensor: x = self.layer0(query, memory) x = self.layer1(x, memory) return x - from torch.ao.quantization.quantize_pt2e import ( + from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, ) - from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, ) From fef81a589614cf3b0ea078fe96c75900000290a4 Mon Sep 17 00:00:00 2001 From: Hyungkeun-Park-Nota Date: Wed, 15 Apr 2026 01:22:45 +0000 Subject: [PATCH 3/4] fix: apply ufmt formatting to test_interleaved_groups_no_false_merge --- exir/backend/test/test_group_partitioner.py | 28 ++++++++------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/exir/backend/test/test_group_partitioner.py b/exir/backend/test/test_group_partitioner.py index d475a9038cb..2af59fb763f 100644 --- a/exir/backend/test/test_group_partitioner.py +++ b/exir/backend/test/test_group_partitioner.py @@ -1729,46 +1729,38 @@ def __init__(self): self.layer0 = DecoderLayer() self.layer1 = DecoderLayer() - def forward(self, query: torch.Tensor, memory: torch.Tensor) -> torch.Tensor: + def forward( + self, query: torch.Tensor, memory: torch.Tensor + ) -> torch.Tensor: x = self.layer0(query, memory) x = self.layer1(x, memory) return x - from torchao.quantization.pt2e.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackDynamicallyQuantizedPartitioner, ) from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( - XNNPACKQuantizer, get_symmetric_quantization_config, - ) - - from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackDynamicallyQuantizedPartitioner, + XNNPACKQuantizer, ) from executorch.exir import to_edge_transform_and_lower + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e model = TwoLayerDecoder().eval() query = torch.randn(1, 10, 256) memory = torch.randn(1, 20, 256) - exported = torch.export.export( - model, (query, memory), strict=False - ) + exported = torch.export.export(model, (query, memory), strict=False) quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config( - is_per_channel=True, is_dynamic=True - ) + get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True) ) prepared = prepare_pt2e(exported.module(), quantizer) with torch.no_grad(): prepared(query, memory) converted = convert_pt2e(prepared) - re_exported = torch.export.export( - converted, (query, memory), strict=False - ) + re_exported = torch.export.export(converted, (query, memory), strict=False) # Before the fix this raised: # AssertionError: Invalid partition, found dependency cycles From 50a24470140a8eea31c5e3886ca5e90f12711b54 Mon Sep 17 00:00:00 2001 From: Hyungkeun-Park-Nota Date: Thu, 16 Apr 2026 01:24:26 +0000 Subject: [PATCH 4/4] fix: add missing deps to test_group_partitioner buck target --- exir/backend/test/BUCK | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/exir/backend/test/BUCK b/exir/backend/test/BUCK index 057aaf4caa3..12c8fb1015e 100644 --- a/exir/backend/test/BUCK +++ b/exir/backend/test/BUCK @@ -426,5 +426,9 @@ fbcode_target(_kind = runtime.python_test, deps = [ "//caffe2:torch", "//executorch/exir/backend/canonical_partitioners:group_partitioner_lib", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", + "//executorch/exir:lib", + "//pytorch/ao:torchao", ], )