-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conda tensorflow adventure (cont'd) #5
Comments
This answer suggests installing CUDA from Conda as well, as it will o/w use the old version in my |
The answer did remove the first error, but the gist of the second error remains:
I reduced the number of environments to 1 → same outcome. Do you know how much GPU memory a SAC agent for HalfCheetah-v4 should require? |
OK, nevermind this is not an xpag-related issue. I just need to figure out a way to make all the dependencies work together on my machine 😅 |
The amount of GPU memory is rarely an issue I think (the default SAC agents in the examples use relatively small networks), but JAX tries to preallocate 90% of the total GPU memory, and I have encountered various issues at this step. One command that has proven useful in some specific cases is the following, before executing the code: By the way do you manage to use JAX for simple operations (sth like |
Thanks for suggesting the memory fraction setting. I started from a clean slate (thanks conda), what follows is in a new environment created from xpag's In this new setting,
Yes, this works, although I probably should look into that ptxas warning: Python 3.10.9 | packaged by conda-forge | (main, Feb 2 2023, 20:20:04) [GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
jax.numpy.sqrt(4.)>>> jax.numpy.sqrt(4.)
2023-03-01 13:47:16.968355: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:114] *** WARNING *** You are using ptxas 10.1.243, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.
Array(2., dtype=float32, weak_type=True)
>>> jax.numpy.sqrt(4.)
Array(2., dtype=float32, weak_type=True) |
On my main computer the ptxas version is 11.5.119. |
Yes, this is now solved by #6 (comment). I could upgrade to CUDA toolkit >= 12.1 by installing CUDA driver == 11.4 (the one corresponding to my GPU driver) from the nvidia conda channel. |
This issue follows the conda installation instructions from #4, but it happens post-installation. Everything imports well, but trying to create a SAC agent fails.
Reproduction code
Outcome
First, a not-so-inviting warning:
Followed by:
The text was updated successfully, but these errors were encountered: