Skip to content

Add topk Triton kernel for CUDA backend#18141

Merged
mergennachin merged 1 commit into
mainfrom
mergennachin/topk-triton-kernel
Mar 19, 2026
Merged

Add topk Triton kernel for CUDA backend#18141
mergennachin merged 1 commit into
mainfrom
mergennachin/topk-triton-kernel

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented Mar 12, 2026

Add topk Triton kernel for CUDA backend

Replaces aten.topk with a Triton implementation compiled directly into
the AOTInductor .so. Algorithm: iterative argmax/argmin with masking.

  • Replacement pass skips N > 4096 (kernel loads entire rows into one
    thread block); falls back to aten for vocab-sized topk
  • NaN handling matches torch.topk: NaN treated as larger than all
    finite values for both largest=True and largest=False
  • Handles empty dimensions (N=0, k=0)
  • Tests: eager correctness, NaN, empty, 3D non-last dim, export, e2e

Naive implementation, slower than torch.topK

  ┌──────────────────────────┬────────────┬─────────────┬─────────┐
  │          Config          │ Eager (us) │ Runner (us) │ Speedup │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=4 cols=8 k=2        │ 73.8       │ 210.4       │ 0.35x   │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=16 cols=8 k=2       │ 79.5       │ 224.6       │ 0.35x   │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=4 cols=32 k=5       │ 70.1       │ 228.0       │ 0.31x   │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=32 cols=64 k=10     │ 73.9       │ 299.4       │ 0.25x   │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=64 cols=128 k=5     │ 76.5       │ 265.2       │ 0.29x   │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=128 cols=256 k=10   │ 81.2       │ 239.4       │ 0.34x   │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=256 cols=512 k=20   │ 83.1       │ 352.0       │ 0.24x   │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=512 cols=32 k=2     │ 79.5       │ 258.1       │ 0.31x   │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=1024 cols=16 k=4    │ 75.1       │ 244.7       │ 0.31x   │
  ├──────────────────────────┼────────────┼─────────────┼─────────┤
  │ rows=1024 cols=1024 k=10 │ 297.5      │ 623.0       │ 0.48x   │
  └──────────────────────────┴────────────┴─────────────┴─────────┘

Copilot AI review requested due to automatic review settings March 12, 2026 22:31
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18141

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 167 Pending

As of commit cf79091 with merge base 1925873 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 12, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a Triton-based topk kernel for the ExecuTorch CUDA backend, replacing aten.topk.default during graph transformation. The kernel uses iterative argmax/argmin with masking and is registered via @triton_op.

Changes:

  • New Triton topk kernel implementation with iterative max/min and masking algorithm
  • Registration of the kernel in the edge-to-triton replacement pass
  • Tests (eager correctness, export validation, E2E C++ runner) and a dedicated C++ test runner

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.

Show a summary per file
File Description
backends/cuda/triton/kernels/topk.py New Triton topk kernel and its abstract/fake implementation
backends/cuda/triton/kernels/init.py Export the new topk symbol
backends/cuda/triton/replacement_pass.py Map aten.topk.default to the Triton kernel
backends/cuda/tests/test_topk.py Eager correctness, export, and E2E tests
backends/cuda/tests/topk_runner/main.cpp C++ runner for E2E testing
backends/cuda/tests/topk_runner/CMakeLists.txt Build config for the C++ runner

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@mergennachin mergennachin force-pushed the mergennachin/topk-triton-kernel branch from fb5d204 to 00165ab Compare March 12, 2026 22:50
@@ -0,0 +1,117 @@
/*
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why write a new tiny runner for each new op as opposed to use a standard runner?

Comment thread backends/cuda/tests/test_topk.py
BLOCK: tl.constexpr,
LARGEST: tl.constexpr,
):
"""Single-block topk: one program per row, iterative max/min with masking."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this kernel? Were there any other options, curious?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for simplicity, for MoE, we are only selecting 8 from 256. So perf isn't important. torch topK has libtorch dependency, so simplest to untangle from libtorch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for simplicity, for MoE, we are only selecting 8 from 256. So perf isn't important. torch topK has libtorch dependency, so simplest to untangle from libtorch.

Copy link
Copy Markdown
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamping to unblock. We should look into perf later.

Copilot AI review requested due to automatic review settings March 18, 2026 22:43
@mergennachin mergennachin force-pushed the mergennachin/topk-triton-kernel branch 2 times, most recently from 89bc514 to 38371c6 Compare March 18, 2026 22:43
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread backends/cuda/triton/kernels/topk.py
Comment thread backends/cuda/triton/kernels/topk.py
Comment thread backends/cuda/triton/kernels/topk.py Outdated
else:
FILL: tl.constexpr = float("inf")

vals = tl.load(row_ptr + offs, mask=mask, other=FILL).to(tl.float32)
Comment thread backends/cuda/triton/kernels/topk.py Outdated
Comment thread backends/cuda/triton/replacement_pass.py
Comment thread backends/cuda/tests/test_topk.py
@mergennachin mergennachin force-pushed the mergennachin/topk-triton-kernel branch from 38371c6 to 8a84442 Compare March 19, 2026 14:10
Replaces aten.topk with a Triton implementation compiled directly into
the AOTInductor .so. Algorithm: iterative argmax/argmin with masking.

- Replacement pass skips N > 4096 (kernel loads entire rows into one
  thread block); falls back to aten for vocab-sized topk
- NaN handling matches torch.topk: NaN treated as larger than all
  finite values for both largest=True and largest=False
- Handles empty dimensions (N=0, k=0)
- Tests: eager correctness, NaN, empty, 3D non-last dim, export, e2e
Copilot AI review requested due to automatic review settings March 19, 2026 14:12
@mergennachin mergennachin force-pushed the mergennachin/topk-triton-kernel branch from 8a84442 to cf79091 Compare March 19, 2026 14:12
@mergennachin mergennachin merged commit 3f33e54 into main Mar 19, 2026
180 of 182 checks passed
@mergennachin mergennachin deleted the mergennachin/topk-triton-kernel branch March 19, 2026 14:16
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +105 to +107
input_shape = node.args[0].meta["val"].shape
dim = node.args[2] if len(node.args) > 2 else -1
N = input_shape[dim]
Comment on lines +55 to +56
raw_vals = tl.load(row_ptr + offs, mask=mask, other=FILL).to(tl.float32)
idxs = offs.to(tl.int64)
Comment on lines +228 to +232
"""Export succeeds and produces non-empty .pte."""
with tempfile.TemporaryDirectory() as tmpdir:
pte_path, _ = export_topk(tmpdir)
self.assertTrue(os.path.exists(pte_path))
self.assertGreater(os.path.getsize(pte_path), 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants