-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Give linear an explicit autograd formula always #162411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/ezyang/3144/base
Are you sure you want to change the base?
Conversation
Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162411
Note: Links to docs will display an error until the docs builds have been completed. ❌ 16 New FailuresAs of commit 25a4078 with merge base 8171d60 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
There are still test failures though. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
| // dL/dW = sum_batch ( (dL/dy)ᵀ @ x ) | ||
| // Use einsum to contract over all leading dims without reshaping: | ||
| if (output_mask[1]) { | ||
| grad_weight = at::einsum("...o,...i->oi", {grad_output, self}); // [out, in] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't this go into decomposition, since linear_backward is CompositeImplicitAutograd now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. So another PR we have to do is make einsum not decompose, but THAT is likely to be a lot more controversial. Another reason why making views work is "better" (if you can shake it)
|
|
||
| - name: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor | ||
| input, weight, bias: "grad.defined() ? linear_backward(input, grad, weight, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" | ||
| result: auto_linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious what does this line do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for forward mode AD; it says that this function is linear and thus is forward ad formula is trivial (do the same function)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds ok as long as there is no perf hit, this is a pretty hot path.
| auto self_ = moveBatchDimToFront(self, self_bdim); | ||
| auto weight_ = moveBatchDimToFront(weight, weight_bdim); | ||
| auto bias_ = bias.has_value() ? std::make_optional<Tensor>(moveBatchDimToFront(*bias, bias_bdim)) : std::nullopt; | ||
| return std::make_tuple( at::linear(self_, weight_, bias_), 0 ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ho linear supports arbitrary batch dimensions on the weights?
Your backward formula seems to say no? :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoops :)
| // dL/dW = sum_batch ( (dL/dy)ᵀ @ x ) | ||
| // Use einsum to contract over all leading dims without reshaping: | ||
| if (output_mask[1]) { | ||
| grad_weight = at::einsum("...o,...i->oi", {grad_output, self}); // [out, in] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the perf hit of this for a regular nn.Linear() layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is potentially actually pretty bad lol. And it doesn't even do what I want, because I want the einsum to also show up as its own operator LOL.
|
Also I expect a lot more changes in the PT2 compilation stack to handle the new op and remove special casing for linear decomp. |
Stack from ghstack (oldest at bottom):
Signed-off-by: Edward Z. Yang ezyang@meta.com