Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in actor loss for sac_continuous_action.py #379

Closed
terencenwz opened this issue May 5, 2023 · 5 comments
Closed

Bug in actor loss for sac_continuous_action.py #379

terencenwz opened this issue May 5, 2023 · 5 comments
Assignees

Comments

@terencenwz
Copy link

terencenwz commented May 5, 2023

Problem Description

In the following line

min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)

min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)
actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

should be

min_qf_pi = torch.min(qf1_pi, qf2_pi)
actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

Or else
(alpha * log_pi) - min_qf_pi produces a matrix of [batch_size x batch_size] instead of just [batch_size]
and gives a different actor loss from my tests:
min_qf_pi.shape: torch.Size([8])
log_pi.shape: torch.Size([8, 1])
((alpha * log_pi) - min_qf_pi):
tensor([[ 8.7687, 8.6482, 5.3872, 8.6279, 8.7512, 6.9031, 7.5819, 5.7800],
[ 9.0996, 8.9791, 5.7181, 8.9588, 9.0821, 7.2340, 7.9129, 6.1109],
[ 4.5497, 4.4292, 1.1682, 4.4089, 4.5323, 2.6841, 3.3630, 1.5610],
[ 9.8283, 9.7078, 6.4468, 9.6875, 9.8109, 7.9627, 8.6416, 6.8396],
[ 9.3948, 9.2743, 6.0133, 9.2540, 9.3773, 7.5292, 8.2081, 6.4061],
[ 6.0864, 5.9659, 2.7049, 5.9456, 6.0689, 4.2208, 4.8996, 3.0977],
[ 3.0503, 2.9298, -0.3312, 2.9095, 3.0328, 1.1847, 1.8635, 0.0616],
[ 1.6122, 1.4917, -1.7694, 1.4714, 1.5947, -0.2535, 0.4254, -1.3766]],
device='cuda:0', grad_fn=)

The line in the Atari version is correct

min_qf_values = torch.min(qf1_values, qf2_values)

@timoklein
Copy link
Collaborator

Yeah, the version using min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1) computes an outer product due to the different shapes. That's why the sac_atari version omits the .view(-1) because I ran into the same issue.

It might be worthwhile to investigate why this hasn't been an issue previously though.

Your fix should work @terencenwz . If you wanna do a PR, I can merge it if that's also fine for @dosssman and @vwxyzjn

@dosssman
Copy link
Collaborator

dosssman commented May 6, 2023

Thanks, appreciate it.

Will come back to this in the middle of the week if no changes by then.

@vwxyzjn
Copy link
Owner

vwxyzjn commented May 6, 2023

Thanks for raising this issue @terencenwz. Alongside with the PR, we should probably re-run the benchmark experiments as well given that this is a performance-impacting change. The specific steps are listed at https://docs.cleanrl.dev/contribution/#rlops-for-performance-impacting-changes

@terencenwz
Copy link
Author

After further tests, I found that the outer product gives the same mean, so the actor_loss is actually unaffected.

@pseudo-rnd-thoughts
Copy link
Collaborator

@terencenwz Thanks, I will update #383 with your suggested change. Im planning on running performance tests this week

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants