Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] cpp codegen using already deleted header #126128

Closed
xmfan opened this issue May 14, 2024 · 3 comments
Closed

[inductor] cpp codegen using already deleted header #126128

xmfan opened this issue May 14, 2024 · 3 comments
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@xmfan
Copy link
Member

xmfan commented May 14, 2024

馃悰 Describe the bug

Inductor seems to be using a header filename generated from a previous test run, but temporary directories are cleared between tests.

Issue boils down to the fresh_inductor_cache wrapper not cleaning up properly once exited

import torch
from torch._inductor.utils import fresh_inductor_cache

def fn():
    x = torch.randn(1, 10)
    y = torch.randn(10, 1)
    return torch.mm(x, y).sum()

def fn2():
    x = torch.randn(10, 100)
    y = torch.randn(100, 10)
    return torch.mm(x, y).sum()

with fresh_inductor_cache():
    torch.compile(fn)()

torch.compile(fn2)()

makes compiled autograd flaky tests: TORCHINDUCTOR_COMPILE_THREADS=1 pytest test/inductor/test_compiled_autograd.py -k 'test_torch_compile or test_access_saved_tensor_twice_without_recomputation_works'

Versions

main

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

@jgong5
Copy link
Collaborator

jgong5 commented May 14, 2024

The header file about cpp_prefix.h and is named after its hash. It should be generated before compiling the generated cpp code. Things are being tested with other UTs in test_torchinductor.py and test_cpu_repro.py. Is it test_compile_autograd.py specific?

@xmfan
Copy link
Member Author

xmfan commented May 14, 2024

I only have a repro with test_compiled_autograd.py, do you have a code pointer to where we generate the header name?

@xmfan
Copy link
Member Author

xmfan commented May 14, 2024

@jgong5 found a smaller repro, and have a fix

import torch
from torch._inductor.utils import fresh_inductor_cache

def fn():
    x = torch.randn(1, 10)
    y = torch.randn(10, 1)
    return torch.mm(x, y).sum()

def fn2():
    x = torch.randn(10, 100)
    y = torch.randn(100, 10)
    return torch.mm(x, y).sum()

with fresh_inductor_cache():
    torch.compile(fn)()

torch.compile(fn2)()

xmfan added a commit that referenced this issue May 14, 2024
FIXES #126128

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
xmfan added a commit that referenced this issue May 14, 2024
FIXES #126128.

Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again,  usually fine in tests.

Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. 
TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
xmfan added a commit that referenced this issue May 14, 2024
FIXES #126128.

Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again,  usually fine in tests.

Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. 
TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@xmfan xmfan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 14, 2024
@xmfan xmfan self-assigned this May 14, 2024
xmfan added a commit that referenced this issue May 16, 2024
FIXES #126128.

Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again,  usually fine in tests.

Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. 
TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
xmfan added a commit that referenced this issue May 16, 2024
FIXES #126128.

Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again,  usually fine in tests.

Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. 
TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
xmfan added a commit that referenced this issue May 16, 2024
FIXES #126128.

Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again,  usually fine in tests.

Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. 
TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
xmfan added a commit that referenced this issue May 16, 2024
FIXES #126128.

Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again,  usually fine in tests.

Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. 
TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ZelboK pushed a commit to ZelboK/pytorch that referenced this issue May 19, 2024
FIXES pytorch#126128.

Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again,  usually fine in tests.

Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd.
TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't

Pull Request resolved: pytorch#126146
Approved by: https://github.com/jgong5, https://github.com/oulgen
ghstack dependencies: pytorch#126144
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants