Initial implementation of stateless RNG APIs#177229
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177229
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b60b2dd with merge base 99dee05 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
|
But would this be similar to accepting Or is it mainly to stripp the |
ghstack-source-id: e02a04b Pull-Request: pytorch/pytorch#177229
| uint32_t r3 = curand(&state); | ||
|
|
||
| uint64_t new_seed = static_cast<uint64_t>(r0) | (static_cast<uint64_t>(r1) << 32); | ||
| uint64_t new_offset = static_cast<uint64_t>(r2) | (static_cast<uint64_t>(r3) << 32); |
There was a problem hiding this comment.
Is there a specific reason for choosing this hash?
There was a problem hiding this comment.
It's simply reinterpreting the returned 128-bit randomness sampled at this location as a new (64-bit seed, 64-bit offset). This is a convenient way to jump to a location within the 3D PRNG number space (seed, subsequence=0, offset) that is v. likely to result in an independent PRNG stream to avoid undesirable reuse.
| uint64_t offset = input[key_idx * 2 + 1]; | ||
|
|
||
| curandStatePhilox4_32_10_t state; | ||
| curand_init(seed, /*subsequence=*/0, /*offset=*/offset, &state); |
There was a problem hiding this comment.
Why not expose the Philox subsequence dimension? Different subsequences occupy disjoint regions of the counter space by construction.
There was a problem hiding this comment.
Restricting Philox subsequence to 0 in these kernels (and further not exposing subsequence) is a design decision that helps satisfy some useful properties:
- Consistency of random number generation across GPU types. The underlying kernels for
torch.randn()/torch.rand()today use thread ID-based subsequence numbers, with the # of threads being SM-dependent. This makes the generated outputs different between e.g. A100 vs. H100. - Consistency of generated values over time AKA "sequential equivalence". This means making the outputs independent of input size / batching, requiring subsequence numbers that are not based on input size; subsequence=0 is the easiest way to achieve this. This property is useful for Distributed, who would like to guarantee that a sharded parameter will have its shards initialized with the same values as if the parameter were initialized all at once. Ex:
key = torch.random.key(42)
vals = torch.random.uniform(key, 10)
sharded_vals1 = torch.random.uniform(key, 5)
# TODO: introduce a public API for this "offset hacking"
offset_key = torch.tensor([42, 5], device=key.device, dtype=key.dtype)
sharded_vals2 = torch.random.uniform(offset_key, 5)
assert torch.equal(vals, torch.cat([sharded_vals1, sharded_vals2], dim=0))Note that ensuring the above properties does come at a performance cost. Using thread-based subsequence numbers allows for better memory coalescing behavior and fewer expensive curand_init() calls. I think it may be useful to satisfy these properties by default, but provide a flag for opting out for those who care about absolute performance (for this case, simply call into the existing kernels).
@vadimkantorov Conceptually yes, this is similar. And you're correct that Generator is the current concept holding an RNG state. There may or may not be a provided public translation between manual "RNG key" state specification and Generators.
A nice property we get by representing RNG state as bare Tensors is the ability to batch them efficiently. Here's an example where a different RNG key is used for each sample within a sampled batch, which is useful to achieve "batch invariance": # === Batch invariance example (per-sample keys) ===
key = torch.random.key(seed=42)
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 2)
keys = torch.random.split(key, num=B+1)
key, batch_key = keys[0], keys[1:]
# each set of 32 is generated via a corresponding batch item key
x = torch.random.normal(B, 32, key=batch_key)Note that you could do this with a batch-sized stack of Generators, but it's much less efficient. Ex:
And of course, there's the fact that Generators today are not supported in |
Maybe would be possible to somehow make existing |
Would this also nicely make Generators work in torch.compile? (make the torch.Generator a tensor subclass and implement the methods as methods changing the value) |
@vadimkantorov unfortunately, I don't think it's technically feasible to make torch.Generator be a tensor subclass without substantial refactoring / BC-breakage. @albanD can correct me if I'm wrong though |
|
if not - maybe at least the new Key tensor-subclass could implement some inplace methods as torch.Generator, and make it accepted in old generator-related APIs, then torch.Generator could be deprecated... |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
|
||
| int64_t total_threads = num_keys * num_splits; | ||
| constexpr int block_size = 256; | ||
| int num_blocks = std::min( |
There was a problem hiding this comment.
btw you don't need to bother with sizing the number of blocks to a particular GPU. Previously because curand_init call was expensive (don't exactly know why and it's not important) we tried to reuse it across grid_stride loop, but all you are doing is philox_4x32 so there's no reason for a grid-stride loop, just launch enough blocks. IT's a nit for this PR but could simplify the next one.
There was a problem hiding this comment.
nice, love this!
ghstack-source-id: 9074ceb Pull-Request: pytorch#177229
Stateless API variant for generates random numbers from a normal distribution deterministically given an RNG key.
~~There is some complexity required to keep normal generation consistent across odd and even offsets: the Box-Muller transform is used to efficiently transform sampled uniforms -> sampled normals and operates on pairs of normals. So we handle this case explicitly by using an effective offset of `offset - (offset % 4)` then skipping `offset % 4` sampled values to determine where to begin our output stream. This way, the same pairs of uniforms are always paired and we get a consistent, offsetable stream of generated normal values.~~
Update: scrapped Box-Muller handling for simplicity. Also, we treat offset in units of 128-bit randomness chunks instead of 32-bit units, so we can provide fewer guarantees around sequential equivalence but the kernels are simpler.
Example API usage:
```python
import torch.func._random as r
# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
new_key, subkey = r.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = r.normal(key, 3)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
key, subkey = r.split(key)
val = r.normal(subkey, 3)
print(f"draw {i}: {val}")
# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")
# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
val = r.normal(r.fold_in(key, i), 3)
print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)
# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: #177230
Approved by: https://github.com/ngimel
ghstack dependencies: #177229
Initial implementation for a JAX-like stateless RNG API, starting with:
* `torch.func._random.key(seed)`
* `torch.func._random.split(key, num)`
* `torch.func._random.fold_in(key, data)`
Both `split()` and `fold_in()` support arbitrarily-batched keys. The former will add a new left-most batch dimension while the latter maintains the input shape.
In the details, both `split()` and `fold_in()` sample 128-bits of randomness at each (seed, offset) pair to produce a new (seed, offset) pair. 64-bits of the sampled randomness are used for the seed and 64-bits are used for the offset.
Underneath, the `split()` and `fold_in()` kernels utilize a custom stateless Philox-4x32-10 round implementation. This implementation fixes `subsequence=0` permanently to ensure consistent random number generation across # of CUDA threads and thus across GPU devices, input shapes, etc.
Example API usage:
```python
import torch.func._random as r
# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
new_key, subkey = r.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = r.normal(key, 3)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
key, subkey = r.split(key)
val = r.normal(subkey, 3)
print(f"draw {i}: {val}")
# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")
# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
val = r.normal(r.fold_in(key, i), 3)
print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)
# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: #177229
Approved by: https://github.com/albanD
Stateless API variant for generates random numbers from a normal distribution deterministically given an RNG key.
~~There is some complexity required to keep normal generation consistent across odd and even offsets: the Box-Muller transform is used to efficiently transform sampled uniforms -> sampled normals and operates on pairs of normals. So we handle this case explicitly by using an effective offset of `offset - (offset % 4)` then skipping `offset % 4` sampled values to determine where to begin our output stream. This way, the same pairs of uniforms are always paired and we get a consistent, offsetable stream of generated normal values.~~
Update: scrapped Box-Muller handling for simplicity. Also, we treat offset in units of 128-bit randomness chunks instead of 32-bit units, so we can provide fewer guarantees around sequential equivalence but the kernels are simpler.
Example API usage:
```python
import torch.func._random as r
# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
new_key, subkey = r.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = r.normal(key, 3)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
key, subkey = r.split(key)
val = r.normal(subkey, 3)
print(f"draw {i}: {val}")
# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")
# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
val = r.normal(r.fold_in(key, i), 3)
print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)
# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: #177230
Approved by: https://github.com/ngimel
ghstack dependencies: #177229
Initial implementation for a JAX-like stateless RNG API, starting with:
* `torch.func._random.key(seed)`
* `torch.func._random.split(key, num)`
* `torch.func._random.fold_in(key, data)`
Both `split()` and `fold_in()` support arbitrarily-batched keys. The former will add a new left-most batch dimension while the latter maintains the input shape.
In the details, both `split()` and `fold_in()` sample 128-bits of randomness at each (seed, offset) pair to produce a new (seed, offset) pair. 64-bits of the sampled randomness are used for the seed and 64-bits are used for the offset.
Underneath, the `split()` and `fold_in()` kernels utilize a custom stateless Philox-4x32-10 round implementation. This implementation fixes `subsequence=0` permanently to ensure consistent random number generation across # of CUDA threads and thus across GPU devices, input shapes, etc.
Example API usage:
```python
import torch.func._random as r
# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
new_key, subkey = r.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = r.normal(key, 3)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
key, subkey = r.split(key)
val = r.normal(subkey, 3)
print(f"draw {i}: {val}")
# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")
# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
val = r.normal(r.fold_in(key, i), 3)
print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)
# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: pytorch#177229
Approved by: https://github.com/albanD
Stateless API variant for generates random numbers from a normal distribution deterministically given an RNG key.
~~There is some complexity required to keep normal generation consistent across odd and even offsets: the Box-Muller transform is used to efficiently transform sampled uniforms -> sampled normals and operates on pairs of normals. So we handle this case explicitly by using an effective offset of `offset - (offset % 4)` then skipping `offset % 4` sampled values to determine where to begin our output stream. This way, the same pairs of uniforms are always paired and we get a consistent, offsetable stream of generated normal values.~~
Update: scrapped Box-Muller handling for simplicity. Also, we treat offset in units of 128-bit randomness chunks instead of 32-bit units, so we can provide fewer guarantees around sequential equivalence but the kernels are simpler.
Example API usage:
```python
import torch.func._random as r
# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
new_key, subkey = r.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = r.normal(key, 3)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
key, subkey = r.split(key)
val = r.normal(subkey, 3)
print(f"draw {i}: {val}")
# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")
# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
val = r.normal(r.fold_in(key, i), 3)
print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)
# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: pytorch#177230
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#177229
Initial implementation for a JAX-like stateless RNG API, starting with:
* `torch.func._random.key(seed)`
* `torch.func._random.split(key, num)`
* `torch.func._random.fold_in(key, data)`
Both `split()` and `fold_in()` support arbitrarily-batched keys. The former will add a new left-most batch dimension while the latter maintains the input shape.
In the details, both `split()` and `fold_in()` sample 128-bits of randomness at each (seed, offset) pair to produce a new (seed, offset) pair. 64-bits of the sampled randomness are used for the seed and 64-bits are used for the offset.
Underneath, the `split()` and `fold_in()` kernels utilize a custom stateless Philox-4x32-10 round implementation. This implementation fixes `subsequence=0` permanently to ensure consistent random number generation across # of CUDA threads and thus across GPU devices, input shapes, etc.
Example API usage:
```python
import torch.func._random as r
# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
new_key, subkey = r.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = r.normal(key, 3)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
key, subkey = r.split(key)
val = r.normal(subkey, 3)
print(f"draw {i}: {val}")
# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")
# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
val = r.normal(r.fold_in(key, i), 3)
print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)
# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: pytorch#177229
Approved by: https://github.com/albanD
Stateless API variant for generates random numbers from a normal distribution deterministically given an RNG key.
~~There is some complexity required to keep normal generation consistent across odd and even offsets: the Box-Muller transform is used to efficiently transform sampled uniforms -> sampled normals and operates on pairs of normals. So we handle this case explicitly by using an effective offset of `offset - (offset % 4)` then skipping `offset % 4` sampled values to determine where to begin our output stream. This way, the same pairs of uniforms are always paired and we get a consistent, offsetable stream of generated normal values.~~
Update: scrapped Box-Muller handling for simplicity. Also, we treat offset in units of 128-bit randomness chunks instead of 32-bit units, so we can provide fewer guarantees around sequential equivalence but the kernels are simpler.
Example API usage:
```python
import torch.func._random as r
# === Randomness using split() ===
key = r.key(42, device="cuda")
for i in range(3):
new_key, subkey = r.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = r.normal(key, 3)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
# === Less involved and mostly equivalent to the above ===
key = r.key(42, device="cuda")
for i in range(3):
key, subkey = r.split(key)
val = r.normal(subkey, 3)
print(f"draw {i}: {val}")
# === Alternative approach using split with explicit num ===
key = r.key(42, device="cuda")
keys = r.split(key, num=4)
key, subkeys = keys[0], keys[1:]
vals = [r.normal(subkey, 3) for subkey in subkeys]
print(f"draw: {vals}")
# === Alternative approach using fold_in() ===
key = r.key(42, device="cuda")
for i in range(3):
val = r.normal(r.fold_in(key, i), 3)
print(f"draw {i}: {val}")
# This is needed for future randomness calls to avoid key reuse
key = r.fold_in(key, 3)
# === Batch invariance example (per-sample keys) ===
key = r.key(42, device="cuda")
B = 64
# 1 key for future randomness and the other B for per-batch-item usage.
# key shape: (2,)
# batch_key shape: (B, 1, 2)
keys = r.split(key, num=B+1)
key, batch_key = keys[0], keys[1:].unsqueeze(-2)
# each set of 32 is generated via a corresponding batch item key
x = r.normal(batch_key, B, 32)
```
Pull Request resolved: pytorch#177230
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#177229
Stack from ghstack (oldest at bottom):
Initial implementation for a JAX-like stateless RNG API, starting with:
torch.func._random.key(seed)torch.func._random.split(key, num)torch.func._random.fold_in(key, data)Both
split()andfold_in()support arbitrarily-batched keys. The former will add a new left-most batch dimension while the latter maintains the input shape.In the details, both
split()andfold_in()sample 128-bits of randomness at each (seed, offset) pair to produce a new (seed, offset) pair. 64-bits of the sampled randomness are used for the seed and 64-bits are used for the offset.Underneath, the
split()andfold_in()kernels utilize a custom stateless Philox-4x32-10 round implementation. This implementation fixessubsequence=0permanently to ensure consistent random number generation across # of CUDA threads and thus across GPU devices, input shapes, etc.Example API usage: