-
Notifications
You must be signed in to change notification settings - Fork 341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax dependency error, solution other than downgrade? #2693
Comments
Hi. Can you please verify the installed version? Since scvi-tools 1.1.0, there is no explicit dependency in scvi-tools on chex anymore. It seems that flax is causing this ImportError. Can you import |
Thank you for the quick response! I have confirmed the version:
And running
Looking at the traceback more closely, the trace seems to be showing: Thanks! |
Flax is used within scvi-tools. There isn't a specific requirement for the Flax version. However, there is a mismatch in your environment in the JAX and Flax version installed (Flax is older than JAX) and this is causing issues. If you install JAX from scratch in a new environment using pypi, the error shouldn't occur. You can try uninstalling Flax and JAX in the current environment and reinstall JAX (will install a correct version of Flax) and hope that it's fixed. My own experience is that it's easier to set up a new environment. |
Thanks, I will give that a try! |
Looking into the versions, this is what I have - both flax and jax seem to be the latest versions shown on their respective github pages, so I'm not sure that is the issue here.
I went ahead and uninstalled flax, jax, and jaxlib. Then ran this
|
I'm sorry and you need to also install Flax. Can you please check in a new environment to install JAX and Flax and see that it works. It's very difficult to fix a conda environment with wrong dependencies. We can support that creating a new conda environment and following the installation works: https://docs.scvi-tools.org/en/stable/installation.html. |
I know this issue has been raised in other issues (#2501 #2530), but downgrading
jax
causes other packages in my environment to throw errors, which seems like a rabbit hole to go down. Are there plans to fix this in a near-future release? Thanks!For reference I am using scvi-tools
1.1.2
, and am receiving the import error that others have reported:"AttributeError: module 'jax.random' has no attribute 'KeyArray'"
See full traceback below.
The text was updated successfully, but these errors were encountered: