-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add nvFuser support for aten.native_batch_norm_backward #84546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add nvFuser support for aten.native_batch_norm_backward #84546
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit c287399 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Note, there's a couple other uses of this pattern in this file, no need to fix them in this PR.
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here. |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here. |
|
Hey @IvanYashchuk. |
) Summary: Replacing `tensor.reshape(broadcast_mask)` with unsqueezes makes the implementation of `batch_norm_backward` more friendly for PrimTorch+nvFuser. Pull Request resolved: #84546 Approved by: https://github.com/Chillee Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/6363b1b3587aa64ad055ba0a905af28d8dec52d2 Reviewed By: izaitsevfb Differential Revision: D39296978 fbshipit-source-id: 3f30341348796290006268cfb9f0af1f02718a5c
Replacing
tensor.reshape(broadcast_mask)with unsqueezes makes the implementation ofbatch_norm_backwardmore friendly for PrimTorch+nvFuser.