-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Unify lowerings for auto_functionalized and triton_kernel_wrapper_functional #134466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ctional Fixes #134372 The triton_kernel_wrapper_functional lowering was causing problems (it was generating small kernels with nans in it, probably from realizing aten.empty nodes. Instead of having its own manual lowering, we change triton_kernel_wrapper_functional to go the same route as auto_functionalized where we decompose the node into clone + mutation nodes. Test Plan: - new test - existing tests [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134466
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0dd0222 with merge base 2553278 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ctional Fixes #134372 The triton_kernel_wrapper_functional lowering was causing problems (it was generating small kernels with nans in it, probably from realizing aten.empty nodes. Instead of having its own manual lowering, we change triton_kernel_wrapper_functional to go the same route as auto_functionalized where we decompose the node into clone + mutation nodes. Test Plan: - new test - existing tests ghstack-source-id: 60b7f9f Pull Request resolved: #134466
…wrapper_functional" Fixes #134372 The triton_kernel_wrapper_functional lowering was causing problems (it was generating small kernels with nans in it, probably from realizing aten.empty nodes. Instead of having its own manual lowering, we change triton_kernel_wrapper_functional to go the same route as auto_functionalized where we decompose the node into clone + mutation nodes. Test Plan: - new test - existing tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…wrapper_functional" Fixes #134372 The triton_kernel_wrapper_functional lowering was causing problems (it was generating small kernels with nans in it, probably from realizing aten.empty nodes. Instead of having its own manual lowering, we change triton_kernel_wrapper_functional to go the same route as auto_functionalized where we decompose the node into clone + mutation nodes. Test Plan: - new test - existing tests [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…zed (#134490) We say Node a is fusible into node b if node b is an auto_functionalized node that may reinplace node a later on. This PR also changes aten.empty to be recomputable w.r.t the Partitioner (it is, like aten.zeros, cheap to recompute and fusible into other ops). Fixes #134468 Test Plan: - new test Pull Request resolved: #134490 Approved by: https://github.com/Chillee ghstack dependencies: #134364, #134466
aten.empty is almost always fusible into its consumer, so we never CSE it. This fixes a bug that looks like the following: ```py @torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}) def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None: out_sin.copy_(x.sin()) out_cos.copy_(x.cos()) @torch.compile def f(x): out0 = torch.empty_like(x) out1 = torch.empty_like(x) sin_cos(x, out0, out1) return x.clone(), out0, out1 x = torch.randn(3, requires_grad=True) f(x) ``` - cse would de-duplicate the empty nodes - reinplacing would add an additional clone (because it can't write to both tensors at the same time) - the clone lowers into a new buffer + a copy_ kernel - the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional buffer, it doesn't matter what the values are. We could attempt to fix this on the reinplacing side but this seemed better as a partitioner heuristic and the reinplacing fix is a bit more tricky (we'd need to identify that the op never reads from the empty node). Test Plan: - new test (the old number was 27, the new number is 21, so this PR helped). Pull Request resolved: #134703 Approved by: https://github.com/yf225 ghstack dependencies: #134466, #134490, #134491
…ctional (pytorch#134466) Fixes pytorch#134372 The triton_kernel_wrapper_functional lowering was causing problems (it was generating small kernels with nans in it, probably from realizing aten.empty nodes. Instead of having its own manual lowering, we change triton_kernel_wrapper_functional to go the same route as auto_functionalized where we decompose the node into clone + mutation nodes. Test Plan: - new test - existing tests Pull Request resolved: pytorch#134466 Approved by: https://github.com/oulgen, https://github.com/eellison ghstack dependencies: pytorch#134364
…zed (pytorch#134490) We say Node a is fusible into node b if node b is an auto_functionalized node that may reinplace node a later on. This PR also changes aten.empty to be recomputable w.r.t the Partitioner (it is, like aten.zeros, cheap to recompute and fusible into other ops). Fixes pytorch#134468 Test Plan: - new test Pull Request resolved: pytorch#134490 Approved by: https://github.com/Chillee ghstack dependencies: pytorch#134364, pytorch#134466
…ytorch#134491) mutated arguments to triton kernels are fusible into the triton kernel. Test Plan: - new test Pull Request resolved: pytorch#134491 Approved by: https://github.com/Chillee ghstack dependencies: pytorch#134364, pytorch#134466, pytorch#134490
ROCM doesn't trigger the layout optimization that makes the test case valid so we're going to skip the checks. Should fix the following (I'll close them later) - pytorch#134481 - pytorch#134519 Pull Request resolved: pytorch#134690 Approved by: https://github.com/FindHao ghstack dependencies: pytorch#134466, pytorch#134490, pytorch#134491
Fixes pytorch#134119 From user feedback, it's difficult to understand what the tests do. We clarify the docs more. Pull Request resolved: pytorch#134692 Approved by: https://github.com/albanD ghstack dependencies: pytorch#134466, pytorch#134490, pytorch#134491, pytorch#134690
Fixes pytorch#134278 Test Plan: - tested locally Pull Request resolved: pytorch#134688 Approved by: https://github.com/yushangdi ghstack dependencies: pytorch#134466, pytorch#134490, pytorch#134491, pytorch#134690, pytorch#134692
aten.empty is almost always fusible into its consumer, so we never CSE it. This fixes a bug that looks like the following: ```py @torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}) def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None: out_sin.copy_(x.sin()) out_cos.copy_(x.cos()) @torch.compile def f(x): out0 = torch.empty_like(x) out1 = torch.empty_like(x) sin_cos(x, out0, out1) return x.clone(), out0, out1 x = torch.randn(3, requires_grad=True) f(x) ``` - cse would de-duplicate the empty nodes - reinplacing would add an additional clone (because it can't write to both tensors at the same time) - the clone lowers into a new buffer + a copy_ kernel - the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional buffer, it doesn't matter what the values are. We could attempt to fix this on the reinplacing side but this seemed better as a partitioner heuristic and the reinplacing fix is a bit more tricky (we'd need to identify that the op never reads from the empty node). Test Plan: - new test (the old number was 27, the new number is 21, so this PR helped). Pull Request resolved: pytorch#134703 Approved by: https://github.com/yf225 ghstack dependencies: pytorch#134466, pytorch#134490, pytorch#134491
Stack from ghstack (oldest at bottom):
Fixes #134372
The triton_kernel_wrapper_functional lowering was causing problems (it
was generating small kernels with nans in it, probably from realizing
aten.empty nodes. Instead of having its own manual lowering, we change
triton_kernel_wrapper_functional to go the same route as
auto_functionalized where we decompose the node into clone + mutation
nodes.
Test Plan:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang