Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax dependency error, solution other than downgrade? #2693

Closed
gouinK opened this issue Apr 6, 2024 · 6 comments
Closed

Jax dependency error, solution other than downgrade? #2693

gouinK opened this issue Apr 6, 2024 · 6 comments
Labels

Comments

@gouinK
Copy link

gouinK commented Apr 6, 2024

I know this issue has been raised in other issues (#2501 #2530), but downgrading jax causes other packages in my environment to throw errors, which seems like a rabbit hole to go down. Are there plans to fix this in a near-future release? Thanks!

For reference I am using scvi-tools 1.1.2, and am receiving the import error that others have reported:
"AttributeError: module 'jax.random' has no attribute 'KeyArray'"

See full traceback below.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 11
      9 import gc
     10 import os
---> 11 import scvi
     12 import rapids_singlecell as rsc
     13 import matplotlib.pyplot as plt

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/scvi/__init__.py:11
      8 from ._settings import settings
     10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import data, model, external, utils
     13 from importlib.metadata import version
     15 package_name = "scvi-tools"

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/scvi/data/__init__.py:25
      4 from ._datasets import (
      5     annotation_simulation,
      6     brainlarge_dataset,
   (...)
     22     synthetic_iid,
     23 )
     24 from ._manager import AnnDataManager, AnnDataManagerValidationCheck
---> 25 from ._preprocessing import (
     26     add_dna_sequence,
     27     organize_cite_seq_10x,
     28     organize_multiome_anndatas,
     29     poisson_gene_selection,
     30     reads_to_fragments,
     31 )
     32 from ._read import read_10x_atac, read_10x_multiome
     34 __all__ = [
     35     "AnnTorchDataset",
     36     "AnnDataManagerValidationCheck",
   (...)
     66     "cellxgene",
     67 ]

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/scvi/data/_preprocessing.py:12
      9 import torch
     10 from scipy.sparse import issparse
---> 12 from scvi.model._utils import parse_device_args
     13 from scvi.utils import dependencies, track
     14 from scvi.utils._docstrings import devices_dsp

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/scvi/model/__init__.py:2
      1 from . import utils
----> 2 from ._amortizedlda import AmortizedLDA
      3 from ._autozi import AUTOZI
      4 from ._condscvi import CondSCVI

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/scvi/model/_amortizedlda.py:15
     13 from scvi.data import AnnDataManager
     14 from scvi.data.fields import LayerField
---> 15 from scvi.module import AmortizedLDAPyroModule
     16 from scvi.utils import setup_anndata_dsp
     18 from .base import BaseModelClass, PyroSviTrainMixin

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/scvi/module/__init__.py:1
----> 1 from ._amortizedlda import AmortizedLDAPyroModule
      2 from ._autozivae import AutoZIVAE
      3 from ._classifier import Classifier

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/scvi/module/_amortizedlda.py:15
     13 from scvi._constants import REGISTRY_KEYS
     14 from scvi._types import Tunable
---> 15 from scvi.module.base import PyroBaseModuleClass, auto_move_data
     16 from scvi.nn import Encoder
     18 _AMORTIZED_LDA_PYRO_MODULE_NAME = "amortized_lda"

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/scvi/module/base/__init__.py:1
----> 1 from ._base_module import (
      2     BaseMinifiedModeModuleClass,
      3     BaseModuleClass,
      4     JaxBaseModuleClass,
      5     LossOutput,
      6     PyroBaseModuleClass,
      7     TrainStateWithState,
      8 )
      9 from ._decorators import auto_move_data, flax_configure
     11 __all__ = [
     12     "BaseModuleClass",
     13     "LossOutput",
   (...)
     19     "BaseMinifiedModeModuleClass",
     20 ]

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/scvi/module/base/_base_module.py:14
     12 import pyro
     13 import torch
---> 14 from flax.training import train_state
     15 from jax import random
     16 from jaxlib.xla_extension import Device

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/flax/training/train_state.py:17
      1 # Copyright 2024 The Flax Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     15 from typing import Any, Callable
---> 17 import optax
     19 from flax import core, struct
     20 from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/__init__.py:17
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
     18 from optax import losses
     19 from optax import monte_carlo

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/contrib/__init__.py:17
      1 # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Contributed optimizers in Optax."""
---> 17 from optax.contrib.cocob import cocob
     18 from optax.contrib.cocob import COCOBState
     19 from optax.contrib.complex_valued import split_real_and_imaginary

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/contrib/cocob.py:25
     23 import jax.numpy as jnp
     24 import jax.tree_util as jtu
---> 25 from optax._src import base
     28 class COCOBState(NamedTuple):
     29   """State for COntinuous COin Betting."""

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/_src/base.py:19
     15 """Base interfaces and datatypes."""
     17 from typing import Any, Callable, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union
---> 19 import chex
     20 import jax
     21 import jax.numpy as jnp

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/__init__.py:17
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/_src/asserts.py:26
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/_src/asserts_internal.py:34
     31 from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
     33 from absl import logging
---> 34 from chex._src import pytypes
     35 import jax
     36 from jax.experimental import checkify

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/_src/pytypes.py:54
     52 Numeric = Union[Array, Scalar]
     53 Shape = jax.core.Shape
---> 54 PRNGKey = jax.random.KeyArray
     55 PyTreeDef = jax.tree_util.PyTreeDef
     56 Device = jax.Device

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr.<locals>.getattr(name)
     52   warnings.warn(message, DeprecationWarning, stacklevel=2)
     53   return fn
---> 54 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.random' has no attribute 'KeyArray'
@gouinK gouinK added the bug label Apr 6, 2024
@canergen
Copy link
Contributor

canergen commented Apr 6, 2024

Hi. Can you please verify the installed version? Since scvi-tools 1.1.0, there is no explicit dependency in scvi-tools on chex anymore. It seems that flax is causing this ImportError. Can you import from flax.training import train_state? If you can't import flax correctly, the easiest solution would be to set up a new conda environment. Fixing these dependency issues can unfortunately otherwise be a lengthy process.

@gouinK
Copy link
Author

gouinK commented Apr 6, 2024

Thank you for the quick response!

I have confirmed the version:

scvi-tools                1.1.2                    pypi_0    pypi

And running from flax.training import train_state does indeed throw the same error, see traceback below:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 1
----> 1 from flax.training import train_state

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/flax/training/train_state.py:17
      1 # Copyright 2024 The Flax Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     15 from typing import Any, Callable
---> 17 import optax
     19 from flax import core, struct
     20 from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/__init__.py:17
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
     18 from optax import losses
     19 from optax import monte_carlo

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/contrib/__init__.py:17
      1 # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Contributed optimizers in Optax."""
---> 17 from optax.contrib.cocob import cocob
     18 from optax.contrib.cocob import COCOBState
     19 from optax.contrib.complex_valued import split_real_and_imaginary

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/contrib/cocob.py:25
     23 import jax.numpy as jnp
     24 import jax.tree_util as jtu
---> 25 from optax._src import base
     28 class COCOBState(NamedTuple):
     29   """State for COntinuous COin Betting."""

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/_src/base.py:19
     15 """Base interfaces and datatypes."""
     17 from typing import Any, Callable, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union
---> 19 import chex
     20 import jax
     21 import jax.numpy as jnp

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/__init__.py:17
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/_src/asserts.py:26
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/_src/asserts_internal.py:34
     31 from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
     33 from absl import logging
---> 34 from chex._src import pytypes
     35 import jax
     36 from jax.experimental import checkify

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/_src/pytypes.py:54
     52 Numeric = Union[Array, Scalar]
     53 Shape = jax.core.Shape
---> 54 PRNGKey = jax.random.KeyArray
     55 PyTreeDef = jax.tree_util.PyTreeDef
     56 Device = jax.Device

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr.<locals>.getattr(name)
     52   warnings.warn(message, DeprecationWarning, stacklevel=2)
     53   return fn
---> 54 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.random' has no attribute 'KeyArray'

Looking at the traceback more closely, the trace seems to be showing:
flax --> optax --> chex --> jax
Is flax a dependency of scvi-tools and if so is there a suggested version of flax to use?

Thanks!

@canergen
Copy link
Contributor

canergen commented Apr 6, 2024

Flax is used within scvi-tools. There isn't a specific requirement for the Flax version. However, there is a mismatch in your environment in the JAX and Flax version installed (Flax is older than JAX) and this is causing issues. If you install JAX from scratch in a new environment using pypi, the error shouldn't occur. You can try uninstalling Flax and JAX in the current environment and reinstall JAX (will install a correct version of Flax) and hope that it's fixed. My own experience is that it's easier to set up a new environment.

@canergen canergen closed this as completed Apr 6, 2024
@gouinK
Copy link
Author

gouinK commented Apr 6, 2024

Thanks, I will give that a try!

@gouinK
Copy link
Author

gouinK commented Apr 6, 2024

Looking into the versions, this is what I have - both flax and jax seem to be the latest versions shown on their respective github pages, so I'm not sure that is the issue here.

flax                      0.8.2 
jax                       0.4.26
jaxlib                    0.4.26
chex                      0.1.7
optax                     0.2.1

I went ahead and uninstalled flax, jax, and jaxlib. Then ran this
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, which resulted in this:

scvi-tools 1.1.2 requires flax, which is not installed.
Successfully installed jax-0.4.26 jaxlib-0.4.26+cuda12.cudnn89 nvidia-cuda-nvcc-cu12-12.4.131

@canergen
Copy link
Contributor

canergen commented Apr 6, 2024

I'm sorry and you need to also install Flax. Can you please check in a new environment to install JAX and Flax and see that it works. It's very difficult to fix a conda environment with wrong dependencies. We can support that creating a new conda environment and following the installation works: https://docs.scvi-tools.org/en/stable/installation.html.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants