-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SPS and q-values metrics for value-based methods #126
Conversation
This pull request is being automatically deployed with Vercel (learn more). 🔍 Inspect: https://vercel.com/vwxyzjn/cleanrl/FqjMVLeRjw6sZJgzSowvU9U1RFMr |
@vwxyzjn Regarding the One potential downside that comes to mind however, is that a newcomer that does not know this is the default behavior of pytorch might get confused at first. Although in SB3, they do seem to use the In any case, I floated this change in sac.py last time, but still a little bit on the fence on how appropriate it would be for |
Fair enough, let's go ahead with this change |
@dosssman it's ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late answer. The review request escaped me somehow.
First, I would like to confirm the meaning of SPS.
(Sample || Step per seconds ?)
I have compiled a few minor tweaks that build on top of this PR in #130 .
Hopefully I did not misunderstand the intent of this PR.
Some additional comments are:
- This might just be my OCD, but when printing the SPS for the terminal and tensorboard / wandb logging like here for example:
Lines 213 to 214 in bb8dd13
print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
how about doing something like this ?
SPS = int(global_step / (time.time() - start_time))
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
Furthermore, we could also fuse the SPS and episode return logging here to only print once? :
Lines 184 to 189 in bb8dd13
for info in infos: | |
if "episode" in info.keys(): | |
print(f"global_step={global_step}, episodic_return={info['episode']['r']}") | |
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) | |
writer.add_scalar("charts/epsilon", epsilon, global_step) | |
break |
- The update to gym=0.23.0 breaks a few scripts, namely the
rnd_ppo.py
,apex_dqn_atari
, and the 'offline/*scripts due to the
Monitor` gym wrapper being removed form that version. This is the reason Tweaks for PR 'add-sps-qvalues' #130 tests break.
Regarding cleanrl/offline
- SPS logging is missing
- Monitor cannot be imported anymore due to gym=0.23.0 update
- Tests for those two scripts are missing
- Unlike
dqn_atari
, the wrappers are not imported from SB3 - The offline-env-id that is required to load the dataset does not seem to work anymore. Is there any dependency missing, such as
d4rl
ord4rl_atari
for example ?
Since the offline
scripts are not really related to this PR, I did not go too much in detail.
We could merge this PR while opening an issue for the offline scripts, if not already done.
In any case, great job as always.
* Addtional tweaks regarding SPS and Q values, added tests for Atari related scripts while at it * test_atari.py: fixed the apex_dqn_atari.py test * quick fix * Fix pre-commit Co-authored-by: Costa Huang <costa.huang@outlook.com>
Merging now. |
This PR adds SPS and q-values metrics for value methods.
This PR also use
F.mse_loss()
instead ofloss_fn = nn.MSELoss()
@dosssman do you think we should also use
using qf1() instead, which calls forward() by default
?