-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Compiling on GPU with cpp_wrapper leads to ValueError: not enough values to unpack (expected 68, got 1)
#115035
Comments
ValueError: not enough values to unpack (expected 68, got 1)
ValueError: not enough values to unpack (expected 68, got 1)
This example fails in bare python too... It either should replace |
EDIT: this is actually due to mistakenly using This was introduced in #114067 @desertfire |
Repro from torch import nn
from torch._inductor import config
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10, device='cuda')
def forward(self, x):
return self.linear(x)
with torch.no_grad(), config.patch({"cpp_wrapper": True}):
model = Model()
model_opt = torch.compile(model)
model_opt(torch.zeros(10, device="cuda")) |
You are right. Using real_inputs works for the AOTInductor mode, but not JIT with cpp_wrapper, and our current test does not catch that. There is something additional needs to handled in your #115053, but let me comment there. |
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters. [ghstack-poisoned]
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
馃悰 Describe the bug
Don't have a repro for now.
But it originates in
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
I notice the tutorial only returns single values: https://pytorch.org/tutorials/prototype/inductor_cpp_wrapper_tutorial.html
Compiling for CUDA.
Versions
main
cc: @desertfire @jgong5
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @aakhundov @ColinPeppler
The text was updated successfully, but these errors were encountered: