Skip to content

Fix dependency cycle in GroupBasedPartitioner._can_merge_partitions#18397

Merged
GregoryComer merged 4 commits intopytorch:mainfrom
Hyungkeun-Park-Nota:fix/group-partitioner-cycle-detection
Apr 16, 2026
Merged

Fix dependency cycle in GroupBasedPartitioner._can_merge_partitions#18397
GregoryComer merged 4 commits intopytorch:mainfrom
Hyungkeun-Park-Nota:fix/group-partitioner-cycle-detection

Conversation

@Hyungkeun-Park-Nota
Copy link
Copy Markdown
Contributor

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota commented Mar 23, 2026

Summary

GroupBasedPartitioner._can_merge_partitions() only checks downstream dependencies from p2, assuming p2 is always topologically before p1. This assumption fails when partition groups contain nodes spanning wide topological ranges, causing false-negative cycle detection and ultimately AssertionError: Invalid partition, found dependency cycles at fuse_as_graphmodule time.

Root cause: Dynamic quantization inserts choose_qparams nodes that are shared across multiple GEMM ops consuming the same activation. The DSJ (Disjoint Set Join) phase merges these ops into groups whose nodes interleave in topological order. When _merge_partitions later tries to combine two such interleaved groups, the single-direction check (p2 only) misses the cycle path from p1 → external → p2.

Fix:

  1. Check external users from both p1 and p2 (combined_nodes) instead of only p2.
  2. Add a validate_partition() safety net (BFS on live graph edges) to catch any cycle the pre-computed _DependencyViewer might miss.

Reproduction

The issue is triggered when lowering a cross-attention transformer decoder with XnnpackDynamicallyQuantizedPartitioner. Multiple decoder layers share the same encoder output for K/V projections, causing choose_qparams sharing → DSJ group interleaving → false merge → dependency cycle.

Minimal reproduction (no external dependencies beyond PyTorch + ExecuTorch):

import math, torch, torch.nn as nn

class DecoderLayer(nn.Module):
    def __init__(self, d=256):
        super().__init__()
        self.q_proj = nn.Linear(d, d, bias=False)
        self.k_proj = nn.Linear(d, d, bias=False)
        self.v_proj = nn.Linear(d, d, bias=False)
        self.out_proj = nn.Linear(d, d, bias=False)
        self.ffn1 = nn.Linear(d, d * 2, bias=False)
        self.ffn2 = nn.Linear(d * 2, d, bias=False)
        self.norm1 = nn.LayerNorm(d)
        self.norm2 = nn.LayerNorm(d)

    def forward(self, x, mem):
        q, k, v = self.q_proj(x), self.k_proj(mem), self.v_proj(mem)
        attn = torch.softmax(torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)), dim=-1)
        x = self.norm1(x + self.out_proj(torch.bmm(attn, v)))
        return self.norm2(x + self.ffn2(torch.relu(self.ffn1(x))))

class TwoLayerDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer0 = DecoderLayer()
        self.layer1 = DecoderLayer()

    def forward(self, query, memory):
        return self.layer1(self.layer0(query, memory), memory)

# Export → dynamic quant → lower
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackDynamicallyQuantizedPartitioner
from executorch.exir import to_edge_transform_and_lower

model = TwoLayerDecoder().eval()
q, m = torch.randn(1, 10, 256), torch.randn(1, 20, 256)

exported = torch.export.export(model, (q, m), strict=False)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True))
prepared = prepare_pt2e(exported.module(), quantizer)
with torch.no_grad(): prepared(q, m)
converted = convert_pt2e(prepared)
re_exported = torch.export.export(converted, (q, m), strict=False)

# Before fix: AssertionError: Invalid partition, found dependency cycles
to_edge_transform_and_lower(re_exported, partitioner=[XnnpackDynamicallyQuantizedPartitioner()])

Test plan

  • Added test_interleaved_groups_no_false_merge in exir/backend/test/test_group_partitioner.py
  • Verified the test fails without the fix and passes with the fix
  • Existing test_group_partitioner.py tests pass

cc @JacobSzwejbka @angelayi @GregoryComer @digantdesai @cbilgin

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 23, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18397

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures, 2 Unrelated Failures

As of commit 50a2447 with merge base 520566c (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Mar 23, 2026

Hi @Hyungkeun-Park-Nota!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@Hyungkeun-Park-Nota
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "release notes: xnnpack"

@pytorch-bot pytorch-bot Bot added the release notes: xnnpack Changes to the XNNPack backend delegate label Mar 23, 2026
@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Mar 23, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 23, 2026
@nil-is-all nil-is-all added the module: exir Issues related to Export IR and the code under exir/ label Mar 23, 2026
@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/group-partitioner-cycle-detection branch from bc3bec6 to 73425f1 Compare March 24, 2026 03:40
@nil-is-all
Copy link
Copy Markdown
Contributor

@JacobSzwejbka bringing this to your attention

@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/group-partitioner-cycle-detection branch from 73425f1 to 4971367 Compare March 28, 2026 07:20
…erge_partitions

The previous implementation only checked downstream dependencies from
p2, assuming p2 always precedes p1 in topological order.  This
assumption breaks when partition groups contain nodes spanning wide
topological ranges — for example, when dynamic quantization inserts a
shared `choose_qparams` node consumed by GEMM ops in different
sequential transformer decoder layers.  In that case the two groups
*interleave* in topological order, and the single-direction check
misses cycles flowing from p1 through external nodes back into p2.

This change:
1. Collects external users from *both* p1 and p2 (combined_nodes)
   instead of only p2.
2. Adds a `validate_partition` safety net that performs a direct BFS
   on the live graph edges, catching any cycle the pre-computed
   `_DependencyViewer` might miss.

Fixes `AssertionError: Invalid partition, found dependency cycles`
when lowering cross-attention transformer decoders (e.g. DETR) with
`XnnpackDynamicallyQuantizedPartitioner`.
@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/group-partitioner-cycle-detection branch from 4971367 to cbc1d31 Compare March 30, 2026 06:59
@nil-is-all nil-is-all added the module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ label Apr 1, 2026
- Replace torch.ao.quantization imports with torchao.quantization.pt2e.quantize_pt2e
  and executorch.backends.xnnpack.quantizer.xnnpack_quantizer
- Fix UFMT formatting issues in forward() signatures and torch.bmm/norm2 calls
@Hyungkeun-Park-Nota Hyungkeun-Park-Nota force-pushed the fix/group-partitioner-cycle-detection branch from b695175 to 10ea9c1 Compare April 14, 2026 05:47
@GregoryComer
Copy link
Copy Markdown
Member

@Hyungkeun-Park-Nota Thanks for the contribution. Can you fix the lints? Once CI is green, I can go ahead and merge.

@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Apr 14, 2026

@GregoryComer has imported this pull request. If you are a Meta employee, you can view this in D100871526.

@GregoryComer
Copy link
Copy Markdown
Member

I just re-triggered CI. @Hyungkeun-Park-Nota can you make one other small change? We'll need to add a few deps to the buck build. We should be able to merge after that.

In exir‎/backend‎/test‎/‎BUCK, can you add these dependencies to the test_group_partitioner target?

  • //executorch/backends/xnnpack/partition:xnnpack_partitioner
  • //executorch/backends/xnnpack/quantizer:xnnpack_quantizer
  • //executorch/exir:lib
  • //pytorch/ao:torchao

Thanks!

@GregoryComer
Copy link
Copy Markdown
Member

Mypy lint is pre-existing

@GregoryComer GregoryComer merged commit 3998693 into pytorch:main Apr 16, 2026
159 of 164 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: exir Issues related to Export IR and the code under exir/ module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ release notes: xnnpack Changes to the XNNPack backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants