Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make JAX an optional dependency #2281

Open
martinkim0 opened this issue Oct 10, 2023 · 4 comments
Open

Make JAX an optional dependency #2281

martinkim0 opened this issue Oct 10, 2023 · 4 comments

Comments

@martinkim0
Copy link
Contributor

No description provided.

@martinkim0 martinkim0 added this to the scvi-tools 1.1.0 milestone Oct 10, 2023
@racng
Copy link

racng commented Nov 15, 2023

I would find this very helpful.
I keep getting this error when trying to import scvi-tools

ImportError: cannot import name 'ShapedArray' from 'jax' (/users/rng/mambaforge/envs/scai-v4/lib/python3.10/site-packages/jax/__init__.py)

I try making jax=0.4.13 but its keeps getting replaced with v0.4.20 because it is required by dependencies in pertpy and also scvi-tools is a requirement for pertpy.
I can't figure out an environment where scvi-tools, scarches, and pertpy can co-exist.
Would appreciate it if anyone can share working combination of package versions, especially for these packages:

python
pytorch
torchvision
torchaudio
pytorch-cuda
jax
jaxlib
chex
flax
scvi-tools
scarches
pertpy

Thanks!

@martinkim0
Copy link
Contributor Author

Hi, sorry that you're running into this issue - we're working on making this available in #2318 as part of our next scvi-tools release (v1.1). However, I'm unable to reproduce the ImportError you're getting with JAX 0.4.20. Could you post the full traceback you're seeing? Thanks.

@martinkim0 martinkim0 removed this from the scvi-tools 1.1.0 milestone Nov 23, 2023
@Jingyao12
Copy link

I am running the same issue. My Jax version is 0.4.23

ImportError Traceback (most recent call last)
Cell In[8], line 1
----> 1 import scvi

File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/init.py:10
7 from ._settings import settings
9 # this import needs to come after prior imports to prevent circular import
---> 10 from . import data, model, external, utils
12 # python-poetry/poetry#2366 (comment)
13 # python-poetry/poetry#144 (comment)
14 try:

File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/model/init.py:2
1 from . import utils
----> 2 from ._amortizedlda import AmortizedLDA
3 from ._autozi import AUTOZI
4 from ._condscvi import CondSCVI

File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/model/_amortizedlda.py:14
12 from scvi.data import AnnDataManager
13 from scvi.data.fields import LayerField
---> 14 from scvi.module import AmortizedLDAPyroModule
15 from scvi.utils import setup_anndata_dsp
17 from .base import BaseModelClass, PyroSviTrainMixin

File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/module/init.py:4
2 from ._autozivae import AutoZIVAE
3 from ._classifier import Classifier
----> 4 from ._jaxvae import JaxVAE
5 from ._mrdeconv import MRDeconv
6 from ._multivae import MULTIVAE

File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/module/_jaxvae.py:7
5 import numpy as np
6 import numpyro.distributions as dist
----> 7 from flax import linen as nn
8 from flax.linen.initializers import variance_scaling
10 from scvi import REGISTRY_KEYS

File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/flax/init.py:20
18 from . import core
19 from . import jax_utils
---> 20 from . import linen
21 from . import serialization
22 from . import traverse_util

File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/flax/linen/init.py:47
18 # pylint: disable=g-multiple-import
19 # re-export commonly used modules and functions
20 from .activation import (
21 PReLU as PReLU,
22 celu as celu,
(...)
45 tanh as tanh
46 )
---> 47 from .attention import (
48 MultiHeadDotProductAttention as MultiHeadDotProductAttention,
49 SelfAttention as SelfAttention,
50 combine_masks as combine_masks,
51 dot_product_attention as dot_product_attention,
52 dot_product_attention_weights as dot_product_attention_weights,
53 make_attention_mask as make_attention_mask,
54 make_causal_mask as make_causal_mask
55 )
56 from .combinators import Sequential as Sequential
57 from ..core import (
58 DenyList as DenyList,
59 FrozenDict as FrozenDict,
60 broadcast as broadcast
61 )

File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/flax/linen/attention.py:22
19 from flax.linen.dtypes import promote_dtype
21 from flax.linen.initializers import zeros
---> 22 from flax.linen.linear import default_kernel_init
23 from flax.linen.linear import DenseGeneral
24 from flax.linen.linear import PrecisionLike

File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/flax/linen/linear.py:30
28 from jax import eval_shape
29 from jax import lax
---> 30 from jax import ShapedArray
31 import jax.numpy as jnp
32 import numpy as np

ImportError: cannot import name 'ShapedArray' from 'jax'

@Jingyao12
Copy link

It was fixed by install flax see #2216

@martinkim0 martinkim0 added the P1 label Jul 12, 2024
@martinkim0 martinkim0 added this to the scvi-tools 2.0 milestone Jul 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants