-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Fix NJT OpInfo entry for nn.functional.prelu #144582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144582
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit fd35809 with merge base 3a5bf0b ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. The OpInfo entry for prelu was wrong before this PR; `weight` needs to be passed as well. The op isn't fully implemented yet. [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. `value_selecting_reduction_backward()` is used in the backward for min / max, so this PR implements it for NJT. Notably, this isn't enough for reducing over the ragged dim, since that results in a dense tensor and thus NJT's torch_dispatch will not be called for this op. We need factory function support for nested ints to fix that case. Pull Request resolved: #144583 Approved by: https://github.com/soulitzer ghstack dependencies: #144582
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. Implements `chunk()` backward on the batch dim, which was left out before. This PR unbinds the components and invokes `copy_()` on these to pass along the appropriate gradients. Pull Request resolved: #144584 Approved by: https://github.com/soulitzer ghstack dependencies: #144582, #144583
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. Before this PR, `frexp()` for NJT was handled via the unary pointwise fallback. The op returns a tuple, however, and the fallback doesn't handle that. This PR defines an explicit impl for `frexp()` that wraps both returned `(mantissa, exponent)` as NJTs. Pull Request resolved: #144585 Approved by: https://github.com/soulitzer ghstack dependencies: #144582, #144583, #144584
Stack from ghstack (oldest at bottom):
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.
The OpInfo entry for prelu was wrong before this PR;
weightneeds to be passed as well. The op isn't fully implemented yet.