Add topk Triton kernel for CUDA backend#18141
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18141
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 167 PendingAs of commit cf79091 with merge base 1925873 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds a Triton-based topk kernel for the ExecuTorch CUDA backend, replacing aten.topk.default during graph transformation. The kernel uses iterative argmax/argmin with masking and is registered via @triton_op.
Changes:
- New Triton topk kernel implementation with iterative max/min and masking algorithm
- Registration of the kernel in the edge-to-triton replacement pass
- Tests (eager correctness, export validation, E2E C++ runner) and a dedicated C++ test runner
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| backends/cuda/triton/kernels/topk.py | New Triton topk kernel and its abstract/fake implementation |
| backends/cuda/triton/kernels/init.py | Export the new topk symbol |
| backends/cuda/triton/replacement_pass.py | Map aten.topk.default to the Triton kernel |
| backends/cuda/tests/test_topk.py | Eager correctness, export, and E2E tests |
| backends/cuda/tests/topk_runner/main.cpp | C++ runner for E2E testing |
| backends/cuda/tests/topk_runner/CMakeLists.txt | Build config for the C++ runner |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
fb5d204 to
00165ab
Compare
| @@ -0,0 +1,117 @@ | |||
| /* | |||
There was a problem hiding this comment.
why write a new tiny runner for each new op as opposed to use a standard runner?
| BLOCK: tl.constexpr, | ||
| LARGEST: tl.constexpr, | ||
| ): | ||
| """Single-block topk: one program per row, iterative max/min with masking.""" |
There was a problem hiding this comment.
why this kernel? Were there any other options, curious?
There was a problem hiding this comment.
Just for simplicity, for MoE, we are only selecting 8 from 256. So perf isn't important. torch topK has libtorch dependency, so simplest to untangle from libtorch.
There was a problem hiding this comment.
Just for simplicity, for MoE, we are only selecting 8 from 256. So perf isn't important. torch topK has libtorch dependency, so simplest to untangle from libtorch.
digantdesai
left a comment
There was a problem hiding this comment.
Stamping to unblock. We should look into perf later.
89bc514 to
38371c6
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| else: | ||
| FILL: tl.constexpr = float("inf") | ||
|
|
||
| vals = tl.load(row_ptr + offs, mask=mask, other=FILL).to(tl.float32) |
38371c6 to
8a84442
Compare
Replaces aten.topk with a Triton implementation compiled directly into the AOTInductor .so. Algorithm: iterative argmax/argmin with masking. - Replacement pass skips N > 4096 (kernel loads entire rows into one thread block); falls back to aten for vocab-sized topk - NaN handling matches torch.topk: NaN treated as larger than all finite values for both largest=True and largest=False - Handles empty dimensions (N=0, k=0) - Tests: eager correctness, NaN, empty, 3D non-last dim, export, e2e
8a84442 to
cf79091
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| input_shape = node.args[0].meta["val"].shape | ||
| dim = node.args[2] if len(node.args) > 2 else -1 | ||
| N = input_shape[dim] |
| raw_vals = tl.load(row_ptr + offs, mask=mask, other=FILL).to(tl.float32) | ||
| idxs = offs.to(tl.int64) |
| """Export succeeds and produces non-empty .pte.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| pte_path, _ = export_topk(tmpdir) | ||
| self.assertTrue(os.path.exists(pte_path)) | ||
| self.assertGreater(os.path.getsize(pte_path), 0) |
Add topk Triton kernel for CUDA backend
Replaces aten.topk with a Triton implementation compiled directly into
the AOTInductor .so. Algorithm: iterative argmax/argmin with masking.
thread block); falls back to aten for vocab-sized topk
finite values for both largest=True and largest=False
Naive implementation, slower than torch.topK