Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conda tensorflow adventure (cont'd) #5

Closed
stephane-caron opened this issue Feb 28, 2023 · 7 comments
Closed

Conda tensorflow adventure (cont'd) #5

stephane-caron opened this issue Feb 28, 2023 · 7 comments

Comments

@stephane-caron
Copy link
Contributor

This issue follows the conda installation instructions from #4, but it happens post-installation. Everything imports well, but trying to create a SAC agent fails.

Reproduction code

import jax
import xpag

assert jax.lib.xla_bridge.get_backend().platform == "gpu"
nb_envs = 10  # the number of rollouts in parallel during training
env, eval_env, env_info = xpag.wrappers.gym_vec_env("HalfCheetah-v4", nb_envs)
agent = xpag.agents.SAC(
    env_info["observation_dim"]
    if not env_info["is_goalenv"]
    else env_info["observation_dim"] + env_info["desired_goal_dim"],
    env_info["action_dim"],
    {"actor_lr": 3e-3, "critic_lr": 3e-3, "tau": 5e-2, "seed": 0},
)

Outcome

First, a not-so-inviting warning:

2023-02-28 18:06:42.237635: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:114] *** WARNING *** You are using ptxas 10.1.243, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.

Followed by:

2023-02-28 18:06:44.866666: E external/org_tensorflow/tensorflow/compiler/xla/python/pjit.cc:476] cache miss fail: XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory.
@stephane-caron
Copy link
Contributor Author

This answer suggests installing CUDA from Conda as well, as it will o/w use the old version in my /usr/bin. I'm trying this now (will take a while, bad connection).

@stephane-caron
Copy link
Contributor Author

The answer did remove the first error, but the gist of the second error remains:

2023-02-28 19:08:03.177444: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2410] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory.

I reduced the number of environments to 1 → same outcome.

Do you know how much GPU memory a SAC agent for HalfCheetah-v4 should require?

@stephane-caron
Copy link
Contributor Author

OK, nevermind this is not an xpag-related issue. I just need to figure out a way to make all the dependencies work together on my machine 😅

@perrin-isir
Copy link
Owner

perrin-isir commented Feb 28, 2023

The amount of GPU memory is rarely an issue I think (the default SAC agents in the examples use relatively small networks), but JAX tries to preallocate 90% of the total GPU memory, and I have encountered various issues at this step. One command that has proven useful in some specific cases is the following, before executing the code:
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7
(Of course you can test with other values < 0.9)

By the way do you manage to use JAX for simple operations (sth like jax.numpy.sqrt(4.)) on your machine? It's failing specifically when creating a SAC agent? In this case, could you try also with a TD3 agent (there are less dependencies, in particular TD3 doesn't use tensorflow_probability)?

@stephane-caron
Copy link
Contributor Author

Thanks for suggesting the memory fraction setting. I started from a clean slate (thanks conda), what follows is in a new environment created from xpag's environment.yaml.

In this new setting, gpuGetLastError() failed: out of memory is gone and it seems the script is executing further → #6 (to another error 😅).

By the way do you manage to use JAX for simple operations (sth like jax.numpy.sqrt(4.)) on your machine?

Yes, this works, although I probably should look into that ptxas warning:

Python 3.10.9 | packaged by conda-forge | (main, Feb  2 2023, 20:20:04) [GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
jax.numpy.sqrt(4.)>>> jax.numpy.sqrt(4.)
2023-03-01 13:47:16.968355: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:114] *** WARNING *** You are using ptxas 10.1.243, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.

Array(2., dtype=float32, weak_type=True)
>>> jax.numpy.sqrt(4.)
Array(2., dtype=float32, weak_type=True)

@perrin-isir
Copy link
Owner

On my main computer the ptxas version is 11.5.119.
Here it is mentioned that JAX requires CUDA >= 11.4, so indeed it may cause more serious bugs.

@stephane-caron
Copy link
Contributor Author

Yes, this is now solved by #6 (comment). I could upgrade to CUDA toolkit >= 12.1 by installing CUDA driver == 11.4 (the one corresponding to my GPU driver) from the nvidia conda channel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants