-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] cpp codegen using already deleted header #126128
Comments
The header file about |
I only have a repro with test_compiled_autograd.py, do you have a code pointer to where we generate the header name? |
@jgong5 found a smaller repro, and have a fix import torch
from torch._inductor.utils import fresh_inductor_cache
def fn():
x = torch.randn(1, 10)
y = torch.randn(10, 1)
return torch.mm(x, y).sum()
def fn2():
x = torch.randn(10, 100)
y = torch.randn(100, 10)
return torch.mm(x, y).sum()
with fresh_inductor_cache():
torch.compile(fn)()
torch.compile(fn2)() |
FIXES #126128 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
FIXES #126128. Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again, usually fine in tests. Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
FIXES #126128. Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again, usually fine in tests. Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
FIXES #126128. Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again, usually fine in tests. Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
FIXES #126128. Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again, usually fine in tests. Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
FIXES #126128. Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again, usually fine in tests. Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
FIXES #126128. Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again, usually fine in tests. Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
FIXES pytorch#126128. Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again, usually fine in tests. Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't Pull Request resolved: pytorch#126146 Approved by: https://github.com/jgong5, https://github.com/oulgen ghstack dependencies: pytorch#126144
馃悰 Describe the bug
Inductor seems to be using a header filename generated from a previous test run, but temporary directories are cleared between tests.
Issue boils down to the fresh_inductor_cache wrapper not cleaning up properly once exited
makes compiled autograd flaky tests:
TORCHINDUCTOR_COMPILE_THREADS=1 pytest test/inductor/test_compiled_autograd.py -k 'test_torch_compile or test_access_saved_tensor_twice_without_recomputation_works'
Versions
main
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
The text was updated successfully, but these errors were encountered: