Prevent folding of mutable input to copy_ and put_ operators#4181
Prevent folding of mutable input to copy_ and put_ operators#4181jerryzh168 merged 2 commits intopytorch:mainfrom
Conversation
tom-arm
commented
Mar 26, 2026
- Find mutable buffers for put_ operations as well as copy_.
- Trace the graph to the buffer and mark nodes found as mutable to prevent folding.
* Find mutable buffers for put_ operations as well as copy_. * Trace the graph to the buffer and mark nodes found as mutable to prevent folding.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4181
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit e4779ec with merge base 136cacb ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Partly solves the issue raised here: #4138. It should stop the mutable buffer being folded into a constant |
* Replace _collect_mutable_chain with collect_producer_nodes * Fix typo
| # all producers as mutable to prevent constant folding. | ||
| if len(node.args) > 0 and isinstance(node.args[0], torch.fx.Node): | ||
| mutable_buffers.add(node.args[0]) | ||
| producer_nodes = collect_producer_nodes(node.args[0]) |
There was a problem hiding this comment.
oh didn't know this works as is, but great it works
| import torch | ||
| import torch.utils._pytree as pytree | ||
| from torch._inductor.freezing_utils import maybe_set_is_frozen_param | ||
| from torch.ao.quantization.fx.utils import collect_producer_nodes |
There was a problem hiding this comment.
btw, we can also copy this util to torchao, since we want to avoid dep on torch.ao.quantization, but better to do this in a separate PR, please feel free to put up PR for this one as well, otherwise I can work on this a bit later
There was a problem hiding this comment.
thanks for letting me know, I don't mind doing that for you.
so just copy the collect_producer_nodes function to torchao/quantization/utils.py?
There was a problem hiding this comment.
thanks, that would be great, you can copy that here: https://github.com/pytorch/ao/blob/main/torchao/quantization/pt2e/graph_utils.py