Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SPS and q-values metrics for value-based methods #126

Merged
merged 7 commits into from
Mar 9, 2022

Conversation

vwxyzjn
Copy link
Owner

@vwxyzjn vwxyzjn commented Feb 28, 2022

This PR adds SPS and q-values metrics for value methods.

This PR also use F.mse_loss() instead of loss_fn = nn.MSELoss()

image
image

@dosssman do you think we should also use using qf1() instead, which calls forward() by default?

@gitpod-io
Copy link

gitpod-io bot commented Feb 28, 2022

@vercel
Copy link

vercel bot commented Feb 28, 2022

This pull request is being automatically deployed with Vercel (learn more).
To see the status of your deployment, click below or on the icon next to each commit.

🔍 Inspect: https://vercel.com/vwxyzjn/cleanrl/FqjMVLeRjw6sZJgzSowvU9U1RFMr
✅ Preview: https://cleanrl-git-add-sps-qvalues-vwxyzjn.vercel.app

@dosssman
Copy link
Collaborator

@vwxyzjn Regarding the qf1() instead of qf1.forward() I have a subjective preference because it just looks "cleaner" to me, and that is what I use in my code bases.

One potential downside that comes to mind however, is that a newcomer that does not know this is the default behavior of pytorch might get confused at first.
Using qf1.forward(), it could be easier for the reader to realize that this operation is used the execute the logic coded in the def forward(self, ...) of the Qnetwork class.

Although in SB3, they do seem to use the qf1() way I think:
https://github.com/DLR-RM/stable-baselines3/blob/cdaa9ab418aec18f41c7e8e12e0ad28f238553eb/stable_baselines3/common/torch_layers.py#L233

In any case, I floated this change in sac.py last time, but still a little bit on the fence on how appropriate it would be for cleanrl.

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Feb 28, 2022

Fair enough, let's go ahead with this change qf1() instead of qf1.forward().

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Feb 28, 2022

@dosssman it's ready for review.

Copy link
Collaborator

@dosssman dosssman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late answer. The review request escaped me somehow.

First, I would like to confirm the meaning of SPS.
(Sample || Step per seconds ?)

I have compiled a few minor tweaks that build on top of this PR in #130 .
Hopefully I did not misunderstand the intent of this PR.

Some additional comments are:

  • This might just be my OCD, but when printing the SPS for the terminal and tensorboard / wandb logging like here for example:
    print("SPS:", int(global_step / (time.time() - start_time)))
    writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

how about doing something like this ?

SPS = int(global_step / (time.time() - start_time))
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

Furthermore, we could also fuse the SPS and episode return logging here to only print once? :

for info in infos:
if "episode" in info.keys():
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

  • The update to gym=0.23.0 breaks a few scripts, namely the rnd_ppo.py, apex_dqn_atari, and the 'offline/*scripts due to theMonitor` gym wrapper being removed form that version. This is the reason Tweaks for PR 'add-sps-qvalues' #130 tests break.

Regarding cleanrl/offline

  • SPS logging is missing
  • Monitor cannot be imported anymore due to gym=0.23.0 update
  • Tests for those two scripts are missing
  • Unlike dqn_atari, the wrappers are not imported from SB3
  • The offline-env-id that is required to load the dataset does not seem to work anymore. Is there any dependency missing, such as d4rl or d4rl_atari for example ?

Since the offline scripts are not really related to this PR, I did not go too much in detail.
We could merge this PR while opening an issue for the offline scripts, if not already done.

In any case, great job as always.

* Addtional tweaks regarding SPS and Q values, added tests for Atari related scripts while at it

* test_atari.py: fixed the apex_dqn_atari.py test

* quick fix

* Fix pre-commit

Co-authored-by: Costa Huang <costa.huang@outlook.com>
@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Mar 9, 2022

Merging now.

@vwxyzjn vwxyzjn merged commit 2828e83 into master Mar 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants