-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Implement PReLU in a compositional way #91238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We port it to TI and split it into broadcasting operations (CompositeImplicit) and kernel fusion (via TI). This makes the CPU and CUDA parts agree and heavily simplifies the code. This also includes a number of fixes like using opmath_t consistently and a more efficient backward implementation in the non-contiguous case without copying the data. Fixes #68760 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91238
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 1633602: FLAKY - The following jobs failed but were likely due to flakiness present on master:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@soulitzer @albanD you have been tagged as codeowners, but could you please review it? The only issue with the current approach is that I was not able to port the MKLDNN and the MPS implementations to just implement the simpler new kernel, so those should still be buggy. If the relevant PoCs from Intel / Apple port these, we should be able to remove all the relevant entries from derivatives.yaml and native_functions.yaml. Quantized fails as well, no idea why :( |
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We port it to TI and split it into broadcasting operations (CompositeImplicit) and kernel fusion (via TI). This makes the CPU and CUDA parts agree and heavily simplifies the code. This also includes a number of fixes like using opmath_t consistently and a more efficient backward implementation in the non-contiguous case without copying the data. Fixes #68760 cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We fix it by: - Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel - This second kernel is just a good-ol' pointwise kernel. - We implement the derivative for the pointwise kernel via TI as well for speed. - We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally This fixes a number of issues: - We don't perform copies any more when the inputs are not contiguous - The derivatives are now correct - We fix vmap - CPU and CUDA now share the relevant broadcasting logic - The implementation is about 1/3 the length. Fixes #68760 Fixes #89895 cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We fix it by: - Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel - This second kernel is just a good-ol' pointwise kernel. - We implement the derivative for the pointwise kernel via TI as well for speed. - We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally This fixes a number of issues: - We don't perform copies any more when the inputs are not contiguous - The derivatives are now correct - We fix vmap - CPU and CUDA now share the relevant broadcasting logic - The implementation is about 1/3 the length. Fixes #68760 Fixes #89895 cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why this has to be more complex than functions like amax which we can implement without a pair a functions?
MkldnnCPU: mkldnn_prelu | ||
CPU: prelu_cpu | ||
CUDA: prelu_cuda | ||
MPS: prelu_mps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @kulinseth if you are interested in this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (on functorch batching rule front) :)
aten/src/ATen/native/Activation.cpp
Outdated
at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU); | ||
} | ||
Tensor _prelu_kernel(const Tensor& self, const Tensor& weight) { | ||
auto options = self.options().dtype(promoteTypes(self.scalar_type(), weight.scalar_type())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we want type promoting behavior for prelu
, and we don't have it currently. Moreover, this isn't even the right type promoting behavior, as it would allow 0d weight to dictate output type for a 1d+ input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason TI does not accept undefined tensors. Structured kernels bypass this by being tightly knit to TI. I don't know whether there is a good way to let TI handle type promotion outside of structured kernels.
That being said, yes, let's drop type promotion altogether!
torch/_decomp/decompositions.py
Outdated
|
||
|
||
@register_decomposition(aten._prelu_kernel) | ||
@pw_cast_for_opmath |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pw_cast_for_opmath not needed here, it only leads to more complicated IR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous kernel used opmath_t
and has a dedicated path for BFloat16
. Should I drop those as well (I think we should).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are outputting either self
or weight * self
, in both cases there's no difference whether they are pre-converted to float or not here - self
is a copy either way, weight * self
would do whatever is needed in multiplication
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll clean up all that code then.
torch/_decomp/decompositions.py
Outdated
|
||
|
||
@register_decomposition(aten._prelu_kernel_backward) | ||
@pw_cast_for_opmath |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly here
@albanD Re. amax. It's a bit trickier, because here we are implementing the backward via TI as well. I kept it this way to not regress performance, as the backward already had its own dedicated kernel. Now, arguably, we would be better dropping that and simply implementing the backward compositionally, as the difference in performance will be minimal. Happy to do that if you think that's the way to go. |
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We fix it by: - Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel - This second kernel is just a good-ol' pointwise kernel. - We implement the derivative for the pointwise kernel via TI as well for speed. - We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally This fixes a number of issues: - We don't perform copies any more when the inputs are not contiguous - The derivatives are now correct - We fix vmap and many other functorch-related issues. - CPU and CUDA now share the relevant broadcasting logic - The implementation is about 1/3 the length. Fixes #68760 Fixes #89895 cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We port it to TI and split it into broadcasting operations (CompositeImplicit) and kernel fusion (via TI). This makes the CPU and CUDA parts agree and heavily simplifies the code. This also includes a number of fixes like using opmath_t consistently and a more efficient backward implementation in the non-contiguous case without copying the data. Fixes #68760 ghstack-source-id: d6362ea Pull Request resolved: #91238
@yanbing-j I think the code would break in the case where input is 1D and weight is 0D, then the code |
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We fix it by: - Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel - This second kernel is just a good-ol' pointwise kernel. - We implement the derivative for the pointwise kernel via TI as well for speed. - We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally This fixes a number of issues: - We don't perform copies any more when the inputs are not contiguous - The derivatives are now correct - We fix vmap and many other functorch-related issues. - CPU and CUDA now share the relevant broadcasting logic - The implementation is about 1/3 the length. Fixes #68760 Fixes #89895 cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We port it to TI and split it into broadcasting operations (CompositeImplicit) and kernel fusion (via TI). This makes the CPU and CUDA parts agree and heavily simplifies the code. This also includes a number of fixes like using opmath_t consistently and a more efficient backward implementation in the non-contiguous case without copying the data. Fixes #68760 ghstack-source-id: 73bc805 Pull Request resolved: #91238
No, I don't think we can do that since oneDNN expects the weight dim is the same as input. See this: https://oneapi-src.github.io/oneDNN/v2/dev_guide_prelu.html
We can fix this one for the case when weight is a scalar. But it would be in the ATen/native/mkldnn/Prelu.cpp not in ideep, I guess, since ideep doesn't support 0-dim tensor. @yanbing-j |
If ideep does not support 0-dim tensors, then the code there was correct all along. About the ideep fix, the thing is that ideep does some manual resizing that's no longer necessary after this PR. That code can be removed. Note that oneDNN would get the tensors with the same number of dimensions, so nothing would change there, the only thing that would change is that the resizing would now be done at the ATen level, rather than at the ideep level. |
Can I please get reviews from: I think that the quantisation and the MKLDNN changes are minimal so they don't need much of a review, or they can be reviewed by non-domain experts. |
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We fix it by: - Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel - This second kernel is just a good-ol' pointwise kernel. - We implement the derivative for the pointwise kernel via TI as well for speed. - We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally This fixes a number of issues: - We don't perform copies any more when the inputs are not contiguous - The derivatives are now correct - We fix vmap and many other functorch-related issues. - CPU and CUDA now share the relevant broadcasting logic - The implementation is about 1/3 the length. Fixes #68760 Fixes #89895 cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all your work getting this cleaned up! I'm good with it on my end though I think ideally we would avoid adding any ops.
Without adding _prelu_kernel
/ _prelu_backward_kernel
, I think we could still get de-duplication of the broadcasting logic by dispatching CPU / CUDA to the same function and using REGISTER_DISPATCH
dispatching from there. However, I do understand that this doesn't fix vmap / simplify writing batch rules or simplify derivative computation, which I think are valuable. So imo the tradeoffs are worth it.
aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We fix it by: - Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel - This second kernel is just a good-ol' pointwise kernel. - We implement the derivative for the pointwise kernel via TI as well for speed. - We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally This fixes a number of issues: - We don't perform copies any more when the inputs are not contiguous - The derivatives are now correct - We fix vmap and many other functorch-related issues. - CPU and CUDA now share the relevant broadcasting logic - The implementation is about 1/3 the length. Fixes #68760 Fixes #89895 cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
The PReLU implementation was all over the place. This lead to a number of bugs like #68760. We port it to TI and split it into broadcasting operations (CompositeImplicit) and kernel fusion (via TI). This makes the CPU and CUDA parts agree and heavily simplifies the code. This also includes a number of fixes like using opmath_t consistently and a more efficient backward implementation in the non-contiguous case without copying the data. Fixes #68760 ghstack-source-id: fc4c358 Pull Request resolved: #91238
Great coordinated reviews. Even better being the two of them approvals :D @pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 additional jobs have failed, first few of them are: trunk ,trunk / linux-focal-rocm5.3-py3.8 / test (default, 1, 2, linux.rocm.gpu) Details for Dev Infra teamRaised by workflow job |
Yes. We have changed the weight dimension to input dimension in https://github.com/pytorch/pytorch/pull/91238/files#diff-3dc8af925ec5d52a7cc5722279de60aeac5458097166318bcf6bfd58c420e629L710. The scenario of 1D input and 0D weight is supported. @jgong5 @lezcano
Correct. The resizing has been done in ATen level, if the dimensions of src and weight are the same, the resizing in ideep level will not be invoked, and it can be regarded as a double check. @lezcano |
Hi @lezcano Looks like the quantization issue is fixed, right? #91238 (comment) |
yeah, all good on the quantisation end. Thank you for checking in though @Xia-Weiwen |
@pytorchbot merge -f "flaky tests" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
### Motivation Add `prelu` to lower precision cast policy on AutocastCPU to fix #95365 : Before: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , `prelu` cannot address the scenario of different datatypes of `input` and `weight`, will get a RuntimeError. This scenario is common in autocast, e.g, with `autocast` to `bf16`, if the `op` before `prelu` comes out a `bf16` output, which is the input of `prelu`, and `prelu's` weight is `fp32`, then it will get a RuntimeError. After: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , prelu be forced to run with `bf16` data type. Before #91238, when input is `bf16`, weight will be forced to cast to `bf16`. After #91238, this kind of test scenario will raise a RuntimeError. There is no precision loss since the workable one is also casting to `bf16`. And this also alighs with Autocast CUDA whitelist. Pull Request resolved: #95366 Approved by: https://github.com/ngimel, https://github.com/lezcano, https://github.com/leslie-fang-intel
### Motivation Add `prelu` to lower precision cast policy on AutocastCPU to fix pytorch/pytorch#95365 : Before: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , `prelu` cannot address the scenario of different datatypes of `input` and `weight`, will get a RuntimeError. This scenario is common in autocast, e.g, with `autocast` to `bf16`, if the `op` before `prelu` comes out a `bf16` output, which is the input of `prelu`, and `prelu's` weight is `fp32`, then it will get a RuntimeError. After: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , prelu be forced to run with `bf16` data type. Before pytorch/pytorch#91238, when input is `bf16`, weight will be forced to cast to `bf16`. After pytorch/pytorch#91238, this kind of test scenario will raise a RuntimeError. There is no precision loss since the workable one is also casting to `bf16`. And this also alighs with Autocast CUDA whitelist. Pull Request resolved: pytorch/pytorch#95366 Approved by: https://github.com/ngimel, https://github.com/lezcano, https://github.com/leslie-fang-intel
### Motivation Add `prelu` to lower precision cast policy on AutocastCPU to fix pytorch/pytorch#95365 : Before: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , `prelu` cannot address the scenario of different datatypes of `input` and `weight`, will get a RuntimeError. This scenario is common in autocast, e.g, with `autocast` to `bf16`, if the `op` before `prelu` comes out a `bf16` output, which is the input of `prelu`, and `prelu's` weight is `fp32`, then it will get a RuntimeError. After: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , prelu be forced to run with `bf16` data type. Before pytorch/pytorch#91238, when input is `bf16`, weight will be forced to cast to `bf16`. After pytorch/pytorch#91238, this kind of test scenario will raise a RuntimeError. There is no precision loss since the workable one is also casting to `bf16`. And this also alighs with Autocast CUDA whitelist. Pull Request resolved: pytorch/pytorch#95366 Approved by: https://github.com/ngimel, https://github.com/lezcano, https://github.com/leslie-fang-intel
### Motivation Add `prelu` to lower precision cast policy on AutocastCPU to fix pytorch/pytorch#95365 : Before: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , `prelu` cannot address the scenario of different datatypes of `input` and `weight`, will get a RuntimeError. This scenario is common in autocast, e.g, with `autocast` to `bf16`, if the `op` before `prelu` comes out a `bf16` output, which is the input of `prelu`, and `prelu's` weight is `fp32`, then it will get a RuntimeError. After: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , prelu be forced to run with `bf16` data type. Before pytorch/pytorch#91238, when input is `bf16`, weight will be forced to cast to `bf16`. After pytorch/pytorch#91238, this kind of test scenario will raise a RuntimeError. There is no precision loss since the workable one is also casting to `bf16`. And this also alighs with Autocast CUDA whitelist. Pull Request resolved: pytorch/pytorch#95366 Approved by: https://github.com/ngimel, https://github.com/lezcano, https://github.com/leslie-fang-intel
Stack from ghstack (oldest at bottom):
The PReLU implementation was all over the place. This lead to a number
of bugs like #68760. We fix it by:
This fixes a number of issues:
Fixes #68760
Fixes #89895
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10