-
Notifications
You must be signed in to change notification settings - Fork 146
Closed
Description
Description
JAX inner-most dispatch for RandomVariables: jax_sample_fn
, look like
pytensor/pytensor/link/jax/dispatch/random.py
Lines 146 to 172 in 964cccb
@jax_sample_fn.register(ptr.CauchyRV) | |
@jax_sample_fn.register(ptr.GumbelRV) | |
@jax_sample_fn.register(ptr.LaplaceRV) | |
@jax_sample_fn.register(ptr.LogisticRV) | |
@jax_sample_fn.register(ptr.NormalRV) | |
def jax_sample_fn_loc_scale(op, node): | |
"""JAX implementation of random variables in the loc-scale families. | |
JAX only implements the standard version of random variables in the | |
loc-scale family. We thus need to translate and rescale the results | |
manually. | |
""" | |
name = op.name | |
jax_op = getattr(jax.random, name) | |
def sample_fn(rng, size, dtype, *parameters): | |
rng_key = rng["jax_state"] | |
rng_key, sampling_key = jax.random.split(rng_key, 2) | |
loc, scale = parameters | |
if size is None: | |
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape | |
sample = loc + jax_op(sampling_key, size, dtype) * scale | |
rng["jax_state"] = rng_key | |
return (rng, sample) | |
return sample_fn |
The whole rng logic could be handled on the outermost dispatch jax_funcify_RandomVariable
instead:
pytensor/pytensor/link/jax/dispatch/random.py
Lines 104 to 117 in 964cccb
if None in static_size: | |
assert_size_argument_jax_compatible(node) | |
def sample_fn(rng, size, *parameters): | |
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters) | |
else: | |
def sample_fn(rng, size, *parameters): | |
return jax_sample_fn(op, node=node)( | |
rng, static_size, out_dtype, *parameters | |
) | |
return sample_fn |
If an implementation needs a split other than 2, they can split the provided rng again anyway.