Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] TorchInductor does not correctly recognize the grad status of model code #125474

Open
xuzhao9 opened this issue May 3, 2024 · 3 comments
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@xuzhao9
Copy link
Contributor

xuzhao9 commented May 3, 2024

馃悰 Describe the bug

At

if config.freezing and not torch.is_grad_enabled():
, torchinductor finds torch.is_grad_enabled() is True even if the compiled code sets with torch.no_grad().

Reproduction script:

import torch
import torch.nn.functional as F

class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5)

    def forward(self, x):
        return F.relu(self.conv1(x))

model = TestModule()
example_inputs = torch.randn([1, 1, 10, 10]).cpu()

@torch.compile
def run(model, example_inputs):
    with torch.no_grad():
        # Check inductor grad status at
        # https://github.com/pytorch/pytorch/blob/30610251ec7b8f7e0507df06c3aadbcf90658e0e/torch/_inductor/compile_fx.py#L1371
        # If we print `torch._inductor.config.cpp.weight_prepack` , torchinductor believes torch.is_grad_enabled() is False
        # Otherwise, torchinductor believes torch.is_grad_enabled() is True 
        # print("torch._inductor.config.cpp.weight_prepack", torch._inductor.config.cpp.weight_prepack)
        model(example_inputs)

run(model, example_inputs)

Versions

The latest nightly release (20240503).

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented May 3, 2024

This is reported by a torchbench user: pytorch/benchmark#2253

@bdhirsh
Copy link
Contributor

bdhirsh commented May 3, 2024

The simplest workaround is to move the no_grad() outside of your compiled region:

with torch.no_grad():
    run(model, example_inputs)

But I think we should be able to fix this. The problem is that your compiled region enables no_grad at the very beginning: so when torch.compile is invoked, we have not yet hit the no_grad region. When we actually trace the graph, though, we will find that none of the outputs of the graph require gradients.

AOTAutograd has some similar logic to recognize when it is compiling for inference vs training here

It's a bit annoying though, since inductor is choosing what fw_compiler to use before invoking AOTAutograd. So we would have to change the "fw_compiler" that we pass in from inductor, to dynamically switch between the inference or fw variant, depending on what AOTAutograd discovered.

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented May 3, 2024

Thanks @bdhirsh for the response. I will keep this issue open since it is a valid issue with simple script to reproduce. On the torchbench side, we will workaround it by moving the no_grad() context before the compile region.

@yanboliang yanboliang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 6, 2024
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this issue May 6, 2024
Summary:
Related Inductor issue: pytorch/pytorch#125474.

Inductor cannot correctly handle `torch.no_grad()` context when it is inside the model code. To best leverage inductor code on inference, we are removing the `torch.no_grad()` context requirement from the model code. Instead, use the `pick_grad` helper function to manage the grad context from the framework level.

Fixes #2253

Pull Request resolved: #2256

Reviewed By: aaronenyeshi

Differential Revision: D57009780

Pulled By: xuzhao9

fbshipit-source-id: 08998e0da99105e1b4fcd5ca2cb3ba2764513f69
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants