Skip to content

Add GPU-side Gumbel-max sampling for CUDA graph compatibility#18844

Merged
Gasoonjia merged 42 commits intomainfrom
cuda-graph-sampling
Apr 27, 2026
Merged

Add GPU-side Gumbel-max sampling for CUDA graph compatibility#18844
Gasoonjia merged 42 commits intomainfrom
cuda-graph-sampling

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

@Gasoonjia Gasoonjia commented Apr 13, 2026

This PR replaces cpu sampler with CUDA sampler and fuse sampler with forward method to both eliminate unnecessary data transfer and improve sampling efficient. Decode performance increases from 113.8 token/s to 119.5 token/s

Once we land the device support pipeline, we should decompose the forward method with sampling.

Gasoonjia and others added 23 commits April 1, 2026 23:06
The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode.
Replace with plain PyTorch einsum ops that Inductor can fuse:
- FLA GPU time: 1.085ms → 0.344ms/step (-68%)
- Total GPU time: 12.0ms → 9.0ms/step (-25%)
- Export changed to static T=1 with enable_dynamic_shape=False
Move decode/prefill dispatch inside the chunk_gated_delta_rule triton_op
instead of using torch.cond at model level. This follows the same pattern
as the SDPA triton_op (pow2/non-pow2 dispatch) and avoids torch.cond
incompatibility with AOTI's FunctionalTensor pipeline.

Changes:
- chunk_gated_delta_rule.py: Add fused recurrent Triton kernel for T=1,
  refactor chunked pipeline into _launch_chunked(), dispatch via Python
  if inside the @triton_op wrapper
- model.py: Remove torch.cond from GatedDeltaNet.forward(), call
  triton_op directly (dispatch is internal)
- export.py: Single-method export with dynamic seq_len dim
- main.cpp: Fix create_text_llm_runner API signature
Only chunk_gated_delta_rule.py needs modification — dispatch logic
is internal to the triton_op, no model/export/runner changes needed.
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive
  reference across all FLA test configs
- test_dispatch_multiple_seq_lengths: verify correctness for
  T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch
  paths and chunk boundary edge cases
