Skip to content

Prevent folding of mutable input to copy_ and put_ operators#4181

Merged
jerryzh168 merged 2 commits intopytorch:mainfrom
tom-arm:improve_buffer_folding
Apr 10, 2026
Merged

Prevent folding of mutable input to copy_ and put_ operators#4181
jerryzh168 merged 2 commits intopytorch:mainfrom
tom-arm:improve_buffer_folding

Conversation

@tom-arm
Copy link
Copy Markdown
Contributor

@tom-arm tom-arm commented Mar 26, 2026

  • Find mutable buffers for put_ operations as well as copy_.
  • Trace the graph to the buffer and mark nodes found as mutable to prevent folding.

* Find mutable buffers for put_ operations as well as copy_.
* Trace the graph to the buffer and mark nodes found as mutable
  to prevent folding.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 26, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4181

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit e4779ec with merge base 136cacb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2026
@tom-arm
Copy link
Copy Markdown
Contributor Author

tom-arm commented Mar 26, 2026

Partly solves the issue raised here: #4138. It should stop the mutable buffer being folded into a constant

Comment thread torchao/quantization/pt2e/constant_fold.py Outdated
Comment thread torchao/quantization/pt2e/constant_fold.py Outdated
* Replace _collect_mutable_chain with collect_producer_nodes
* Fix typo
@tom-arm tom-arm requested review from andrewor14 and vkuzo as code owners April 7, 2026 11:51
# all producers as mutable to prevent constant folding.
if len(node.args) > 0 and isinstance(node.args[0], torch.fx.Node):
mutable_buffers.add(node.args[0])
producer_nodes = collect_producer_nodes(node.args[0])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh didn't know this works as is, but great it works

Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@jerryzh168 jerryzh168 added the module: pt2e_quant pt2 export quantization (prepare_pt2e, convert_pt2e, quantizer) label Apr 9, 2026
import torch
import torch.utils._pytree as pytree
from torch._inductor.freezing_utils import maybe_set_is_frozen_param
from torch.ao.quantization.fx.utils import collect_producer_nodes
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, we can also copy this util to torchao, since we want to avoid dep on torch.ao.quantization, but better to do this in a separate PR, please feel free to put up PR for this one as well, otherwise I can work on this a bit later

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for letting me know, I don't mind doing that for you.

so just copy the collect_producer_nodes function to torchao/quantization/utils.py?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, that would be great, you can copy that here: https://github.com/pytorch/ao/blob/main/torchao/quantization/pt2e/graph_utils.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have raised a PR for this here: #4294

@jerryzh168 jerryzh168 merged commit 9b1b902 into pytorch:main Apr 10, 2026
20 checks passed
@tom-arm tom-arm deleted the improve_buffer_folding branch April 10, 2026 12:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: pt2e_quant pt2 export quantization (prepare_pt2e, convert_pt2e, quantizer)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants