Skip to content

Initial implementation of stateless RNG APIs#177229

Closed
jbschlosser wants to merge 14 commits into
gh/jbschlosser/253/basefrom
gh/jbschlosser/253/head
Closed

Initial implementation of stateless RNG APIs#177229
jbschlosser wants to merge 14 commits into
gh/jbschlosser/253/basefrom
gh/jbschlosser/253/head

Conversation

@jbschlosser
Copy link
Copy Markdown
Contributor

@jbschlosser jbschlosser commented Mar 12, 2026

Stack from ghstack (oldest at bottom):

Initial implementation for a JAX-like stateless RNG API, starting with:

  • torch.func._random.key(seed)
  • torch.func._random.split(key, num)
  • torch.func._random.fold_in(key, data)

Both split() and fold_in() support arbitrarily-batched keys. The former will add a new left-most batch dimension while the latter maintains the input shape.

In the details, both split() and fold_in() sample 128-bits of randomness at each (seed, offset) pair to produce a new (seed, offset) pair. 64-bits of the sampled randomness are used for the seed and 64-bits are used for the offset.

Underneath, the split() and fold_in() kernels utilize a custom stateless Philox-4x32-10 round implementation. This implementation fixes subsequence=0 permanently to ensure consistent random number generation across # of CUDA threads and thus across GPU devices, input shapes, etc.

Example API usage:

import torch.func._random as r

# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
  new_key, subkey = r.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = r.normal(key, 3)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.

# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
  key, subkey = r.split(key)
  val = r.normal(subkey, 3)
  print(f"draw {i}: {val}")

# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")

# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
    val = r.normal(r.fold_in(key, i), 3)
    print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)

# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177229

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b60b2dd with merge base 99dee05 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 12, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@github-actions
Copy link
Copy Markdown
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@jbschlosser jbschlosser marked this pull request as draft March 12, 2026 04:55
@vadimkantorov
Copy link
Copy Markdown
Contributor

vadimkantorov commented Mar 12, 2026

But would this be similar to accepting Generator and then also returning the updated Generator object (instead of updating it inplace)? AFAIK the Generator object currently holds the RNG state...

Or is it mainly to stripp the Generator type and representing the RNG state as bare Tensor in the APIs?

@eqy eqy requested a review from eee4017 March 12, 2026 17:23
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
[ghstack-poisoned]
uint32_t r3 = curand(&state);

uint64_t new_seed = static_cast<uint64_t>(r0) | (static_cast<uint64_t>(r1) << 32);
uint64_t new_offset = static_cast<uint64_t>(r2) | (static_cast<uint64_t>(r3) << 32);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason for choosing this hash?

Copy link
Copy Markdown
Contributor Author

@jbschlosser jbschlosser Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's simply reinterpreting the returned 128-bit randomness sampled at this location as a new (64-bit seed, 64-bit offset). This is a convenient way to jump to a location within the 3D PRNG number space (seed, subsequence=0, offset) that is v. likely to result in an independent PRNG stream to avoid undesirable reuse.

uint64_t offset = input[key_idx * 2 + 1];

curandStatePhilox4_32_10_t state;
curand_init(seed, /*subsequence=*/0, /*offset=*/offset, &state);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not expose the Philox subsequence dimension? Different subsequences occupy disjoint regions of the counter space by construction.

Copy link
Copy Markdown
Contributor Author

@jbschlosser jbschlosser Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restricting Philox subsequence to 0 in these kernels (and further not exposing subsequence) is a design decision that helps satisfy some useful properties:

  • Consistency of random number generation across GPU types. The underlying kernels for torch.randn() / torch.rand() today use thread ID-based subsequence numbers, with the # of threads being SM-dependent. This makes the generated outputs different between e.g. A100 vs. H100.
  • Consistency of generated values over time AKA "sequential equivalence". This means making the outputs independent of input size / batching, requiring subsequence numbers that are not based on input size; subsequence=0 is the easiest way to achieve this. This property is useful for Distributed, who would like to guarantee that a sharded parameter will have its shards initialized with the same values as if the parameter were initialized all at once. Ex:
key = torch.random.key(42)
vals = torch.random.uniform(key, 10)
sharded_vals1 = torch.random.uniform(key, 5)
# TODO: introduce a public API for this "offset hacking"
offset_key = torch.tensor([42, 5], device=key.device, dtype=key.dtype)
sharded_vals2 = torch.random.uniform(offset_key, 5)
assert torch.equal(vals, torch.cat([sharded_vals1, sharded_vals2], dim=0))

Note that ensuring the above properties does come at a performance cost. Using thread-based subsequence numbers allows for better memory coalescing behavior and fewer expensive curand_init() calls. I think it may be useful to satisfy these properties by default, but provide a flag for opting out for those who care about absolute performance (for this case, simply call into the existing kernels).

@jbschlosser
Copy link
Copy Markdown
Contributor Author

jbschlosser commented Mar 19, 2026

But would this be similar to accepting Generator and then also returning the updated Generator object (instead of updating it inplace)? AFAIK the Generator object currently holds the RNG state...

@vadimkantorov Conceptually yes, this is similar. And you're correct that Generator is the current concept holding an RNG state. There may or may not be a provided public translation between manual "RNG key" state specification and Generators.

Or is it mainly to stripp the Generator type and representing the RNG state as bare Tensor in the APIs?

A nice property we get by representing RNG state as bare Tensors is the ability to batch them efficiently. Here's an example where a different RNG key is used for each sample within a sampled batch, which is useful to achieve "batch invariance":

# === Batch invariance example (per-sample keys) ===
key = torch.random.key(seed=42)
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 2)
keys = torch.random.split(key, num=B+1)
key, batch_key = keys[0], keys[1:]
# each set of 32 is generated via a corresponding batch item key
x = torch.random.normal(B, 32, key=batch_key)

Note that you could do this with a batch-sized stack of Generators, but it's much less efficient. Ex:

Task Torch Generator-based (with CUDAGraphs) Torch Stateless (with CUDAGraphs)
1024 x rand(1024) 1.66 ms 44.79 us
32768 x rand(1024) 43.96 ms 1.17 ms
32768 x rand(32768) 64.68 ms 36.44 ms

And of course, there's the fact that Generators today are not supported in torch.compile, while it's trivial to support these Tensor-based APIs there.

@vadimkantorov
Copy link
Copy Markdown
Contributor

vadimkantorov commented Mar 19, 2026

A nice property we get by representing RNG state as bare Tensors is the ability to batch them efficiently.

Maybe would be possible to somehow make existing torch.Generator post-hoc as tensor sublcass and inherit from torch.Tensor? then maybe the generator/key APIs could be united?

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Comment thread torch/func/_random.py Outdated
Comment thread torch/_meta_registrations.py
Comment thread test/test_stateless_rng.py
Comment thread test/test_stateless_rng.py Outdated
Comment thread test/test_stateless_rng.py Outdated
Comment thread test/test_stateless_rng.py Outdated
[ghstack-poisoned]
@vadimkantorov
Copy link
Copy Markdown
Contributor

Maybe would be possible to somehow make existing torch.Generator post-hoc as tensor sublcass and inherit from torch.Tensor? then maybe the generator/key APIs could be united?

Would this also nicely make Generators work in torch.compile? (make the torch.Generator a tensor subclass and implement the methods as methods changing the value)

@jbschlosser
Copy link
Copy Markdown
Contributor Author

jbschlosser commented Apr 2, 2026

Maybe would be possible to somehow make existing torch.Generator post-hoc as tensor sublcass and inherit from torch.Tensor? then maybe the generator/key APIs could be united?
Would this also nicely make Generators work in torch.compile? (make the torch.Generator a tensor subclass and implement the methods as methods changing the value)

@vadimkantorov unfortunately, I don't think it's technically feasible to make torch.Generator be a tensor subclass without substantial refactoring / BC-breakage. @albanD can correct me if I'm wrong though

@vadimkantorov
Copy link
Copy Markdown
Contributor

if not - maybe at least the new Key tensor-subclass could implement some inplace methods as torch.Generator, and make it accepted in old generator-related APIs, then torch.Generator could be deprecated...

[ghstack-poisoned]
@jbschlosser jbschlosser requested a review from albanD April 2, 2026 18:39
[ghstack-poisoned]
@jbschlosser jbschlosser added topic: not user facing topic category and removed topic: new features topic category labels Apr 2, 2026
Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

@jbschlosser
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 2, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here


int64_t total_threads = num_keys * num_splits;
constexpr int block_size = 256;
int num_blocks = std::min(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw you don't need to bother with sizing the number of blocks to a particular GPU. Previously because curand_init call was expensive (don't exactly know why and it's not important) we tried to reuse it across grid_stride loop, but all you are doing is philox_4x32 so there's no reason for a grid-stride loop, just launch enough blocks. IT's a nit for this PR but could simplify the next one.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, love this!

mori360 pushed a commit to mori360/pytorch that referenced this pull request Apr 3, 2026
pytorchmergebot pushed a commit that referenced this pull request Apr 3, 2026
Stateless API variant for generates random numbers from a normal distribution deterministically given an RNG key.

~~There is some complexity required to keep normal generation consistent across odd and even offsets: the Box-Muller transform is used to efficiently transform sampled uniforms -> sampled normals and operates on pairs of normals. So we handle this case explicitly by using an effective offset of `offset - (offset % 4)` then skipping `offset % 4` sampled values to determine where to begin our output stream. This way, the same pairs of uniforms are always paired and we get a consistent, offsetable stream of generated normal values.~~

Update: scrapped Box-Muller handling for simplicity. Also, we treat offset in units of 128-bit randomness chunks instead of 32-bit units, so we can provide fewer guarantees around sequential equivalence but the kernels are simpler.

Example API usage:
```python
import torch.func._random as r

# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
  new_key, subkey = r.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = r.normal(key, 3)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.

# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
  key, subkey = r.split(key)
  val = r.normal(subkey, 3)
  print(f"draw {i}: {val}")

# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")

# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
    val = r.normal(r.fold_in(key, i), 3)
    print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)

# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: #177230
Approved by: https://github.com/ngimel
ghstack dependencies: #177229
weifengpy pushed a commit that referenced this pull request Apr 7, 2026
Initial implementation for a JAX-like stateless RNG API, starting with:
* `torch.func._random.key(seed)`
* `torch.func._random.split(key, num)`
* `torch.func._random.fold_in(key, data)`

Both `split()` and `fold_in()` support arbitrarily-batched keys. The former will add a new left-most batch dimension while the latter maintains the input shape.

In the details, both `split()` and `fold_in()` sample 128-bits of randomness at each (seed, offset) pair to produce a new (seed, offset) pair. 64-bits of the sampled randomness are used for the seed and 64-bits are used for the offset.

Underneath, the `split()` and `fold_in()` kernels utilize a custom stateless Philox-4x32-10 round implementation. This implementation fixes `subsequence=0` permanently to ensure consistent random number generation across # of CUDA threads and thus across GPU devices, input shapes, etc.

Example API usage:
```python
import torch.func._random as r

# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
  new_key, subkey = r.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = r.normal(key, 3)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.

# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
  key, subkey = r.split(key)
  val = r.normal(subkey, 3)
  print(f"draw {i}: {val}")

# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")

# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
    val = r.normal(r.fold_in(key, i), 3)
    print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)

# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: #177229
Approved by: https://github.com/albanD
weifengpy pushed a commit that referenced this pull request Apr 7, 2026
Stateless API variant for generates random numbers from a normal distribution deterministically given an RNG key.

~~There is some complexity required to keep normal generation consistent across odd and even offsets: the Box-Muller transform is used to efficiently transform sampled uniforms -> sampled normals and operates on pairs of normals. So we handle this case explicitly by using an effective offset of `offset - (offset % 4)` then skipping `offset % 4` sampled values to determine where to begin our output stream. This way, the same pairs of uniforms are always paired and we get a consistent, offsetable stream of generated normal values.~~

Update: scrapped Box-Muller handling for simplicity. Also, we treat offset in units of 128-bit randomness chunks instead of 32-bit units, so we can provide fewer guarantees around sequential equivalence but the kernels are simpler.

Example API usage:
```python
import torch.func._random as r

# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
  new_key, subkey = r.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = r.normal(key, 3)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.

# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
  key, subkey = r.split(key)
  val = r.normal(subkey, 3)
  print(f"draw {i}: {val}")

# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")

# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
    val = r.normal(r.fold_in(key, i), 3)
    print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)

# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: #177230
Approved by: https://github.com/ngimel
ghstack dependencies: #177229
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
Initial implementation for a JAX-like stateless RNG API, starting with:
* `torch.func._random.key(seed)`
* `torch.func._random.split(key, num)`
* `torch.func._random.fold_in(key, data)`

Both `split()` and `fold_in()` support arbitrarily-batched keys. The former will add a new left-most batch dimension while the latter maintains the input shape.

In the details, both `split()` and `fold_in()` sample 128-bits of randomness at each (seed, offset) pair to produce a new (seed, offset) pair. 64-bits of the sampled randomness are used for the seed and 64-bits are used for the offset.

Underneath, the `split()` and `fold_in()` kernels utilize a custom stateless Philox-4x32-10 round implementation. This implementation fixes `subsequence=0` permanently to ensure consistent random number generation across # of CUDA threads and thus across GPU devices, input shapes, etc.

Example API usage:
```python
import torch.func._random as r

# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
  new_key, subkey = r.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = r.normal(key, 3)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.

# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
  key, subkey = r.split(key)
  val = r.normal(subkey, 3)
  print(f"draw {i}: {val}")

# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")

# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
    val = r.normal(r.fold_in(key, i), 3)
    print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)

# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: pytorch#177229
Approved by: https://github.com/albanD
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
Stateless API variant for generates random numbers from a normal distribution deterministically given an RNG key.

~~There is some complexity required to keep normal generation consistent across odd and even offsets: the Box-Muller transform is used to efficiently transform sampled uniforms -> sampled normals and operates on pairs of normals. So we handle this case explicitly by using an effective offset of `offset - (offset % 4)` then skipping `offset % 4` sampled values to determine where to begin our output stream. This way, the same pairs of uniforms are always paired and we get a consistent, offsetable stream of generated normal values.~~

Update: scrapped Box-Muller handling for simplicity. Also, we treat offset in units of 128-bit randomness chunks instead of 32-bit units, so we can provide fewer guarantees around sequential equivalence but the kernels are simpler.

Example API usage:
```python
import torch.func._random as r

# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
  new_key, subkey = r.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = r.normal(key, 3)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.

# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
  key, subkey = r.split(key)
  val = r.normal(subkey, 3)
  print(f"draw {i}: {val}")

# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")

# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
    val = r.normal(r.fold_in(key, i), 3)
    print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)

# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: pytorch#177230
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#177229
bobrenjc93 pushed a commit to bobrenjc93/pytorch that referenced this pull request Apr 10, 2026
Initial implementation for a JAX-like stateless RNG API, starting with:
* `torch.func._random.key(seed)`
* `torch.func._random.split(key, num)`
* `torch.func._random.fold_in(key, data)`

Both `split()` and `fold_in()` support arbitrarily-batched keys. The former will add a new left-most batch dimension while the latter maintains the input shape.

In the details, both `split()` and `fold_in()` sample 128-bits of randomness at each (seed, offset) pair to produce a new (seed, offset) pair. 64-bits of the sampled randomness are used for the seed and 64-bits are used for the offset.

Underneath, the `split()` and `fold_in()` kernels utilize a custom stateless Philox-4x32-10 round implementation. This implementation fixes `subsequence=0` permanently to ensure consistent random number generation across # of CUDA threads and thus across GPU devices, input shapes, etc.

Example API usage:
```python
import torch.func._random as r

# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
  new_key, subkey = r.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = r.normal(key, 3)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.

# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
  key, subkey = r.split(key)
  val = r.normal(subkey, 3)
  print(f"draw {i}: {val}")

# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")

# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
    val = r.normal(r.fold_in(key, i), 3)
    print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)

# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: pytorch#177229
Approved by: https://github.com/albanD
bobrenjc93 pushed a commit to bobrenjc93/pytorch that referenced this pull request Apr 10, 2026
Stateless API variant for generates random numbers from a normal distribution deterministically given an RNG key.

~~There is some complexity required to keep normal generation consistent across odd and even offsets: the Box-Muller transform is used to efficiently transform sampled uniforms -> sampled normals and operates on pairs of normals. So we handle this case explicitly by using an effective offset of `offset - (offset % 4)` then skipping `offset % 4` sampled values to determine where to begin our output stream. This way, the same pairs of uniforms are always paired and we get a consistent, offsetable stream of generated normal values.~~

Update: scrapped Box-Muller handling for simplicity. Also, we treat offset in units of 128-bit randomness chunks instead of 32-bit units, so we can provide fewer guarantees around sequential equivalence but the kernels are simpler.

Example API usage:
```python
import torch.func._random as r

# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
  new_key, subkey = r.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = r.normal(key, 3)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.

# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
  key, subkey = r.split(key)
  val = r.normal(subkey, 3)
  print(f"draw {i}: {val}")

# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")

# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
    val = r.normal(r.fold_in(key, i), 3)
    print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)

# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: pytorch#177230
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#177229
@github-actions github-actions Bot deleted the gh/jbschlosser/253/head branch May 4, 2026 02:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: python_frontend python frontend release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants