MLX: on-device token sampling with Gumbel-max#20454
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20454
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below:
|
This PR needs a
|
| ExecutionState& st, | ||
| StreamOrDevice s) { | ||
| auto shape = to_shape(n.shape, st); | ||
| check_allocation_bounded(shape, uint32, "random_bits"); |
There was a problem hiding this comment.
How does uint32 here compare to n.width in random::bits?
It someone sets width as another constant, does it require to be in sync with uint32?
There was a problem hiding this comment.
Since nothing emits a non-4 width, now I assert n.width == 4.
| from torch import Tensor | ||
|
|
||
|
|
||
| @torch.library.custom_op("mlx::sample", mutates_args=()) |
There was a problem hiding this comment.
This should be in backends/mlx/custom_ops.py
(It's not a custom kernel op b/c it doesn't have metal shader code)
There was a problem hiding this comment.
Moved it to custom_ops.py and relocated the test.
|
@claude review this code |
|
Claude finished @metascroy's task in 3m 8s —— View job Code review: MLX on-device token samplingRead through the full diff ( Addressing the two inline review comments1. Yes, they must stay in sync, and right now they're only implicitly consistent: check_allocation_bounded(shape, uint32, "random_bits"); // hard-coded 4-byte dtype
...
st.set_tensor(n.out, random::bits(shape, n.width, key, s)); // n.width drives real element sizeThe bound check assumes 4 bytes/element while the actual allocation is Since nothing emits a non-4 width, the simplest hardening is to assert the invariant rather than carry a flexible-but-unsupported field: if (n.width != 4) {
throw std::runtime_error("random_bits: only width=4 (uint32) is supported");
}That makes the 2. Agreed.
Additional observationsReference op vs. on-device graph are two different implementations. The CPU reference in
Minor: None of these are blockers. The two inline items (width invariant + file relocation) are the ones worth doing before merge. |
| return self.head(logits, temperature=temperature) | ||
|
|
||
|
|
||
| class TestSampleOp(unittest.TestCase): |
There was a problem hiding this comment.
These test cases should be inside backends/mlx/test/test_ops.py
| pte, | ||
| ) | ||
| self.assertEqual(count_mlx_delegate_segments(pte), 1) | ||
| counts = get_mlx_node_counts(pte) |
There was a problem hiding this comment.
See test_ops.py. There are utilities for testing node counts
| # optional_str carries its own None handling; other compound offset | ||
| # fields (int_or_vid, etc.) must be guarded when optional so a None | ||
| # value is serialized as an absent field rather than crashing. | ||
| if fld.required or kind == "optional_str": |
There was a problem hiding this comment.
Why are these changes needed? You don't have a string arg on the new node?
| Gumbel-max sampling from softmax(logits / temperature). | ||
| logits: [B, vocab] | ||
| temperature: scalar float tensor (runtime input) | ||
| seed: scalar int tensor or None |
There was a problem hiding this comment.
Does it not export if seed is an int?
| AsTypeNode( | ||
| x=P.slot_to_tid(g_f32), | ||
| out=P.slot_to_tid(g), | ||
| scalar_type=torch_dtype_to_scalar_type(dt), |
There was a problem hiding this comment.
Should we have this at all? Why not compute divide/argmax/etc in same fp32 dtype? The final output type is integer
| """ | ||
| Gumbel-max sampling from softmax(logits / temperature). | ||
| logits: [B, vocab] | ||
| temperature: scalar float tensor (runtime input) |
There was a problem hiding this comment.
Can we add top-p as well?
| logits, temperature = args[0], args[1] | ||
| seed = args[2] if len(args) > 2 and args[2] is not None else None | ||
|
|
||
| dt = n.args[0].meta["val"].dtype |
There was a problem hiding this comment.
Can we use emit_if_else to specialize on temperature 0 as argmax?
Summary
Adds token sampling that runs inside the exported .pte for the MLX backend: a model wrapped in SamplingHead returns a sampled token id instead of [B, S, vocab] logits, avoiding the per-step logits copy to host and the host-side softmax+multinomial.
Sampling uses Gumbel-max: argmax(logits / temperature + g), g = -log(-log(u)). The only new schema primitive is a random source, RandomBitsNode, the rest reuses existing nodes. Greedy = temperature → 0. temperature is a runtime input; seed is optional.
Changes
Notes
Fixes #20353