Skip to content

Commit

Permalink
Replace usages of PRNGKey with key (#2)
Browse files Browse the repository at this point in the history
Deprecates usage of `jax.random.PRNGKey` in favour of `jax.random.key` as per
[JEP 9263](https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html)
  • Loading branch information
tttc3 committed Feb 19, 2024
1 parent 830574d commit 5ad1a1b
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ from mccube import (
gaussian_wasserstein_metric,
)

key = jr.PRNGKey(42)
key = jr.key(42)
n, d = 512, 10
t0 = 0.0
epochs = 512
Expand Down
6 changes: 3 additions & 3 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ from mccube import gaussian_wasserstein_metric, unpack_particles

jax.config.update("jax_enable_x64", True)

key, rng_key = jr.split(jr.PRNGKey(42))
key, rng_key = jr.split(jr.key(42))
n, d = 512, 10
t0 = 0.0
n_epochs = 1024
Expand Down Expand Up @@ -133,7 +133,7 @@ def inference_loop(kernel, initial_state, n_epochs, num_chains, *, key):

return states

key, sampler_key = jr.split(jr.PRNGKey(42))
key, sampler_key = jr.split(jr.key(42))
sampler = blackjax.mala(logdensity, dt0)
init_state = jax.vmap(sampler.init)(y0)
state = inference_loop(
Expand Down Expand Up @@ -181,7 +181,7 @@ from mccube import (
BinaryTreePartitioningKernel,
)

key = jr.PRNGKey(42)
key = jr.key(42)
gaussian_cubature = Hadamard(GaussianRegion(d))
mcc_cde = diffrax.WeaklyDiagonalControlTerm(
lambda t, p, args: jnp.sqrt(2.0),
Expand Down
2 changes: 1 addition & 1 deletion mccube/_kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class PartitioningRecombinationKernel(AbstractRecombinationKernel):
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(42)
key = jr.key(42)
y0 = jnp.ones((64,8))
n, d = y0.shape
Expand Down
4 changes: 2 additions & 2 deletions mccube/_kernels/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MonteCarloKernel(AbstractRecombinationKernel):
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(42)
key = jr.key(42)
kernel = mccube.MonteCarloKernel({"y": 3}, key=key)
y0 = {"y": jnp.ones((10,2))}
result = kernel(..., y0, ...)
Expand Down Expand Up @@ -80,7 +80,7 @@ class MonteCarloPartitioningKernel(AbstractPartitioningKernel):
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(42)
key = jr.key(42)
kernel = mccube.MonteCarloKernel(..., key=key)
partitioning_kernel = mccube.MonteCarloPartitioningKernel(4, kernel)
y0 = jnp.ones((12,2))
Expand Down
2 changes: 1 addition & 1 deletion mccube/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MCCSolver(AbstractWrappedSolver[_SolverState]):
import jax.random as jr
from diffrax import diffeqsolve, Euler
key, rng_key = jr.split(jr.PRNGKey(42))
key, rng_key = jr.split(jr.key(42))
t0, t1 = 0.0, 1.0
dt0 = 0.001
particles = jnp.ones((32,8))
Expand Down
6 changes: 3 additions & 3 deletions tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_monte_carlo_partitioning_kernel():

y0 = jnp.array([[1.0, 0.01], [2.0, 1.0], [3.0, 100.0], [4.0, 10000.0]])

key = jr.PRNGKey(42)
key = jr.key(42)
mc_kernel = MonteCarloKernel(None, key=key)
kernel = mccube.MonteCarloPartitioningKernel(n_parts, mc_kernel)
values = kernel(0.0, y0, ...)
Expand All @@ -80,7 +80,7 @@ def test_monte_carlo_partitioning_kernel():
jnp.unique(values, return_counts=True), jnp.unique(y0, return_counts=True)
)

key = jr.PRNGKey(42)
key = jr.key(42)
mc_kernel = MonteCarloKernel(None, weighting_function=lambda x: x, key=key)
kernel = mccube.MonteCarloPartitioningKernel(n_parts, mc_kernel)
values = kernel(0.0, y0, ..., weighted=True)
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_binary_tree_partitioning_kernel(mode):

# n, d = 64, 2

# key = jr.PRNGKey(42)
# key = jr.key(42)
# y0 = jr.multivariate_normal(key, jnp.zeros(d), jnp.eye(d), (n,))
# weights = jnp.arange(1.0, n + 1.0)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .helpers import gaussian_formulae

key = jr.PRNGKey(42)
key = jr.key(42)
init_key, rng_key = jr.split(key)
t0 = 0.0
dt0 = 0.05
Expand Down Expand Up @@ -46,14 +46,14 @@ def test_diffrax_ula():
ode = ODETerm(lambda t, p, args: grad_logdensity(p))
cde = WeaklyDiagonalControlTerm(
lambda t, p, args: jnp.sqrt(2.0),
VirtualBrownianTree(t0, t1, dt0 / 10, (k, d), key=jr.PRNGKey(42)),
VirtualBrownianTree(t0, t1, dt0 / 10, (k, d), key=jr.key(42)),
)
terms = MultiTerm(ode, cde)
diffeqsolve(terms, Euler(), t0, t1, dt0, y0)


def test_MCCSolver_init():
key = jr.PRNGKey(42)
key = jr.key(42)
with pytest.raises(ValueError) as e, pytest.warns(UserWarning) as w:
mccube.MCCSolver(EulerHeun(), mccube.MonteCarloKernel(10, key=key), 0)

Expand Down

0 comments on commit 5ad1a1b

Please sign in to comment.