-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensorflow[and-cuda] 2.15.0/2.15.1 compatibility with jax[cuda12] #68290
Comments
@attaluris TensorFlow[and-cuda] 2.15.0/2.15.1 is likely not compatible with jax[cuda12]. There's a version mismatch with respect to the NVIDIA NCCL library, a component needed for GPU support in both TensorFlow and JAX. Thank you! |
@sushreebarsa Is there a way we could loosen the strict requirement in tensorflow from I also raised an issue in the Jax repo |
@attaluris Could you try to update tensorflow to 2.16.1 and jax to 0.4.28? Thank you! |
@sushreebarsa thanks for the detail! 👀 |
@attaluris On ubuntu 20.04LTS machine with Tesla P100 GPU, we tried to install JAX 0.4.28 cuda version first and then installed tensorflow 2.15.1. Both were installed successfully. JAX and tf were able to detect GPUs as well. Please have a look at the below screenshot for reference. Thank you! |
This issue is stale because it has been open for 7 days with no activity. It will be closed if no further activity occurs. Thank you. |
This issue was closed because it has been inactive for 7 days since being marked as stale. Please reopen if you'd like to work on this further. |
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
No
Source
source
TensorFlow version
2.15.0/2.15.1
Custom code
Yes
OS platform and distribution
Debian Bulleye
Mobile device
No response
Python version
3.9/3.10
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
12.2
GPU model and memory
v100
Current behavior?
Hey y'all! I think
tensorflow[and-cuda]
is incompatible withjax[cuda12]
and just wanted to clarify if this was expected.The solve error I'm getting is:
and none of the
jax[cuda12]
versions with GPU compatibility supportnvidia-nccl-cu12=2.16.5
; does this requirement need to be hard or can it be looser to accomodate higher versions ofnvidia-nccl-cu12
?Standalone code to reproduce the issue
Relevant log output
Because tensorflow[and-cuda] (2.15.0) depends on nvidia-nccl-cu12 (2.16.5) and jax[cuda12] (0.4.23) depends on nvidia-nccl-cu12 (>=2.18.3), tensorflow[and-cuda] (2.15.0) is incompatible with jax[cuda12] (0.4.23). So, because hex-packages depends on both jax[cuda12] (0.4.23) and tensorflow[and-cuda] (2.15.0), version solving failed.
whenever I use a version of
jax
that has thecuda12
extraThe text was updated successfully, but these errors were encountered: