You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The error happens because new version of jax requires a specific pattern for using jit's static_argnums. See google/jax#595 for more context. The failed functions include:
xlogy
xlog1py
standard_gamma
In addition, I think that we can simplify the implementation of these functions by using decorator custom_transform as in cumsum, cumprod.
Are these functions failing with the jax master? We aren't using jit at least in xlogy and xlog1py, so that doesn't seem to immediately conflict with the JAX change. It also seems like a bunch of functions have been reorganized, so its probably worth tackling in one shot with the next jax release?
Yes, they fail with jax master. Maybe due to another reason. Yes, I also think that we just need to address these kinds of issues with one shot before our alpha release. It is almost there I guess. :)
The error happens because new version of jax requires a specific pattern for using
jit
's static_argnums. See google/jax#595 for more context. The failed functions include:In addition, I think that we can simplify the implementation of these functions by using decorator
custom_transform
as incumsum
,cumprod
.cc @neerajprad
The text was updated successfully, but these errors were encountered: