Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Update PPO examples #1495

Merged
merged 29 commits into from
Sep 21, 2023
Merged

Conversation

albertbou92
Copy link
Contributor

@albertbou92 albertbou92 commented Sep 6, 2023

Description

Updated PPO examples. Now the scripts reproduce the results from Atari and MuJoCo environments in the original PPO paper.
Some common improvements are added (like computing the advantage at every epoch)

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 6, 2023
@vmoens vmoens added the new algo New algorithm request or PR label Sep 6, 2023
@albertbou92
Copy link
Contributor Author

This is my suggestion to update the PPO examples. I have 1 script for MuJoCo and one for Atari since the architectures and the env transforms are different. I think it is easier to read, but we could also have a single script since the rest is almost the same. Maybe we could all review it as @matteobettini suggested and agree on a template for the other examples? what do you think? @vmoens @BY571

@vmoens
Copy link
Contributor

vmoens commented Sep 8, 2023

I think it's great. Having 2 separates files is fine, it's better to have clarity than highly engineered, poorly tested and unreadable single scripts :) .
We can (if you want) write a .md file in the directory that says that we have 2 files for clarity, even though it's self explanatory...


# Test logging
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if (collected_frames - frames_in_batch) // cfg.logger.test_interval < (
Copy link
Contributor

@matteobettini matteobettini Sep 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you use a variable from the enumeration of the collector here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doing it implicitly since collected_frames is the "number of batches collected" * frames_in_batch.
I can make it more explicit:

if (i - 1) * frames_in_batch % test_interval < i * frames_in_batch % test_interval:

Copy link
Contributor

@matteobettini matteobettini Sep 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not if i % cfg.eval.evaluation_interval == 0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah because your test interval is in frames, then if collected_frames_till_now % test_interval == 0?

episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
logger.log_scalar(
"reward_train", episode_rewards.mean().item(), collected_frames
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i usually use names like "train/reward", this will make wandb automatically divide them into different panels.

here is the stuff i log for training

and this some for eval

"eval/episode_reward_min": min(rewards),

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we at least need to time the scripts and log times both for collection and training

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to use this kind of naming. Probably in each example the metrics will vary a bit, but at least we can agree to make them always "train/..." and "eval/...".

Adding the timing also makes sense

@BY571 BY571 mentioned this pull request Sep 14, 2023
10 tasks
@vmoens
Copy link
Contributor

vmoens commented Sep 17, 2023

@albertbou92 we also need to update the examples CI (cc @BY571)

@vmoens vmoens changed the title [Feature] Update PPO examples [Algorithm] Update PPO examples Sep 18, 2023
examples/ppo/ppo_mujoco.py Show resolved Hide resolved
examples/ppo/ppo_atari.py Show resolved Hide resolved
examples/ppo/utils_atari.py Outdated Show resolved Hide resolved
examples/ppo/utils_mujoco.py Outdated Show resolved Hide resolved
@@ -0,0 +1,116 @@
import gym
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're using HalfCheeta-v4, we can use gymnasium no?
This version of halfcheetah with the gym from the CI makes it crash.
The CI uses gym 0.23 for D4RL compatibility, and gymnasium for newer stuff.
(welcome to gym wonderland)

Copy link
Contributor Author

@albertbou92 albertbou92 Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used HalfCheetah-v4 in case the CI env did not have MuJoCo. The default config has HalfCheetah-v3.

But using gymnasium seems to work out of the box, so I changed that.

@albertbou92
Copy link
Contributor Author

for atari, we can now speedup training with the new vectorised envs right?

@vmoens
Copy link
Contributor

vmoens commented Sep 18, 2023

I think, I can check

@vmoens
Copy link
Contributor

vmoens commented Sep 20, 2023

@albertbou92 the examples test are failing

@albertbou92
Copy link
Contributor Author

@albertbou92 the examples test are failing

solved! also for A2C @vmoens

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@vmoens vmoens merged commit fc9794d into pytorch:main Sep 21, 2023
50 of 59 checks passed
@vmoens vmoens deleted the update_ppo_example branch September 21, 2023 12:41
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Co-authored-by: vmoens <vincentmoens@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. new algo New algorithm request or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants