Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Omni RNG (`orng`)

`orng` provides a thin facade over several Array API–compatible random number
`orng` provides a thin wrapper over several Array API–compatible random number
generators. It mirrors the subset of the `numpy.random.Generator` API:

- `random`
Expand Down Expand Up @@ -39,15 +39,15 @@ You can also combine extras, e.g. `pip install "orng[numpy,torch]"`.
## Quick Start

```python
from orng import ArrayRNG
from orng import RandomGenerator

rng = ArrayRNG(backend="numpy", seed=42)
rng = RandomGenerator(backend="numpy", seed=42)
samples = rng.normal(loc=0.0, scale=1.0, size=5)
uniform = rng.uniform(low=-1.0, high=1.0, size=(2, 2))
```

The backend module is imported lazily. If the requested library is missing,
`ArrayRNG` will raise an informative `ImportError` that points to the matching
`RandomGenerator` will raise an informative `ImportError` that points to the matching
extra.

## Functional Backend API
Expand Down Expand Up @@ -143,7 +143,7 @@ through the same functional interface. JAX always uses and returns a PRNG key.

### Backend State Reference

When you pass the optional `generator` argument to `ArrayRNG`, the expected
When you pass the optional `generator` argument to `RandomGenerator`, the expected
object depends on the backend:

| Backend | Generator argument |
Expand All @@ -162,7 +162,7 @@ orng/
├── src/orng/
│ ├── __init__.py # package exports
│ ├── _utils.py # shared helpers (internal)
│ ├── orng.py # ArrayRNG facade
│ ├── orng.py # RandomGenerator wrapper
│ └── backends/ # backend-specific implementations
└── README.md
```
Expand Down
3 changes: 2 additions & 1 deletion src/orng/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
create_functional_backend,
create_functional_backend_from_xp,
)
from .orng import ArrayRNG
from .orng import ArrayRNG, RandomGenerator

__all__ = [
"ArrayRNG",
"RandomGenerator",
"create_backend_from_xp",
"create_functional_backend",
"create_functional_backend_from_xp",
Expand Down
22 changes: 18 additions & 4 deletions src/orng/orng.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Backend-aware random number generation helpers.

This module introduces :class:`ArrayRNG`, a small facade that mimics the subset
This module introduces :class:`RandomGenerator`, a small facade that mimics the subset
of ``numpy.random.Generator`` APIs. The class presents a uniform interface
across NumPy, PyTorch, CuPy, and JAX.
"""
Expand Down Expand Up @@ -59,7 +59,7 @@ def choice(


@dataclass
class ArrayRNG:
class RandomGenerator:
"""Facade exposing ``numpy.random.Generator``-style helpers across backends.

Parameters
Expand Down Expand Up @@ -98,7 +98,7 @@ def from_xp(
seed: int | None = None,
generator: Any | None = None,
device: Any | None = None,
) -> "ArrayRNG":
) -> "RandomGenerator":
return cls(
backend=infer_backend_name_from_xp(xp),
seed=seed,
Expand Down Expand Up @@ -205,4 +205,18 @@ def to_functional(
return backend, self._impl._state


__all__ = ["ArrayRNG"]
class ArrayRNG(RandomGenerator):
"""Deprecated alias for :class:`RandomGenerator`."""

def __post_init__(self) -> None:
import warnings

warnings.warn(
"ArrayRNG is deprecated and will be removed in a future release. "
"Please use orng.Generator instead.",
FutureWarning,
)
super().__post_init__()


__all__ = ["ArrayRNG", "RandomGenerator"]
4 changes: 2 additions & 2 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from orng import ArrayRNG
from orng import RandomGenerator
from orng.backends.cupy import CuPyBackend
from orng.backends.jax import JAXBackend
from orng.backends.numpy import NumPyBackend
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_backend_seeding(backend_case, method_case):

@pytest.fixture()
def rng(backend_name, seed):
return ArrayRNG(backend_name, seed=seed)
return RandomGenerator(backend_name, seed=seed)


def test_array_rng_random(rng):
Expand Down
16 changes: 9 additions & 7 deletions tests/test_orng.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from orng import ArrayRNG, create_backend_from_xp
from orng import RandomGenerator, create_backend_from_xp
from orng._utils import normalize_shape, total_size
from orng.backends import _FACTORIES
from orng.backends import numpy as numpy_backend
Expand Down Expand Up @@ -76,7 +76,9 @@ def fake_factory(**kwargs):

monkeypatch.setitem(_FACTORIES, "numpy", fake_factory)

rng = ArrayRNG(backend="numpy", seed=7, generator="sentinel", device="cpu")
rng = RandomGenerator(
backend="numpy", seed=7, generator="sentinel", device="cpu"
)
assert len(instances) == 1
backend = instances[0]
assert backend.seed == 7
Expand Down Expand Up @@ -129,7 +131,7 @@ def fake_factory(**kwargs):

def test_array_rng_rejects_unknown_backend():
with pytest.raises(ValueError):
ArrayRNG(backend="unknown")
RandomGenerator(backend="unknown")


def test_array_rng_from_xp_infers_backend(monkeypatch):
Expand Down Expand Up @@ -160,7 +162,7 @@ def choice(self, population, *, size, replace, probabilities):
_FACTORIES, "numpy", lambda **kwargs: DummyBackend(**kwargs)
)

rng = ArrayRNG.from_xp(np, seed=123)
rng = RandomGenerator.from_xp(np, seed=123)

assert rng.backend == "numpy"
assert captured_kwargs["seed"] == 123
Expand Down Expand Up @@ -218,7 +220,7 @@ def choice(self, population, *, size, replace, probabilities):
)

key = ("jax-key",)
ArrayRNG(backend="jax", generator=key)
RandomGenerator(backend="jax", generator=key)

assert captured_kwargs["generator"] is key

Expand Down Expand Up @@ -253,7 +255,7 @@ def choice(self, state, population, *, size, replace, probabilities):
lambda name, pure: backend,
)

rng = ArrayRNG(backend="numpy", seed=123)
rng = RandomGenerator(backend="numpy", seed=123)
functional_backend, state = rng.to_functional()

assert functional_backend is backend
Expand Down Expand Up @@ -288,7 +290,7 @@ def choice(self, state, population, *, size, replace, probabilities):
or DummyFunctionalBackend(),
)

rng = ArrayRNG(backend="numpy", seed=123)
rng = RandomGenerator(backend="numpy", seed=123)
rng.to_functional(pure=False)

assert captured == {"name": "numpy", "pure": False}
Expand Down
Loading