-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PPO + JAX + EnvPool + MuJoCo #217
base: master
Are you sure you want to change the base?
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
It seems that there isn't that much benefit in PPO - the SPS metric is not a lot better, as shown below. Note: there is probably a bug... that's why the sample efficiency suffers. Maybe I was implementing PPO using the incorrect paradigm with JAX. Any thoughts on this @joaogui1 and @ikostrikov? Thanks! |
I'm not sure if obs = obs.at[step].set(x) is indeed in-place inside of jit. I think in this specific case it still creates a new array. I think it's truly in-place only for specific use cases. For example, when memory is donated (on TPU and GPU only). Could you double check that? |
if args.anneal_lr: | ||
frac = 1.0 - (update - 1.0) / num_updates | ||
lrnow = frac * args.learning_rate | ||
agent_optimizer_state[1].hyperparams["learning_rate"] = lrnow | ||
agent_optimizer.update(agent_params, agent_optimizer_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my experience, there's a gain if the main for loop can be replaced with lax.fori_loop
Maybe the documentation meant if you had created an array inside the JIT the operation would be in place? I tested out
which gives
|
@vwxyzjn yes, I think it's either for arrays created inside of jit or donated arguments. |
advantages = advantages.at[:].set(0.0) # reset advantages | ||
next_value = critic.apply(agent_params.critic_params, next_obs).squeeze() | ||
lastgaelam = 0 | ||
for t in reversed(range(args.num_steps)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was looking through your codes to get some idea about how other people were writing RL algos in jax (and how far people jited things) and think this might be an issue during the first compile step. The for loop will basically be unrolled and when I tried this the compile time was very long especially if args.num_steps is big.
Ended up using jax.lax.scan
and replaced the loop like this (code doesn't fit yours exactly but idea is there):
not_dones = ~dones
value_diffs = gamma * values[1:] * not_dones - values[:-1]
deltas = rewards + value_diffs
def body_fun(gae, t):
gae = deltas[t] + gamma * gae_lambda * not_dones[t] * gae
return gae, gae
indices = jnp.arange(N)[::-1]
gae, advantages = jax.lax.scan(body_fun, 0.0, indices,)
advantages = advantages[::-1]
Also avoids using the .at and .set functions (of which im still not sure of what the performance is). Maybe this might be useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use reverse=True
in the scan so you don't have to flip it.
Jitting the epochs in update_ppo() results in extremely high start up times for high epoch values and doesn't provide any speed after it's finally running.
|
envs = gym.wrappers.ClipAction(envs) | ||
envs = gym.wrappers.NormalizeObservation(envs) | ||
envs = gym.wrappers.TransformObservation(envs, lambda obs: np.clip(obs, -10, 10)) | ||
envs = gym.wrappers.NormalizeReward(envs) | ||
envs = gym.wrappers.TransformReward(envs, lambda reward: np.clip(reward, -10, 10)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is desirable to implement these in jax, which should help speed up the training progress and will allow us to use the XLA interface in the future.
I think it's worth changing to
The command used is: Note: The training code was removed as the collection time correlates with the |
@51616 thanks for raising this issue. Could you share the snippet that derived these numbers?
|
@vwxyzjn Here's the code
I can make a PR for this. I also think we should use the output of the
The code is a bit cleaner and uses the output from |
@vwxyzjn Was there any reason why this wasn't merged in the end? |
Nothing really. If you’d like free free to take on the PR :) |
Description
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).