-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Algorithm] Update PPO examples #1495
Conversation
This is my suggestion to update the PPO examples. I have 1 script for MuJoCo and one for Atari since the architectures and the env transforms are different. I think it is easier to read, but we could also have a single script since the rest is almost the same. Maybe we could all review it as @matteobettini suggested and agree on a template for the other examples? what do you think? @vmoens @BY571 |
I think it's great. Having 2 separates files is fine, it's better to have clarity than highly engineered, poorly tested and unreadable single scripts :) . |
9d1e8a8
to
f2c8e68
Compare
examples/ppo/ppo_mujoco.py
Outdated
|
||
# Test logging | ||
with torch.no_grad(), set_exploration_type(ExplorationType.MODE): | ||
if (collected_frames - frames_in_batch) // cfg.logger.test_interval < ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't you use a variable from the enumeration of the collector here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doing it implicitly since collected_frames is the "number of batches collected" * frames_in_batch.
I can make it more explicit:
if (i - 1) * frames_in_batch % test_interval < i * frames_in_batch % test_interval:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not if i % cfg.eval.evaluation_interval == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah because your test interval is in frames, then if collected_frames_till_now % test_interval == 0
?
examples/ppo/ppo_mujoco.py
Outdated
episode_rewards = data["next", "episode_reward"][data["next", "done"]] | ||
if len(episode_rewards) > 0: | ||
logger.log_scalar( | ||
"reward_train", episode_rewards.mean().item(), collected_frames |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i usually use names like "train/reward", this will make wandb automatically divide them into different panels.
here is the stuff i log for training
rl/examples/multiagent/utils/logging.py
Line 76 in 147de71
to_log.update( |
and this some for eval
rl/examples/multiagent/utils/logging.py
Line 122 in 147de71
"eval/episode_reward_min": min(rewards), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we at least need to time the scripts and log times both for collection and training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to use this kind of naming. Probably in each example the metrics will vary a bit, but at least we can agree to make them always "train/..." and "eval/...".
Adding the timing also makes sense
@albertbou92 we also need to update the examples CI (cc @BY571) |
47b71e4
to
ebe28a6
Compare
examples/ppo/utils_mujoco.py
Outdated
@@ -0,0 +1,116 @@ | |||
import gym |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you're using HalfCheeta-v4, we can use gymnasium no?
This version of halfcheetah with the gym from the CI makes it crash.
The CI uses gym 0.23 for D4RL compatibility, and gymnasium for newer stuff.
(welcome to gym wonderland)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used HalfCheetah-v4 in case the CI env did not have MuJoCo. The default config has HalfCheetah-v3.
But using gymnasium seems to work out of the box, so I changed that.
for atari, we can now speedup training with the new vectorised envs right? |
I think, I can check |
@albertbou92 the examples test are failing |
solved! also for A2C @vmoens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Co-authored-by: vmoens <vincentmoens@gmail.com>
Description
Updated PPO examples. Now the scripts reproduce the results from Atari and MuJoCo environments in the original PPO paper.
Some common improvements are added (like computing the advantage at every epoch)
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!