- Grid changed from (B*H,) to (V//BV, B*H) — 4x more blocks, better SM
  occupancy (128 blocks vs 32 on A100)
- BV reduced from 128 to 32 — lower register pressure, no spilling
- Removed unnecessary .contiguous() copies on squeezed inputs
- Removed debug print from triton_op dispatch
- GPU kernel time: 6us (3.47x faster than Inductor-fused native ops)
- Split model into prefill (chunked FLA triton_op) and decode (native PyTorch
  recurrent delta rule) methods with explicit state passing
- Add runtime_specs processing in CudaBackend::init() so LoadBackendOptionsMap
  options (skip_copy_output_to_cpu, use_shared_cuda_stream) take effect
- Keep state tensors GPU-resident across method calls; only copy logits to CPU
  for sampling via cudaMemcpy
- Achieves 77.4 tok/s decode (3.75x over naive dual-method baseline)

Modified files:
- cuda_backend.cpp: read runtime_specs in init() for skip_copy + shared stream
- main.cpp: dual-method runner with GPU-resident state, logits CPU copy helper
- CMakeLists.txt: link CUDA::cudart for cudaMemcpy
- model.py: dual-method model definition (prefill + decode)
- export.py: export script for dual-method PTE
Revert from explicit state passing back to registered buffers with
in-place updates (KVCache, conv_state, recurrent_state). Export with
share_mutable_buffers=True so both prefill and forward methods share
mutable state via mem_id=2. C++ runner uses share_memory_arenas=true
and only passes (tokens, input_pos) — no CUDA runtime dependency.

Results: 84.5 tok/s (up from 77.4), 0 select_scatter ops in profile,
65 D2H memcpy (logits only).
Add runtime buffer sharing between AOTI containers so that prefill and
decode methods operate on the same GPU tensors (KV cache, conv_state,
etc.) without unnecessary H2D/D2H copies or getter/setter overhead.

The first container to initialize extracts its constants (keyed by
original FQN). Subsequent containers with matching FQNs are updated via
AOTInductorModelContainerUpdateUserManagedConstantBufferPairs to point
to the same GPU memory (user_managed = true, no copy).

Also switch main.cpp prefill to token-by-token decode path while the
chunked FLA triton_op numerical issue is being resolved.

Tested E2E: "What is the capital of France?" → "Paris" with 966
constants shared between prefill and decode containers on A100.
- cuda_backend.cpp: Use codegen name (from GetConstantName) instead of
  original FQN when calling UpdateUserManagedConstantBufferPairs. The AOTI
  API matches against internal codegen names, not FQNs — using FQNs caused
  silent no-op sharing, breaking KV cache flow between prefill and decode.

- main.cpp: Add chunked prefill path using the "prefill" method (T>=2) with
  cudaDeviceSynchronize between prefill and decode for cross-stream safety.
  Add --decode_only flag to fall back to token-by-token decode for all tokens.

- inference.py: Update docstring to reflect that chunked FLA is used in PTE
  mode (not eager).

Verified E2E: "What is the capital of France?" → "The capital of France is Paris."
Prefill: 105 tok/s (chunked FLA), Decode: 87 tok/s (recurrent delta rule).
- cuda_backend.cpp: Replace debug printf with ET_LOG for errors/info only
- main.cpp: Remove --decode_only flag, keep only chunked prefill path
- cuda_backend.cpp: Replace ET_CHECK_OK_OR_RETURN_ERROR with explicit error
  handling + cudaDeviceSynchronize after weight transfer, add logging for
  missing weights_blob
- main.cpp: Support single "forward" method fallback when prefill/decode
  not available, use prefill_method variable, remove debug printf
Implements CUDA graph support in the CUDA backend to reduce CPU kernel
launch overhead during autoregressive decoding:

- cuda_backend.cpp: 3-phase execution (warmup → capture → replay) with
  static input/output GPU buffers, cudaMemcpyAsync for I/O, and
  cudaGraphInstantiateFlagAutoFreeOnLaunch for cudaMallocAsync compat
- cuda_delegate_handle.h: CUDA graph state (phase, graph objects, static
  buffer metadata) with RAII cleanup in destructor
- main.cpp: --cuda_graph flag that sets BackendOptions before load_method
- test_model_e2e.sh: Enable --cuda_graph for Qwen3.5 MoE CI, set
  PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync

Benchmark (A100, Qwen3.5-35B-A3B HQQ-INT4): 82→98 tok/s (1.20x)
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 13, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18844

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 7 New Failures, 6 Pending, 1 Unrelated Failure

As of commit 3e185c0 with merge base b4d4507 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 13, 2026
@Gasoonjia Gasoonjia force-pushed the cuda-graph-sampling branch 2 times, most recently from b4f9eca to c7450dd Compare April 13, 2026 20:34
@Gasoonjia Gasoonjia force-pushed the cuda-graph-sampling branch from b55f894 to c7450dd Compare April 13, 2026 21:14

extern "C" {

AOTITorchError aoti_torch_cuda_rand(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this from PyTorch/Aten or we are rolling our own?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rolling our own. I didn;t see there's an aten version.

Comment thread examples/models/qwen3_5_moe/export.py Outdated
example_prefill_len = config.max_seq_len - 1
prefill_tokens = torch.zeros((1, example_prefill_len), dtype=torch.long)
prefill_pos = torch.arange(example_prefill_len, dtype=torch.long)
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to revert these changes or just a cherry-pick artifact?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just need another round of main branch update. Have bring it back.

Comment on lines -112 to -117
def _sample(logits, temperature):
"""Sample from logits with temperature."""
if temperature <= 0:
return logits.argmax(dim=-1)
probs = torch.softmax(logits / temperature, dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to verify if we are close to this using some unit tests?

using SizesType = executorch::aten::SizesType;

// Read a sampled token from the model output tensor [B, 1].
// The model performs Gumbel-max sampling on-device and returns a single
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment to the paper or something for Gumbel-max?

prev_token = cur_token;
stats.on_sampling_begin();
cur_token = llm::logits_to_token(*step_logits_ptr, FLAGS_temperature);
stats.on_sampling_end();
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can't report sampling time separately anymore.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't until we land the device support pipeline. With device support we can make the model output in cuda and pull out sampler

Comment thread examples/models/qwen3_5_moe/model.py Outdated
# GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise)
# Equivalent to sampling from softmax(logits/T) but fully on-device.
logits = logits / temperature.clamp(min=1e-6)
noise = torch.rand_like(logits)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this one need rand.cu?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes torch.rand_like need it

* advances a GPU-resident counter for unique offsets, making this fully
* compatible with CUDA graph capture and replay.
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No unit-tests?

@digantdesai
Copy link
Copy Markdown
Contributor

@claude review this PR.

@claude
Copy link
Copy Markdown

claude Bot commented Apr 21, 2026

Claude finished @digantdesai's task in 2m 46s —— View job


PR Review: Add GPU-side Gumbel-max sampling for CUDA graph compatibility

  • Gather context and read changed files
  • Review rand.cu and rand.h (new CUDA sampling implementation)
  • Review changes to examples/models/qwen3_5_moe/ files
  • Review build system changes
  • Post review feedback

Summary

This PR fuses sampling into the model's forward() method using Gumbel-max trick on GPU, replacing the previous CPU-side logits_to_token / _sample approach. This eliminates a GPU→CPU→GPU roundtrip for the full [B, V] logits tensor during decode, and makes the sampling operation compatible with CUDA graph capture (since the RNG state lives on device). The new rand.cu/rand.h files implement custom CUDA RNG shims that use GPU-resident Philox state with atomic counter advancement.

Issues

1. Prefill exports with minimal example shape instead of max sequence length

export.py:655-656 — The prefill example was changed from torch.zeros((1, max_seq_len-1)) to torch.tensor([[0, 1]]). The original comment (still present at line 649-651) explains why the large example was needed:

Example T must equal max_seq_len-1 so AOTI compiles kernels (especially chunk_gated_delta_rule with CHUNK_SIZE=64) for the full range of sequence lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes that reject longer prompts at runtime.

If this comment is still accurate, the change will break prefill for prompts longer than 2 tokens. If the constraint has been relaxed, the comment should be updated.

Fix this →

2. CUDA::curand not linked on MSVC path

CMakeLists.txt:154-157 — The MSVC branch of target_link_libraries(aoti_cuda_shims ...) doesn't include CUDA::curand, but the non-MSVC path (line 166) does. Since rand.cu uses curand_kernel.h, this will cause link failures on Windows MSVC builds.

Fix this →

3. Global mutable state without thread safety or cleanup

rand.cu:48-49d_rng and g_rng_init_done are bare globals with no synchronization. If ensure_rng_init is called concurrently from multiple threads (e.g. multiple models or streams), there's a data race on both the pointer and the flag. Consider using std::once_flag / std::call_once, or at minimum a std::atomic<bool>. Also, d_rng is never freed — not critical for a long-running process, but worth noting.

4. atomicAdd per-element is expensive for large tensors

rand.cu:88,104,118 — Each thread calls atomicAdd(&rng->counter, 1ULL). The comment at line 83 acknowledges this is fine for numel=1, but the kernels support arbitrary numel. For large tensors, the serial contention on a single counter will be a bottleneck. A common pattern is to have one thread (or one per block) grab a range via atomicAdd(&rng->counter, (unsigned long long)numel) and then distribute offsets within the block.

5. Gumbel noise epsilon values are ad-hoc

model.py:640 — The 1e-20 epsilon in torch.log(-torch.log(noise + 1e-20) + 1e-20) works for float32 but could be brittle. torch.rand_like returns values in [0, 1) (i.e., it can return exactly 0.0), so the inner log needs protection. The outer log also needs protection since -log(0 + 1e-20) is a large positive number, not problematic. The values seem fine for float32 but it would be good to add a brief note that this assumes float32 logits (which is guaranteed by the .float() cast on line 635).

6. No error checking on CUDA API calls in read_token

main.cpp:54cudaMemcpy return value is unchecked. If the copy fails for any reason, val will contain uninitialized data, silently producing garbage tokens.

7. Missing temperature in MLX export path

export.py:537-549 — The MLX export still uses the old 2-argument signature (example_tokens, example_input_pos). Since the model's forward() now requires temperature as a third argument, the MLX path will fail at export time. The forward signature change should either be gated on the backend, or the MLX export should be updated to pass temperature.

Fix this →

Suggestions

8. Seeding with time(nullptr) gives poor entropy

rand.cu:57time(nullptr) has second-level resolution and is predictable. For a sampling RNG this doesn't affect correctness, but for reproducibility it would be better to allow an explicit seed parameter, and for better entropy consider mixing in the device clock or using curandGenerateSeeds.

9. Consider making the forward method composable

As the PR description notes, once the device support pipeline lands, forward() should be decomposed. Currently, the model's forward() couples the transformer computation with sampling, which means:

  • You can't get raw logits for tasks like perplexity evaluation or beam search
  • The output is float token IDs rather than Long, requiring the read_token float→uint64 cast
  • Temperature is baked into the model graph

This is reasonable for the current CUDA-graph-focused use case but should be tracked for future cleanup.

10. read_token could use the output dtype directly

main.cpp:45-58 — The model returns argmax(...).float() which means the output is a float representation of an integer. The cast chain float → uint64_t works but is unusual. If the model returned Long instead, you could avoid the float intermediary and read an int64_t directly.

CI Failures

The strides_vec compilation errors in CI are from the base cuda-graph branch (cuda_backend.cpp:697 uses strides_vec before it's defined at line 715), not from this PR's changes. The 2 Windows failures are also from trunk.

Addressing Reviewer Questions

Re: @digantdesai's questions:

  • rand.cu origin (line 134): This is a custom implementation, not from PyTorch/ATen. It implements the AOTI shim interface (aoti_torch_cuda_rand, aoti_torch_cuda_randint_low_out) using cuRAND Philox primitives with GPU-resident state for CUDA graph compatibility.
  • Unit tests (rand.h:56): Agree — the new RNG shims need unit tests, particularly verifying (1) output is in expected range, (2) different replay invocations produce different sequences, and (3) basic statistical properties (uniformity).
  • export.py reverts (line 655): See Issue Add support for quantized LeakyReLU #1 above — the prefill example shape change may break long prompts.
  • rand.cu needed by model.py? (line 639): Yes — torch.rand_like(logits) in the forward method compiles to an AOTI rand call, which dispatches to aoti_torch_cuda_rand at runtime.
  • Sampling time reporting (main.cpp:216): Correct, sampling is now part of the model execution time. You could add a cudaEvent pair around the read_token call to measure the D2H transfer time, but the sampling kernel time would be inseparable from the model forward pass.

Copy link
Copy Markdown
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Songhao. Can we make a util such that this can be used by other models easily?

@Gasoonjia Gasoonjia force-pushed the cuda-graph-sampling branch from 93bee20 to d3bca0d Compare April 23, 2026 03:58
@Gasoonjia Gasoonjia force-pushed the cuda-graph-sampling branch from d3bca0d to 5245f64 Compare April 23, 2026 05:20
Base automatically changed from cuda-graph to main April 27, 2026 05:39
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia
Copy link
Copy Markdown
Contributor Author

Gasoonjia commented Apr 27, 2026

Have extracted sampler into a individual python file for better other model's usage, add unit test for the sampler and rand.cu; addressed claude and @digantdesai 's comment.

After we added device support, we can rewrite the sampler with a pure cuda/triton kernel and shared by other models, without changing model defination.

@Gasoonjia Gasoonjia merged commit c1731fd into main Apr 27, 2026
264 of 285 checks passed
@Gasoonjia Gasoonjia deleted the cuda-graph-sampling branch April 27, 2026 07:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants