-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
prototype jax with dqn #222
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
Added Atari DQN implementation as well. |
Thanks for this PR! Trying to run
|
https://wandb.ai/costa-huang/cleanRL/runs/2ymc7qnx?workspace=user-costa-huang i did a quick run with DQN atari by it cannot replicate the same level of performance in torch. Would you mind looking into it? |
Thanks for this PR. I looked into it a bit more. There are two complications: Image dimensions (NCHW vs NHWC)A pre-processed Atari game image has height H=84, width W=84, channels C=4 from the frame stack, and a batch dimension N=1 . Pytorch's Flax's We can print out the models in the current implementation to compare and confirm this issue: # in dqn_atari.py
from torchsummary import summary
summary(q_network, (4, 84, 84))
# in dqn_atari_jax.py
print(q_network.tabulate(q_key, obs)) We need to fix the image input format with Flax by adding a transpose: @nn.compact
def __call__(self, x):
+ x = jnp.transpose(x, (0, 2, 3, 1))
x = x / (255.0)
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4))(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2))(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1))(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(self.action_dim)(x)
return x PaddingAfter fixing the image format issue, they still don't quite exactly the same: So, it looks like Flax's I looked into the reputable dqn_zoo implementation and found they used Implementing @nn.compact
def __call__(self, x):
x = jnp.transpose(x, (0, 2, 3, 1))
x = x / (255.0)
- x = nn.Conv(32, kernel_size=(8, 8), strides=4)(x)
+ x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding='VALID')(x)
x = nn.relu(x)
- x = nn.Conv(64, kernel_size=(4, 4), strides=2)(x)
+ x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding='VALID')(x)
x = nn.relu(x)
- x = nn.Conv(64, kernel_size=(3, 3), strides=1)(x)
+ x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding='VALID')(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(self.action_dim)(x)
return x |
The results look great so far. One other thing: would you mind modifying the script to use |
Thanks for looking into the issue. I completely missed channel location preprocessing. Didn't know sb3 preprocessed it for pytorch. I've updated the code to use TrainState API. But I'm not able to use jitted apply function from state.apply_fn. Even if I replace the apply_fn after the jit operation as below. q_state = TrainState.create(
apply_fn=q_network.apply,
params=q_network.init(q_key, obs),
target_params=q_network.init(q_key, obs),
tx=optax.adam(learning_rate=args.learning_rate),
)
q_network.apply = jax.jit(q_network.apply)
q_state = q_state.replace(apply_fn=q_network.apply) And directly applying jit operation during the creation of TrainState leads to errors during the initialization of parameters. |
Maybe not, we can probably just stick with |
cc @yooceii |
Scratch that. Let me just run the experiments. It shouldn't take that long. |
@kinalmehta I got this error
See https://wandb.ai/openrlbenchmark/cleanrl/runs/1joyhhtw/logs?workspace=user-costa-huang |
Nvm sorry to disturb you @kinalmehta, there was a zombie process that took over the GPU memory and once I removed that process things start to work again. |
Hey @kinalmehta @yooceii, I did another round of quick benchmark and found jitting the action sampling slows down throughput (* see explanation below) See report The reason #231 (comment) found jitting is faster is because the non-jitted baseline ( Can @kinalmehta and @yooceii confirm my findings? The difference could also stem from hardware differences. Namely, please compare the SPS at 200k steps for
Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything LGTM. Thanks @kinalmehta!
Description
JAX implementation for DQN
Implementation for #220
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).