Skip to content

MLX: on-device token sampling with Gumbel-max#20454

Draft
kiymetakdemir wants to merge 2 commits into
pytorch:mainfrom
kiymetakdemir:mlx-ondevice-sampling
Draft

MLX: on-device token sampling with Gumbel-max#20454
kiymetakdemir wants to merge 2 commits into
pytorch:mainfrom
kiymetakdemir:mlx-ondevice-sampling

Conversation

@kiymetakdemir

Copy link
Copy Markdown

Summary

Adds token sampling that runs inside the exported .pte for the MLX backend: a model wrapped in SamplingHead returns a sampled token id instead of [B, S, vocab] logits, avoiding the per-step logits copy to host and the host-side softmax+multinomial.

Sampling uses Gumbel-max: argmax(logits / temperature + g), g = -log(-log(u)). The only new schema primitive is a random source, RandomBitsNode, the rest reuses existing nodes. Greedy = temperature → 0. temperature is a runtime input; seed is optional.

Changes

  • schema.fbs: new RandomBitsNode (append-only union member, optional seed).
  • custom_kernel_ops/sample.py: mlx::sample op + register_fake + CPU reference.
  • ops.py: _sample_handler lowering the Gumbel-max graph.
  • runtime/MLXInterpreter.h: exec_random_bits + dispatch.
  • llm/sampling.py: SamplingHead wrapper.
  • generate.py: None-guard optional compound fields so the optional seed (de)serializes.

Notes

  • Uniform/gumbel computed in fp32 (bf16 rounds the ~1.0 clamp up → log(0)=-inf → poisons argmax).
  • Tests: custom_kernel_ops/test/test_sample.py, eager parity/distribution/determinism, export+partition lowering, and on-device e2e (incl. a bf16 large-vocab regression).

Fixes #20353

@pytorch-bot

pytorch-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20454

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⚠️ 13 Awaiting Approval

As of commit bddb819 with merge base 6f6225c (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 23, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

ExecutionState& st,
StreamOrDevice s) {
auto shape = to_shape(n.shape, st);
check_allocation_bounded(shape, uint32, "random_bits");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does uint32 here compare to n.width in random::bits?

It someone sets width as another constant, does it require to be in sync with uint32?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since nothing emits a non-4 width, now I assert n.width == 4.

from torch import Tensor


@torch.library.custom_op("mlx::sample", mutates_args=())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in backends/mlx/custom_ops.py

(It's not a custom kernel op b/c it doesn't have metal shader code)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it to custom_ops.py and relocated the test.

@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 23, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 3m 8s —— View job


Code review: MLX on-device token sampling

Read through the full diff (sample.py, _sample_handler, exec_random_bits, SamplingHead, schema + generator changes, and the tests). Overall this is clean, well-tested work — the Gumbel-max lowering is faithful to MLX's uniform → gumbel → argmax layering, the schema change is correctly append-only, and the bf16 fp32-clamp fix is a nice catch with a dedicated regression test. Comments below.

Addressing the two inline review comments

1. MLXInterpreter.h:1703uint32 vs n.width (your question)

Yes, they must stay in sync, and right now they're only implicitly consistent:

check_allocation_bounded(shape, uint32, "random_bits");   // hard-coded 4-byte dtype
...
st.set_tensor(n.out, random::bits(shape, n.width, key, s)); // n.width drives real element size

The bound check assumes 4 bytes/element while the actual allocation is n.width bytes/element. They agree only because _sample_handler always emits width=4 and the schema default is 4. If anyone ever sets a different width, the bound check silently uses the wrong size and random::bits produces a non-uint32 dtype that the rest of the graph (the AsTypeNode → /uint32_max chain) assumes is uint32.

Since nothing emits a non-4 width, the simplest hardening is to assert the invariant rather than carry a flexible-but-unsupported field:

if (n.width != 4) {
  throw std::runtime_error("random_bits: only width=4 (uint32) is supported");
}

That makes the uint32 in the bound check provably correct. (Alternatively, map n.width → dtype and pass that to check_allocation_bounded, but that's more machinery than the single supported path warrants.) Fix this →

2. sample.py location — move to custom_ops.py

Agreed. mlx::sample has no Metal shader; it's a graph-lowered op like mlx::rope / mlx::gather_mm / mlx::gather_qmm, which all live in backends/mlx/custom_ops.py. custom_kernel_ops/ is for ops with actual kernel source (tq_norm, tq4_compress, gated_delta_rule, …). Moving it requires updating the two import sites:

  • backends/mlx/ops.py:22 (import ...custom_kernel_ops.sample)
  • backends/mlx/custom_kernel_ops/test/test_sample.py:27 (and ideally relocate the test to backends/mlx/test/).

Fix this →

Additional observations

Reference op vs. on-device graph are two different implementations. The CPU reference in sample.py uses torch.rand directly (no nextafter clamp), while the lowered graph uses bits/uint32_max → min(·, nextafter(1,0)) → gumbel. The eager tests (test_matches_independent_gumbel_reference) validate the reference against an independent torch impl, and the e2e tests validate the graph distributionally — but the two paths are never checked to agree with each other, and seeded determinism differs (torch RNG host-side vs. MLX RNG on-device). That's a reasonable design given the host-copy avoidance goal, but worth a one-line comment in sample.py noting the reference is for export/shape only and is not bit-identical to the delegated path, so a future reader doesn't assume seeded parity across host/device.

temperature = 0 → division by zero. Greedy is documented as "temperature → 0", and tests use 1e-4. Since temperature is a runtime input, an actual 0.0 produces inf/nan in logits / temperature. Not a blocker (callers are expected to pass a small epsilon), but a brief note in SamplingHead's docstring that temperature must be > 0 would prevent a confusing on-device result.

ArgmaxNode(axis=-1) — confirm exec_argmax/MLX argmax accepts a negative axis at runtime. Logits are [B, vocab] so -1 is unambiguous, but if MLX expects a non-negative axis it'd be safer to emit the resolved positive axis (the handler has n.args[0].meta available). The e2e test would catch this if it's a problem; flagging in case the e2e runner wasn't exercised on this path.

Minor: emit_lifted_constant(P, 4294967295.0, ...) and the nextafter constant are recomputed each call — fine, but if make_or_get_constant-style dedup applies these will already be shared; no action needed.

None of these are blockers. The two inline items (width invariant + file relocation) are the ones worth doing before merge.
· branch mlx-ondevice-sampling

return self.head(logits, temperature=temperature)


class TestSampleOp(unittest.TestCase):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test cases should be inside backends/mlx/test/test_ops.py

pte,
)
self.assertEqual(count_mlx_delegate_segments(pte), 1)
counts = get_mlx_node_counts(pte)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See test_ops.py. There are utilities for testing node counts

# optional_str carries its own None handling; other compound offset
# fields (int_or_vid, etc.) must be guarded when optional so a None
# value is serialized as an absent field rather than crashing.
if fld.required or kind == "optional_str":

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these changes needed? You don't have a string arg on the new node?

Gumbel-max sampling from softmax(logits / temperature).
logits: [B, vocab]
temperature: scalar float tensor (runtime input)
seed: scalar int tensor or None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it not export if seed is an int?

Comment thread backends/mlx/ops.py
AsTypeNode(
x=P.slot_to_tid(g_f32),
out=P.slot_to_tid(g),
scalar_type=torch_dtype_to_scalar_type(dt),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have this at all? Why not compute divide/argmax/etc in same fp32 dtype? The final output type is integer

"""
Gumbel-max sampling from softmax(logits / temperature).
logits: [B, vocab]
temperature: scalar float tensor (runtime input)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add top-p as well?

Comment thread backends/mlx/ops.py
logits, temperature = args[0], args[1]
seed = args[2] if len(args) > 2 and args[2] is not None else None

dt = n.args[0].meta["val"].dtype

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use emit_if_else to specialize on temperature 0 as argmax?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

# Good First Issue: In-Model Sampling Head for the MLX Backend (Gumbel Sampling)

2 participants