Skip to content

Conversation

angelayi
Copy link
Contributor

@angelayi angelayi commented Sep 18, 2025

Fixes #163294

The code with torch.set_grad_enabled(enable_grad) calls torch._C._set_grad_enabled three times -- (1) when initializing set_grad_enabled, (2) when entering the context, and (3) when exiting the context.

This results in the the retraced export module to have a duplicate torch._C._set_grad_enabled like:

def forward(self, arg0_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    add_1 = torch.ops.aten.add.Tensor(add, 2);  add = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    add_2 = torch.ops.aten.add.Tensor(add_1, 3);  add_1 = None
    return (add_2,)

When export runs the replace_set_grad_with_hop_pass, it will look through the graph for torch._C._set_grad_enabled and create subgraphs. The duplicate torch._C._set_grad_enabled results in an empty submod in the graph, which resulted in an error in this post.

@angelayi angelayi requested a review from yushangdi September 18, 2025 21:18
@angelayi angelayi requested a review from zou3519 as a code owner September 18, 2025 21:18
Copy link

pytorch-bot bot commented Sep 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163295

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ec1548e with merge base bb7c9a2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@angelayi angelayi added the topic: not user facing topic category label Sep 18, 2025
@angelayi
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 19, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@angelayi
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Fixes pytorch#163294

The code `with torch.set_grad_enabled(enable_grad)` calls `torch._C._set_grad_enabled` three times -- (1) when [initializing set_grad_enabled](https://github.com/pytorch/pytorch/blob/bb7c9a2d4127ff178fe8787caf070d7072d015b6/torch/autograd/grad_mode.py#L187C9-L187C35), (2) when [entering the context](https://github.com/pytorch/pytorch/blob/bb7c9a2d4127ff178fe8787caf070d7072d015b6/torch/autograd/grad_mode.py#L194), and (3) when [exiting the context](https://github.com/pytorch/pytorch/blob/bb7c9a2d4127ff178fe8787caf070d7072d015b6/torch/autograd/grad_mode.py#L197).

This results in the the retraced export module to have a duplicate `torch._C._set_grad_enabled` like:
```
def forward(self, arg0_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    add_1 = torch.ops.aten.add.Tensor(add, 2);  add = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    add_2 = torch.ops.aten.add.Tensor(add_1, 3);  add_1 = None
    return (add_2,)
```

When export runs the `replace_set_grad_with_hop_pass`, it will look through the graph for `torch._C._set_grad_enabled` and create subgraphs. The duplicate `torch._C._set_grad_enabled` results in an empty submod in the graph, which resulted in an error in [this post](https://fb.workplace.com/groups/1028545332188949/posts/1844720036398281/?comment_id=1862175381319413).
Pull Request resolved: pytorch#163295
Approved by: https://github.com/yushangdi
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
Fixes pytorch#163294

The code `with torch.set_grad_enabled(enable_grad)` calls `torch._C._set_grad_enabled` three times -- (1) when [initializing set_grad_enabled](https://github.com/pytorch/pytorch/blob/bb7c9a2d4127ff178fe8787caf070d7072d015b6/torch/autograd/grad_mode.py#L187C9-L187C35), (2) when [entering the context](https://github.com/pytorch/pytorch/blob/bb7c9a2d4127ff178fe8787caf070d7072d015b6/torch/autograd/grad_mode.py#L194), and (3) when [exiting the context](https://github.com/pytorch/pytorch/blob/bb7c9a2d4127ff178fe8787caf070d7072d015b6/torch/autograd/grad_mode.py#L197).

This results in the the retraced export module to have a duplicate `torch._C._set_grad_enabled` like:
```
def forward(self, arg0_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    add_1 = torch.ops.aten.add.Tensor(add, 2);  add = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    add_2 = torch.ops.aten.add.Tensor(add_1, 3);  add_1 = None
    return (add_2,)
```

When export runs the `replace_set_grad_with_hop_pass`, it will look through the graph for `torch._C._set_grad_enabled` and create subgraphs. The duplicate `torch._C._set_grad_enabled` results in an empty submod in the graph, which resulted in an error in [this post](https://fb.workplace.com/groups/1028545332188949/posts/1844720036398281/?comment_id=1862175381319413).
Pull Request resolved: pytorch#163295
Approved by: https://github.com/yushangdi
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Fixes pytorch#163294

The code `with torch.set_grad_enabled(enable_grad)` calls `torch._C._set_grad_enabled` three times -- (1) when [initializing set_grad_enabled](https://github.com/pytorch/pytorch/blob/bb7c9a2d4127ff178fe8787caf070d7072d015b6/torch/autograd/grad_mode.py#L187C9-L187C35), (2) when [entering the context](https://github.com/pytorch/pytorch/blob/bb7c9a2d4127ff178fe8787caf070d7072d015b6/torch/autograd/grad_mode.py#L194), and (3) when [exiting the context](https://github.com/pytorch/pytorch/blob/bb7c9a2d4127ff178fe8787caf070d7072d015b6/torch/autograd/grad_mode.py#L197).

This results in the the retraced export module to have a duplicate `torch._C._set_grad_enabled` like:
```
def forward(self, arg0_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    add_1 = torch.ops.aten.add.Tensor(add, 2);  add = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    add_2 = torch.ops.aten.add.Tensor(add_1, 3);  add_1 = None
    return (add_2,)
```

When export runs the `replace_set_grad_with_hop_pass`, it will look through the graph for `torch._C._set_grad_enabled` and create subgraphs. The duplicate `torch._C._set_grad_enabled` results in an empty submod in the graph, which resulted in an error in [this post](https://fb.workplace.com/groups/1028545332188949/posts/1844720036398281/?comment_id=1862175381319413).
Pull Request resolved: pytorch#163295
Approved by: https://github.com/yushangdi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

retracing set_grad HOO creates an empty submod

3 participants