-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Closed
Description
I am using the latest version of jax
v0.7.0 with tensorflow_probability
and encounter the error in the title when defining a distribution:
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
probs = jnp.array(object=[[0.4, 0.5, 0.1], [0.1, 0.2, 0.7]])
categorical_dist = tfp.distributions.Categorical(probs=probs)
The whole error can be shown as follows:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 56, in __getattr__
module = self._load()
^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 43, in _load
module = importlib.import_module(self.__name__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 940, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/__init__.py", line 42, in <module>
from tensorflow_probability.substrates.jax import bijectors
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py", line 19, in <module>
from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py", line 17, in <module>
from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py", line 19, in <module>
from tensorflow_probability.python.internal.backend.jax import compat
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py", line 17, in <module>
from tensorflow_probability.python.internal.backend.jax import v1
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/v1.py", line 23, in <module>
from tensorflow_probability.python.internal.backend.jax import linalg_impl
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg_impl.py", line 23, in <module>
from tensorflow_probability.python.internal.backend.jax import ops
File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 681, in <module>
jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(message)
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.
System information
python==3.11
jax==0.7.0
tensorflow_probability==0.25.0
Update
I just find out that this has been updated in commit 135080b6b1ac5724fc1731b0a9ca6f2010b1aea5. However, the latest release does not include that updated code, and hence, the error.
I wonder if the tensorflow-probability team could release a new version of tensorflow-probability that can work with jax 0.7.0. going forward.
slinderman
Metadata
Metadata
Assignees
Labels
No labels