Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] Compiling on GPU with cpp_wrapper leads to ValueError: not enough values to unpack (expected 68, got 1) #115035

Closed
jon-chuang opened this issue Dec 3, 2023 · 4 comments
Assignees
Labels
high priority module: inductor oncall: pt2 triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jon-chuang
Copy link
Collaborator

jon-chuang commented Dec 3, 2023

馃悰 Describe the bug

Don't have a repro for now.

But it originates in self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()

I notice the tutorial only returns single values: https://pytorch.org/tutorials/prototype/inductor_cpp_wrapper_tutorial.html

Compiling for CUDA.

Versions

main

cc: @desertfire @jgong5

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @aakhundov @ColinPeppler

@jon-chuang jon-chuang changed the title [Inductor] Compiling with cpp_wrapper leads to ValueError: not enough values to unpack (expected 68, got 1) [Inductor] Compiling on GPU with cpp_wrapper leads to ValueError: not enough values to unpack (expected 68, got 1) Dec 3, 2023
@vadimkantorov
Copy link
Contributor

vadimkantorov commented Dec 3, 2023

For instance, this fails:

This example fails in bare python too... a, b = fn() ValueError: not enough values to unpack (expected 2, got 1)

It either should replace b by *b or (1, 2,), by (1, 2,),None

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Dec 4, 2023

EDIT: this is actually due to mistakenly using example_inputs as the inputs to the graph. Somehow this does not correspond the the actual inputs to the graphmodule forward...

This was introduced in #114067 @desertfire

@jon-chuang
Copy link
Collaborator Author

Repro

from torch import nn
from torch._inductor import config

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10, device='cuda')
    
    def forward(self, x):
        return self.linear(x)
    

with torch.no_grad(), config.patch({"cpp_wrapper": True}):
    model = Model()
    model_opt = torch.compile(model)
    model_opt(torch.zeros(10, device="cuda"))

@desertfire
Copy link
Contributor

desertfire commented Dec 4, 2023

EDIT: this is actually due to mistakenly using example_inputs as the inputs to the graph. Somehow this does not correspond the the actual inputs to the graphmodule forward...

This was introduced in #114067 @desertfire

You are right. Using real_inputs works for the AOTInductor mode, but not JIT with cpp_wrapper, and our current test does not catch that. There is something additional needs to handled in your #115053, but let me comment there.

@shunting314 shunting314 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 5, 2023
desertfire added a commit that referenced this issue Dec 20, 2023
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

[ghstack-poisoned]
desertfire added a commit that referenced this issue Dec 20, 2023
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

ghstack-source-id: 35086f97da5f02a6e02e83f8eb369e37dfeaf712
Pull Request resolved: #116197
desertfire added a commit that referenced this issue Dec 20, 2023
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
desertfire added a commit that referenced this issue Dec 20, 2023
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
desertfire added a commit that referenced this issue Dec 20, 2023
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

ghstack-source-id: 37585d72987460663641f27746041a461166cf3b
Pull Request resolved: #116197
pytorchmergebot pushed a commit that referenced this issue Dec 26, 2023
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

ghstack-source-id: cc8a92904893696d34382153fbbd65a1905c83de
Pull Request resolved: #116197
pytorchmergebot pushed a commit that referenced this issue Dec 26, 2023
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Dec 26, 2023
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: inductor oncall: pt2 triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants