Skip to content

Conversation

lezcano
Copy link
Collaborator

@lezcano lezcano commented Dec 21, 2022

Stack from ghstack (oldest at bottom):

The PReLU implementation was all over the place. This lead to a number
of bugs like #68760. We fix it by:

  • Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel
  • This second kernel is just a good-ol' pointwise kernel.
  • We implement the derivative for the pointwise kernel via TI as well for speed.
  • We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally

This fixes a number of issues:

  • We don't perform copies any more when the inputs are not contiguous
  • The derivatives are now correct
  • We fix vmap and many other functorch-related issues.
  • CPU and CUDA now share the relevant broadcasting logic
  • The implementation is about 1/3 the length.

Fixes #68760
Fixes #89895

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

The PReLU implementation was all over the place. This lead to a number
of bugs like #68760. We port it
to TI and split it into broadcasting operations (CompositeImplicit) and
kernel fusion (via TI). This makes the CPU and CUDA parts agree and
heavily simplifies the code. This also includes a number of fixes like
using opmath_t consistently and  a more efficient backward implementation
in the non-contiguous case without copying the data.

Fixes #68760

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 21, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91238

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 1633602:

FLAKY - The following jobs failed but were likely due to flakiness present on master:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Dec 21, 2022
@lezcano lezcano added module: nn Related to torch.nn topic: not user facing topic category and removed module: cpu CPU specific problem (e.g., perf, algorithm) labels Dec 21, 2022
@lezcano
Copy link
Collaborator Author

lezcano commented Dec 21, 2022

@soulitzer @albanD you have been tagged as codeowners, but could you please review it?

The only issue with the current approach is that I was not able to port the MKLDNN and the MPS implementations to just implement the simpler new kernel, so those should still be buggy. If the relevant PoCs from Intel / Apple port these, we should be able to remove all the relevant entries from derivatives.yaml and native_functions.yaml.

Quantized fails as well, no idea why :(

The PReLU implementation was all over the place. This lead to a number
of bugs like #68760. We port it
to TI and split it into broadcasting operations (CompositeImplicit) and
kernel fusion (via TI). This makes the CPU and CUDA parts agree and
heavily simplifies the code. This also includes a number of fixes like
using opmath_t consistently and  a more efficient backward implementation
in the non-contiguous case without copying the data.

Fixes #68760

cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Dec 21, 2022
@lezcano lezcano requested a review from kshitij12345 December 21, 2022 14:36
The PReLU implementation was all over the place. This lead to a number
of bugs like #68760.  We fix it by:
- Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel
- This second kernel is just a good-ol' pointwise kernel.
- We implement the derivative for the pointwise kernel via TI as well for speed.
- We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally

This fixes a number of issues:
- We don't perform copies any more when the inputs are not contiguous
- The derivatives are now correct
- We fix vmap
- CPU and CUDA now share the relevant broadcasting logic
- The implementation is about 1/3 the length.


Fixes #68760
Fixes #89895

cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
The PReLU implementation was all over the place. This lead to a number
of bugs like #68760.  We fix it by:
- Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel
- This second kernel is just a good-ol' pointwise kernel.
- We implement the derivative for the pointwise kernel via TI as well for speed.
- We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally

This fixes a number of issues:
- We don't perform copies any more when the inputs are not contiguous
- The derivatives are now correct
- We fix vmap
- CPU and CUDA now share the relevant broadcasting logic
- The implementation is about 1/3 the length.


Fixes #68760
Fixes #89895

cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this has to be more complex than functions like amax which we can implement without a pair a functions?

MkldnnCPU: mkldnn_prelu
CPU: prelu_cpu
CUDA: prelu_cuda
MPS: prelu_mps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kulinseth if you are interested in this one.

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (on functorch batching rule front) :)

at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU);
}
Tensor _prelu_kernel(const Tensor& self, const Tensor& weight) {
auto options = self.options().dtype(promoteTypes(self.scalar_type(), weight.scalar_type()));
Copy link
Collaborator

@ngimel ngimel Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we want type promoting behavior for prelu, and we don't have it currently. Moreover, this isn't even the right type promoting behavior, as it would allow 0d weight to dictate output type for a 1d+ input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason TI does not accept undefined tensors. Structured kernels bypass this by being tightly knit to TI. I don't know whether there is a good way to let TI handle type promotion outside of structured kernels.

That being said, yes, let's drop type promotion altogether!



@register_decomposition(aten._prelu_kernel)
@pw_cast_for_opmath
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pw_cast_for_opmath not needed here, it only leads to more complicated IR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous kernel used opmath_t and has a dedicated path for BFloat16. Should I drop those as well (I think we should).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are outputting either self or weight * self, in both cases there's no difference whether they are pre-converted to float or not here - self is a copy either way, weight * self would do whatever is needed in multiplication

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll clean up all that code then.



@register_decomposition(aten._prelu_kernel_backward)
@pw_cast_for_opmath
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here

@lezcano
Copy link
Collaborator Author

lezcano commented Dec 22, 2022

@albanD Re. amax. It's a bit trickier, because here we are implementing the backward via TI as well. I kept it this way to not regress performance, as the backward already had its own dedicated kernel.

Now, arguably, we would be better dropping that and simply implementing the backward compositionally, as the difference in performance will be minimal. Happy to do that if you think that's the way to go.

The PReLU implementation was all over the place. This lead to a number
of bugs like #68760.  We fix it by:
- Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel
- This second kernel is just a good-ol' pointwise kernel.
- We implement the derivative for the pointwise kernel via TI as well for speed.
- We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally

This fixes a number of issues:
- We don't perform copies any more when the inputs are not contiguous
- The derivatives are now correct
- We fix vmap and many other functorch-related issues.
- CPU and CUDA now share the relevant broadcasting logic
- The implementation is about 1/3 the length.


Fixes #68760
Fixes #89895

cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@lezcano lezcano requested a review from zou3519 as a code owner December 22, 2022 10:54
lezcano added a commit that referenced this pull request Dec 22, 2022
The PReLU implementation was all over the place. This lead to a number
of bugs like #68760. We port it
to TI and split it into broadcasting operations (CompositeImplicit) and
kernel fusion (via TI). This makes the CPU and CUDA parts agree and
heavily simplifies the code. This also includes a number of fixes like
using opmath_t consistently and  a more efficient backward implementation
in the non-contiguous case without copying the data.

Fixes #68760

ghstack-source-id: d6362ea
Pull Request resolved: #91238
@lezcano
Copy link
Collaborator Author

lezcano commented Dec 29, 2022

@yanbing-j I think the code would break in the case where input is 1D and weight is 0D, then the code dim_w[1] = weight.get_dim(0); would be executed and would break.
This code will not be necessary with this PR, as the inputs given to all the backends would already have the same number of dimensions.

The PReLU implementation was all over the place. This lead to a number
of bugs like #68760.  We fix it by:
- Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel
- This second kernel is just a good-ol' pointwise kernel.
- We implement the derivative for the pointwise kernel via TI as well for speed.
- We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally

This fixes a number of issues:
- We don't perform copies any more when the inputs are not contiguous
- The derivatives are now correct
- We fix vmap and many other functorch-related issues.
- CPU and CUDA now share the relevant broadcasting logic
- The implementation is about 1/3 the length.


Fixes #68760
Fixes #89895

cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Dec 29, 2022
The PReLU implementation was all over the place. This lead to a number
of bugs like #68760. We port it
to TI and split it into broadcasting operations (CompositeImplicit) and
kernel fusion (via TI). This makes the CPU and CUDA parts agree and
heavily simplifies the code. This also includes a number of fixes like
using opmath_t consistently and  a more efficient backward implementation
in the non-contiguous case without copying the data.

Fixes #68760

ghstack-source-id: 73bc805
Pull Request resolved: #91238
@jgong5
Copy link
Collaborator

jgong5 commented Dec 29, 2022

what's the procedure to land this into ideep?

No, I don't think we can do that since oneDNN expects the weight dim is the same as input. See this: https://oneapi-src.github.io/oneDNN/v2/dev_guide_prelu.html

I think the code would break in the case where input is 1D and weight is 0D, then the code dim_w[1] = weight.get_dim(0); would be executed and would break.

We can fix this one for the case when weight is a scalar. But it would be in the ATen/native/mkldnn/Prelu.cpp not in ideep, I guess, since ideep doesn't support 0-dim tensor. @yanbing-j

@lezcano lezcano added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 29, 2022
@lezcano
Copy link
Collaborator Author

lezcano commented Dec 29, 2022

If ideep does not support 0-dim tensors, then the code there was correct all along.

About the ideep fix, the thing is that ideep does some manual resizing that's no longer necessary after this PR. That code can be removed. Note that oneDNN would get the tensors with the same number of dimensions, so nothing would change there, the only thing that would change is that the resizing would now be done at the ATen level, rather than at the ideep level.

@lezcano
Copy link
Collaborator Author

lezcano commented Dec 29, 2022

Can I please get reviews from:
@kulinseth for the MPS changes
@albanD @soulitzer for the changes in the derivatives
@jbschlosser @albanD for the PR in general

I think that the quantisation and the MKLDNN changes are minimal so they don't need much of a review, or they can be reviewed by non-domain experts.

The PReLU implementation was all over the place. This lead to a number
of bugs like #68760.  We fix it by:
- Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel
- This second kernel is just a good-ol' pointwise kernel.
- We implement the derivative for the pointwise kernel via TI as well for speed.
- We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally

This fixes a number of issues:
- We don't perform copies any more when the inputs are not contiguous
- The derivatives are now correct
- We fix vmap and many other functorch-related issues.
- CPU and CUDA now share the relevant broadcasting logic
- The implementation is about 1/3 the length.


Fixes #68760
Fixes #89895

cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all your work getting this cleaned up! I'm good with it on my end though I think ideally we would avoid adding any ops.

Without adding _prelu_kernel / _prelu_backward_kernel, I think we could still get de-duplication of the broadcasting logic by dispatching CPU / CUDA to the same function and using REGISTER_DISPATCH dispatching from there. However, I do understand that this doesn't fix vmap / simplify writing batch rules or simplify derivative computation, which I think are valuable. So imo the tradeoffs are worth it.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

The PReLU implementation was all over the place. This lead to a number
of bugs like #68760.  We fix it by:
- Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel
- This second kernel is just a good-ol' pointwise kernel.
- We implement the derivative for the pointwise kernel via TI as well for speed.
- We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally

This fixes a number of issues:
- We don't perform copies any more when the inputs are not contiguous
- The derivatives are now correct
- We fix vmap and many other functorch-related issues.
- CPU and CUDA now share the relevant broadcasting logic
- The implementation is about 1/3 the length.


Fixes #68760
Fixes #89895

cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Dec 29, 2022
The PReLU implementation was all over the place. This lead to a number
of bugs like #68760. We port it
to TI and split it into broadcasting operations (CompositeImplicit) and
kernel fusion (via TI). This makes the CPU and CUDA parts agree and
heavily simplifies the code. This also includes a number of fixes like
using opmath_t consistently and  a more efficient backward implementation
in the non-contiguous case without copying the data.

Fixes #68760

ghstack-source-id: fc4c358
Pull Request resolved: #91238
@lezcano
Copy link
Collaborator Author

lezcano commented Dec 29, 2022

Great coordinated reviews. Even better being the two of them approvals :D

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 additional jobs have failed, first few of them are: trunk ,trunk / linux-focal-rocm5.3-py3.8 / test (default, 1, 2, linux.rocm.gpu)

Details for Dev Infra team Raised by workflow job

@yanbing-j
Copy link
Collaborator

We can fix this one for the case when weight is a scalar. But it would be in the ATen/native/mkldnn/Prelu.cpp not in ideep, I guess, since ideep doesn't support 0-dim tensor. @yanbing-j

Yes. We have changed the weight dimension to input dimension in https://github.com/pytorch/pytorch/pull/91238/files#diff-3dc8af925ec5d52a7cc5722279de60aeac5458097166318bcf6bfd58c420e629L710. The scenario of 1D input and 0D weight is supported. @jgong5 @lezcano

the only thing that would change is that the resizing would now be done at the ATen level, rather than at the ideep level.

Correct. The resizing has been done in ATen level, if the dimensions of src and weight are the same, the resizing in ideep level will not be invoked, and it can be regarded as a double check. @lezcano

@Xia-Weiwen
Copy link
Collaborator

Hi @lezcano Looks like the quantization issue is fixed, right? #91238 (comment)

@lezcano
Copy link
Collaborator Author

lezcano commented Dec 30, 2022

yeah, all good on the quantisation end. Thank you for checking in though @Xia-Weiwen

@lezcano
Copy link
Collaborator Author

lezcano commented Dec 30, 2022

@pytorchbot merge -f "flaky tests"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Feb 28, 2023
### Motivation
Add `prelu` to lower precision cast policy on AutocastCPU to fix #95365 :

Before: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , `prelu` cannot address the scenario of different datatypes of `input` and `weight`, will get a RuntimeError. This scenario is common in autocast, e.g, with `autocast` to `bf16`, if the `op` before `prelu` comes out a `bf16` output, which is the input of `prelu`, and `prelu's` weight is `fp32`, then it will get a RuntimeError.

After: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , prelu be forced to run with `bf16` data type.

Before #91238, when input is `bf16`, weight will be forced to cast to `bf16`.  After #91238, this kind of test scenario will raise a RuntimeError. There is no precision loss since the workable one is also casting to `bf16`.

And this also alighs with Autocast CUDA whitelist.

Pull Request resolved: #95366
Approved by: https://github.com/ngimel, https://github.com/lezcano, https://github.com/leslie-fang-intel
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 2, 2023
### Motivation
Add `prelu` to lower precision cast policy on AutocastCPU to fix pytorch/pytorch#95365 :

Before: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , `prelu` cannot address the scenario of different datatypes of `input` and `weight`, will get a RuntimeError. This scenario is common in autocast, e.g, with `autocast` to `bf16`, if the `op` before `prelu` comes out a `bf16` output, which is the input of `prelu`, and `prelu's` weight is `fp32`, then it will get a RuntimeError.

After: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , prelu be forced to run with `bf16` data type.

Before pytorch/pytorch#91238, when input is `bf16`, weight will be forced to cast to `bf16`.  After pytorch/pytorch#91238, this kind of test scenario will raise a RuntimeError. There is no precision loss since the workable one is also casting to `bf16`.

And this also alighs with Autocast CUDA whitelist.

Pull Request resolved: pytorch/pytorch#95366
Approved by: https://github.com/ngimel, https://github.com/lezcano, https://github.com/leslie-fang-intel
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
### Motivation
Add `prelu` to lower precision cast policy on AutocastCPU to fix pytorch/pytorch#95365 :

Before: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , `prelu` cannot address the scenario of different datatypes of `input` and `weight`, will get a RuntimeError. This scenario is common in autocast, e.g, with `autocast` to `bf16`, if the `op` before `prelu` comes out a `bf16` output, which is the input of `prelu`, and `prelu's` weight is `fp32`, then it will get a RuntimeError.

After: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , prelu be forced to run with `bf16` data type.

Before pytorch/pytorch#91238, when input is `bf16`, weight will be forced to cast to `bf16`.  After pytorch/pytorch#91238, this kind of test scenario will raise a RuntimeError. There is no precision loss since the workable one is also casting to `bf16`.

And this also alighs with Autocast CUDA whitelist.

Pull Request resolved: pytorch/pytorch#95366
Approved by: https://github.com/ngimel, https://github.com/lezcano, https://github.com/leslie-fang-intel
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
### Motivation
Add `prelu` to lower precision cast policy on AutocastCPU to fix pytorch/pytorch#95365 :

Before: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , `prelu` cannot address the scenario of different datatypes of `input` and `weight`, will get a RuntimeError. This scenario is common in autocast, e.g, with `autocast` to `bf16`, if the `op` before `prelu` comes out a `bf16` output, which is the input of `prelu`, and `prelu's` weight is `fp32`, then it will get a RuntimeError.

After: Within the scope of torch.cpu.amp.autocast(dtype=torch.bfloat16) , prelu be forced to run with `bf16` data type.

Before pytorch/pytorch#91238, when input is `bf16`, weight will be forced to cast to `bf16`.  After pytorch/pytorch#91238, this kind of test scenario will raise a RuntimeError. There is no precision loss since the workable one is also casting to `bf16`.

And this also alighs with Autocast CUDA whitelist.

Pull Request resolved: pytorch/pytorch#95366
Approved by: https://github.com/ngimel, https://github.com/lezcano, https://github.com/leslie-fang-intel
@facebook-github-bot facebook-github-bot deleted the gh/Lezcano/164/head branch June 8, 2023 14:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: nn Related to torch.nn open source release notes: quantization release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants