Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions exir/backend/canonical_partitioners/group_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,24 @@ def __init__(

def _can_merge_partitions(self, p1, p2, partitions_by_id):
"""Check if merging two partitions would create a cycle."""
from torch.fx.passes.utils.fuser_utils import validate_partition

p1_nodes = set(partitions_by_id[p1].nodes.keys())
p2_nodes = set(partitions_by_id[p2].nodes.keys())
combined_nodes = p1_nodes.union(p2_nodes)

# Check external users from BOTH partitions. The original code only
# checked p2 under the assumption that p2 is always topologically
# before p1. However, when partition groups contain nodes that span
# wide topological ranges (e.g. due to shared dynamic-quantization
# choose_qparams nodes), the two partitions can *interleave* in
# topological order, making the single-direction check insufficient.
#
# We still only need to collect the *direct* external users (not
# transitive ones), because dependency_viewer.downstreams_of already
# returns the full transitive closure.
user_nodes = []
# topologically, p2_nodes comes before p1_nodes, so we only
# need to check the downstream nodes of p2.
# Additionally, we don't need to check all the downstream nodes
# of p2, we only need to check the nodes directly outside of p2.
# example:
# partition[a --> b --> c] --> d --> e --> f
# we don't need to check [d, e, f] we only need to check [d] because
# the downstream users of [d] will include [e, f]
for node in p2_nodes:
for node in combined_nodes:
for user in node.users:
if user not in combined_nodes:
user_nodes.append(user)
Expand All @@ -121,6 +125,13 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
if any(n in combined_nodes for n in downstream_nodes):
return False

# Final safety net: validate_partition performs a direct BFS on the
# live graph edges, catching any cycle the pre-computed
# dependency_viewer might miss (e.g. when the graph was transformed
# after the viewer was built).
if not validate_partition(list(combined_nodes)):
return False

return True

def _process_all_nodes(
Expand Down
4 changes: 4 additions & 0 deletions exir/backend/test/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -426,5 +426,9 @@ fbcode_target(_kind = runtime.python_test,
deps = [
"//caffe2:torch",
"//executorch/exir/backend/canonical_partitioners:group_partitioner_lib",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
"//executorch/exir:lib",
"//pytorch/ao:torchao",
],
)
96 changes: 96 additions & 0 deletions exir/backend/test/test_group_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,3 +1672,99 @@ def forward(self, x):

# With allows_single_node_partition=True, we should have partitions
self.assertGreater(len(partitions_with_single), 0)

def test_interleaved_groups_no_false_merge(self):
"""
Test that _can_merge_partitions correctly rejects merges when two
partition groups interleave in topological order.

This reproduces a real-world failure with
XnnpackDynamicallyQuantizedPartitioner on transformer decoder models
where cross-attention K/V projections across multiple decoder layers
share the same encoder ``memory`` input. Dynamic quantization inserts
a shared ``choose_qparams`` node for that input, causing the DSJ phase
to create partition groups whose nodes span wide topological ranges.
When GroupBasedPartitioner later tries to merge these groups, the
original single-direction downstream check missed the cycle because it
assumed p2 is entirely before p1 — which is false for interleaved
groups.

The model is a minimal two-layer cross-attention decoder:

.. code-block:: text

query ──→ layer0(query, memory) ──→ layer1(x, memory) ──→ output
↑ ↑
memory ───────────┴─────────────────────────┘
(shared K/V input across layers)
"""
import math

class DecoderLayer(torch.nn.Module):
def __init__(self, d: int = 256):
super().__init__()
self.q_proj = torch.nn.Linear(d, d, bias=False)
self.k_proj = torch.nn.Linear(d, d, bias=False)
self.v_proj = torch.nn.Linear(d, d, bias=False)
self.out_proj = torch.nn.Linear(d, d, bias=False)
self.ffn1 = torch.nn.Linear(d, d * 2, bias=False)
self.ffn2 = torch.nn.Linear(d * 2, d, bias=False)
self.norm1 = torch.nn.LayerNorm(d)
self.norm2 = torch.nn.LayerNorm(d)

def forward(self, x: torch.Tensor, mem: torch.Tensor) -> torch.Tensor:
q = self.q_proj(x)
k = self.k_proj(mem)
v = self.v_proj(mem)
attn = torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
attn = torch.softmax(attn, dim=-1)
out = self.out_proj(torch.bmm(attn, v))
x = self.norm1(x + out)
x = self.norm2(x + self.ffn2(torch.relu(self.ffn1(x))))
return x

class TwoLayerDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer0 = DecoderLayer()
self.layer1 = DecoderLayer()

def forward(
self, query: torch.Tensor, memory: torch.Tensor
) -> torch.Tensor:
x = self.layer0(query, memory)
x = self.layer1(x, memory)
return x

from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackDynamicallyQuantizedPartitioner,
)
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.exir import to_edge_transform_and_lower
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

model = TwoLayerDecoder().eval()
query = torch.randn(1, 10, 256)
memory = torch.randn(1, 20, 256)

exported = torch.export.export(model, (query, memory), strict=False)

quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
)
prepared = prepare_pt2e(exported.module(), quantizer)
with torch.no_grad():
prepared(query, memory)
converted = convert_pt2e(prepared)

re_exported = torch.export.export(converted, (query, memory), strict=False)

# Before the fix this raised:
# AssertionError: Invalid partition, found dependency cycles
to_edge_transform_and_lower(
re_exported,
partitioner=[XnnpackDynamicallyQuantizedPartitioner()],
)
Loading