Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prototype jax with dqn #222

Merged
merged 27 commits into from
Jul 31, 2022
Merged

prototype jax with dqn #222

merged 27 commits into from
Jul 31, 2022

Conversation

kinalmehta
Copy link
Collaborator

@kinalmehta kinalmehta commented Jun 28, 2022

Description

JAX implementation for DQN
Implementation for #220

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format with width=500 and height=300).
    • I have added links to the tracked experiments.
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Jun 28, 2022

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Jul 31, 2022 at 7:18PM (UTC)

@kinalmehta
Copy link
Collaborator Author

Added Atari DQN implementation as well.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jun 28, 2022

Thanks for this PR!

Trying to run dqn_atari_jax.py and hit a weird error UNKNOWN: CUDNN_STATUS_EXECUTION. According to jax-ml/jax#6332 (comment), we should run it with

XLA_PYTHON_CLIENT_MEM_FRACTION=.7 python dqn_atari_jax.py which worked for me. We should clearly document this when creating the documentation.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jun 28, 2022

https://wandb.ai/costa-huang/cleanRL/runs/2ymc7qnx?workspace=user-costa-huang

6AA07BDE-8DD7-4C8A-AA6A-77482C66BF99

i did a quick run with DQN atari by it cannot replicate the same level of performance in torch. Would you mind looking into it?

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jun 29, 2022

Thanks for this PR. I looked into it a bit more. There are two complications:

Image dimensions (NCHW vs NHWC)

A pre-processed Atari game image has height H=84, width W=84, channels C=4 from the frame stack, and a batch dimension N=1 .

Pytorch's Conv2d expects the array to be in the format of NCHW (second sentence in its docs) (i.e., have the obs input to have shape (1, 4, 84, 84)

Flax's Conv expects the array to be in the format of NHWC (docs) (i.e., have the obs input to have shape (1, 84, 84, 4)

We can print out the models in the current implementation to compare and confirm this issue:

# in dqn_atari.py
from torchsummary import summary
summary(q_network, (4, 84, 84))
# in dqn_atari_jax.py
print(q_network.tabulate(q_key, obs))

image

We need to fix the image input format with Flax by adding a transpose:

 @nn.compact
 def __call__(self, x):
+    x = jnp.transpose(x, (0, 2, 3, 1))
     x = x / (255.0)
     x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4))(x)
     x = nn.relu(x)
     x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2))(x)
     x = nn.relu(x)
     x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1))(x)
     x = nn.relu(x)
     x = x.reshape((x.shape[0], -1))
     x = nn.Dense(512)(x)
     x = nn.relu(x)
     x = nn.Dense(self.action_dim)(x)
     return x

Padding

After fixing the image format issue, they still don't quite exactly the same:

image

So, it looks like Flax's Conv and torch Conv2d result in divergent behavior. Unfortunately, the original paper does not provide detailed info on the exact model size or shapes.

image

I looked into the reputable dqn_zoo implementation and found they used padding='VALID' in Conv (see what does padding='VALID' mean)

Implementing padding='VALID' has aligned the model shapes and parameters between pytorch and Jax

 @nn.compact
 def __call__(self, x):
     x = jnp.transpose(x, (0, 2, 3, 1))
     x = x / (255.0)
-    x = nn.Conv(32, kernel_size=(8, 8), strides=4)(x)
+    x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding='VALID')(x)
     x = nn.relu(x)
-    x = nn.Conv(64, kernel_size=(4, 4), strides=2)(x)
+    x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding='VALID')(x)
     x = nn.relu(x)
-    x = nn.Conv(64, kernel_size=(3, 3), strides=1)(x)
+    x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding='VALID')(x)
     x = nn.relu(x)
     x = x.reshape((x.shape[0], -1))
     x = nn.Dense(512)(x)
     x = nn.relu(x)
     x = nn.Dense(self.action_dim)(x)
     return x

image

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jun 29, 2022

9DF906FB-B68D-47E3-8434-51EDE000F447

The results look great so far. One other thing: would you mind modifying the script to use TrainState like done in #187?

@kinalmehta
Copy link
Collaborator Author

Thanks for looking into the issue. I completely missed channel location preprocessing. Didn't know sb3 preprocessed it for pytorch.

I've updated the code to use TrainState API.

But I'm not able to use jitted apply function from state.apply_fn. Even if I replace the apply_fn after the jit operation as below.

    q_state = TrainState.create(
        apply_fn=q_network.apply,
        params=q_network.init(q_key, obs),
        target_params=q_network.init(q_key, obs),
        tx=optax.adam(learning_rate=args.learning_rate),
    )

    q_network.apply = jax.jit(q_network.apply)
    q_state = q_state.replace(apply_fn=q_network.apply)

And directly applying jit operation during the creation of TrainState leads to errors during the initialization of parameters.
So is there no possible way to use the state.apply_fn with jit?

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jun 30, 2022

So is there no possible way to use the state.apply_fn with jit?

Maybe not, we can probably just stick with q_network.apply

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jul 7, 2022

cc @yooceii

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jul 18, 2022

image

image

image

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jul 28, 2022

Hey @kinalmehta, everything w/ dqn_atari_jax.py looks good. The only thing left is that we need benchmark experiments for merging dqn_jax.py. Would you prefer to do it in this PR or a separate PR? Both would work for me.

Scratch that. Let me just run the experiments. It shouldn't take that long.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jul 28, 2022

@kinalmehta I got this error

2022-07-28 10:18:40.054807: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:232] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2022-07-28 10:18:40.054833: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:234] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
2022-07-28 10:18:40.054886: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: Attempting to perform BLAS operation using StreamExecutor without BLAS support
Traceback (most recent call last):
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/cleanrl/cleanrl/dqn_jax.py", line 237, in <module>
    loss, old_val, q_state = update(
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.9/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.9/lib/python3.9/site-packages/jax/_src/api.py", line 522, in cache_miss
    out_flat = xla.xla_call(
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.9/lib/python3.9/site-packages/jax/core.py", line 1836, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.9/lib/python3.9/site-packages/jax/core.py", line 1852, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.9/lib/python3.9/site-packages/jax/core.py", line 683, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.9/lib/python3.9/site-packages/jax/_src/dispatch.py", line 199, in _xla_call_impl
    return compiled_fun(*args)
  File "/home/costa/.cache/pypoetry/virtualenvs/cleanrl-0hpcRfYV-py3.9/lib/python3.9/site-packages/jax/_src/dispatch.py", line 717, in _execute_compiled
    out_flat = compiled.execute(in_flat)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Attempting to perform BLAS operation using StreamExecutor without BLAS support
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

See https://wandb.ai/openrlbenchmark/cleanrl/runs/1joyhhtw/logs?workspace=user-costa-huang

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jul 28, 2022

Nvm sorry to disturb you @kinalmehta, there was a zombie process that took over the GPU memory and once I removed that process things start to work again.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jul 28, 2022

Hey @kinalmehta @yooceii, I did another round of quick benchmark and found jitting the action sampling slows down throughput (* see explanation below)

image

image

See report

The reason #231 (comment) found jitting is faster is because the non-jitted baseline (dqn_atari_jax_nojit) was different from this PR's 89dcbb4 (e.g., dqn_atari_jax_nojit uses vmap and 89dcbb4 does not).

Can @kinalmehta and @yooceii confirm my findings? The difference could also stem from hardware differences. Namely, please compare the SPS at 200k steps for

Thank you.

@kinalmehta
Copy link
Collaborator Author

kinalmehta commented Jul 28, 2022

Hey guys,

Even I observed the same behavior. Here is the tensorboard plot for 200k steps for the three jax variants plus torch when I run on my setup (RTX 3060 and Ryzen 9)
Screenshot from 2022-07-28 22-57-30

The codes are exactly from the links you shared.

One possible reason I could think of is that the overhead of moving data between GPU and CPU is more than the speed-up provided by jitted action.

Thanks

@yooceii
Copy link
Collaborator

yooceii commented Jul 31, 2022

image
Can confirm old has the highest SPS.

Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything LGTM. Thanks @kinalmehta!

@vwxyzjn vwxyzjn merged commit f27b6d7 into vwxyzjn:master Jul 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants