You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Due to issues like jax-ml/jax#480, it makes sense to use discrete samplers directly from scipy and then transfer the results back to device using jax.device_put(). I checked that this is often an order of magnitude faster for the CPU, and should be safe since the samplers aren't reparametrized. Once we have have a JAX native multinomial distribution, we can change to that later.
The text was updated successfully, but these errors were encountered:
Due to issues like jax-ml/jax#480, it makes sense to use discrete samplers directly from scipy and then transfer the results back to device using
jax.device_put()
. I checked that this is often an order of magnitude faster for the CPU, and should be safe since the samplers aren't reparametrized. Once we have have a JAX native multinomial distribution, we can change to that later.The text was updated successfully, but these errors were encountered: