Skip to content

jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0 #2009

@cnguyen10

Description

@cnguyen10

I am using the latest version of jax v0.7.0 with tensorflow_probability and encounter the error in the title when defining a distribution:

import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp

probs = jnp.array(object=[[0.4, 0.5, 0.1], [0.1, 0.2, 0.7]])

categorical_dist = tfp.distributions.Categorical(probs=probs)

The whole error can be shown as follows:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 56, in __getattr__
    module = self._load()
             ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 43, in _load
    module = importlib.import_module(self.__name__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/__init__.py", line 42, in <module>
    from tensorflow_probability.substrates.jax import bijectors
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py", line 19, in <module>
    from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py", line 17, in <module>
    from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py", line 19, in <module>
    from tensorflow_probability.python.internal.backend.jax import compat
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py", line 17, in <module>
    from tensorflow_probability.python.internal.backend.jax import v1
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/v1.py", line 23, in <module>
    from tensorflow_probability.python.internal.backend.jax import linalg_impl
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg_impl.py", line 23, in <module>
    from tensorflow_probability.python.internal.backend.jax import ops
  File "/usr/local/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 681, in <module>
    jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(message)
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.

System information
python==3.11
jax==0.7.0
tensorflow_probability==0.25.0

Update
I just find out that this has been updated in commit 135080b6b1ac5724fc1731b0a9ca6f2010b1aea5. However, the latest release does not include that updated code, and hence, the error.

I wonder if the tensorflow-probability team could release a new version of tensorflow-probability that can work with jax 0.7.0. going forward.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions