-
Notifications
You must be signed in to change notification settings - Fork 577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug in actor loss for sac_continuous_action.py #379
Comments
Yeah, the version using It might be worthwhile to investigate why this hasn't been an issue previously though. Your fix should work @terencenwz . If you wanna do a PR, I can merge it if that's also fine for @dosssman and @vwxyzjn |
Thanks, appreciate it. Will come back to this in the middle of the week if no changes by then. |
Thanks for raising this issue @terencenwz. Alongside with the PR, we should probably re-run the benchmark experiments as well given that this is a performance-impacting change. The specific steps are listed at https://docs.cleanrl.dev/contribution/#rlops-for-performance-impacting-changes |
After further tests, I found that the outer product gives the same mean, so the actor_loss is actually unaffected. |
@terencenwz Thanks, I will update #383 with your suggested change. Im planning on running performance tests this week |
Problem Description
In the following line
cleanrl/cleanrl/sac_continuous_action.py
Line 270 in 9f8b64b
should be
Or else
(alpha * log_pi) - min_qf_pi produces a matrix of [batch_size x batch_size] instead of just [batch_size]
and gives a different actor loss from my tests:
min_qf_pi.shape: torch.Size([8])
log_pi.shape: torch.Size([8, 1])
((alpha * log_pi) - min_qf_pi):
tensor([[ 8.7687, 8.6482, 5.3872, 8.6279, 8.7512, 6.9031, 7.5819, 5.7800],
[ 9.0996, 8.9791, 5.7181, 8.9588, 9.0821, 7.2340, 7.9129, 6.1109],
[ 4.5497, 4.4292, 1.1682, 4.4089, 4.5323, 2.6841, 3.3630, 1.5610],
[ 9.8283, 9.7078, 6.4468, 9.6875, 9.8109, 7.9627, 8.6416, 6.8396],
[ 9.3948, 9.2743, 6.0133, 9.2540, 9.3773, 7.5292, 8.2081, 6.4061],
[ 6.0864, 5.9659, 2.7049, 5.9456, 6.0689, 4.2208, 4.8996, 3.0977],
[ 3.0503, 2.9298, -0.3312, 2.9095, 3.0328, 1.1847, 1.8635, 0.0616],
[ 1.6122, 1.4917, -1.7694, 1.4714, 1.5947, -0.2535, 0.4254, -1.3766]],
device='cuda:0', grad_fn=)
The line in the Atari version is correct
cleanrl/cleanrl/sac_atari.py
Line 304 in 9f8b64b
The text was updated successfully, but these errors were encountered: