-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[NVFuser] don't decompose linear if we don't have shape info #75770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 29db863 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
| } | ||
| return false; | ||
| } | ||
| // linear can't be fused, but for other reasons it needs to be parsible |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah... Sorry about this. We've separated fusion/profile node registration. So instead of special case on block fusion, we should have updated parser.cpp. Can't believe we haven't run into this earlier 😕
Let me push a patch to your PR.
| TORCH_INTERNAL_ASSERT( | ||
| mat0_size.has_value() && mat1_size.has_value(), | ||
| "concrete shape for linear input & weight are required"); | ||
| if (!mat0_size.has_value() || !mat1_size.has_value()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's ok that we throw a one time warning here? Or maybe even only throw warning when we are in debug.
The assert was put here so that we would be well aware of missing profiling information.
| inp = torch.rand((x, x + i)).cuda() | ||
| weight = torch.rand((x + 2, x + i)).cuda() | ||
| bias = torch.rand((x, x + 2)).cuda() | ||
| y += torch.sin(torch.nn.functional.linear(inp, weight, bias)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, why would this case not have profiling information on linear output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Across the two invocations of linear (i.e. i=0 and i=1) there are different shapes for the inputs of linear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should be unrolling this loop though.. i guess the profiling happens first, then unrolling
[ghstack-poisoned]
[ghstack-poisoned]
| nullptr); | ||
| } | ||
|
|
||
| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we don't need to add profiling for linear now. Wondering if some other pass were doing that. Removing this as a whole should break TestCudaFuser.test_linear but it apparently didn't :)
I'm thinking more like this one to be on the safer side. (waiting on build to test this)
https://github.com/pytorch/pytorch/pull/75897/files#diff-9738ee51d55cdf479f41b018d61ca43ed2ae10c13a20e21ec45b7578e82a23d2R2457-R2460
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double checked this on my local build. ProfileRegistry is profiling all differentiable node so we don't have to register linear for profiling in nvfuser. But I still do prefer to have it registered just to be on the safe side.
Nevertheless, since this is not blocking, I'm stamping it
| nullptr); | ||
| } | ||
|
|
||
| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double checked this on my local build. ProfileRegistry is profiling all differentiable node so we don't have to register linear for profiling in nvfuser. But I still do prefer to have it registered just to be on the safe side.
Nevertheless, since this is not blocking, I'm stamping it
[ghstack-poisoned]
|
Thanks for the help @jjsjann123 - tested your changes from #75897 and merged them into this commit |
|
Is there still anything blocking this and should we ping the bot to merge it? trying to cherry-pick upstream fixes to our local branch. |
|
@pytorchmergebot merge this |
|
Merge failed due to Matched rule superuser, but it was not reviewed yet by any of:esantorella,NivekT,blefaudeux,ngimel,deeptigp, ... |
|
@pytorchmergebot merge this |
|
Hey @davidberard98. |
Pull Request resolved: pytorch#75770 Approved by: https://github.com/jjsjann123, https://github.com/robieta
Pull Request resolved: pytorch#75770 Approved by: https://github.com/jjsjann123, https://github.com/robieta
Summary: Pull Request resolved: #75770 Approved by: https://github.com/jjsjann123, https://github.com/robieta Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ebb60a8b2f495ada503b6386aef0ad61f44b7fac Reviewed By: mehtanirav Differential Revision: D35721647 Pulled By: davidberard98 fbshipit-source-id: bea898927cd6dd740f8ceb09528b3d8dcb71c4cd
Pull Request resolved: #75770 Approved by: https://github.com/jjsjann123, https://github.com/robieta (cherry picked from commit ebb60a8)
Pull Request resolved: pytorch#75770 Approved by: https://github.com/jjsjann123, https://github.com/robieta Co-authored-by: David Berard <dberard@fb.com>
Pull Request resolved: pytorch/pytorch#75770 Approved by: https://github.com/jjsjann123, https://github.com/robieta
Pull Request resolved: pytorch/pytorch#75770 Approved by: https://github.com/jjsjann123, https://github.com/robieta
Pull Request resolved: pytorch/pytorch#75770 Approved by: https://github.com/jjsjann123, https://github.com/robieta Co-authored-by: David Berard <dberard@fb.com>
Stack from ghstack: