Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem of auto alpha in SAC #258

Closed
milkpku opened this issue Dec 8, 2020 · 8 comments
Closed

Problem of auto alpha in SAC #258

milkpku opened this issue Dec 8, 2020 · 8 comments
Labels
question Further information is requested

Comments

@milkpku
Copy link

milkpku commented Dec 8, 2020

Dear author,

I find there is inconsistence between your implementation and the algorithm described by original paper https://arxiv.org/abs/1812.05905

instead of

alpha_loss = -(self._log_alpha * log_prob).mean()

the optimization objective for temperature alpha should be
alpha_loss = -(torch.exp(self._log_alpha)* log_prob).mean()

@Trinkle23897
Copy link
Collaborator

@Trinkle23897 Trinkle23897 added the question Further information is requested label Dec 8, 2020
@milkpku
Copy link
Author

milkpku commented Dec 8, 2020

Interesting. I see they've discussed two problems:

  1. Optimize variable: _alpha vs. _log_alpha
  2. Optimize objective: exp(_log_alpha) * log_prob vs. _log_alpha * log_prob

The first problem is made clear that optimizing _log_alpha prevents alpha from being negative, which otherwise will cause NaN error.

But the second problem is not well discussed, from my point of view. Although someone argues that the only effective information is sign of the gradient, but my argument is gradient of exp(_log_alpha) can prevent alpha dropping too fast when it is close to zero. Still, it needs further experiments to get validated.

@dxing-cs
Copy link

dxing-cs commented Dec 13, 2020

I don't understand why we need to maintain an optimizer for alpha? Since the log_prob has be detached, isn't alpha simply determined by:

  1. (log_prob + target_entropy) >0: alpha = inf
  2. (log_prob + target_entropy) <=0: alpha = 0

Another question is why log_prob needs to be detached?

@danagi
Copy link
Collaborator

danagi commented Dec 13, 2020

I don't understand why we need to maintain an optimizer for alpha? Since the log_prob has be detached, isn't alpha simply determined by:

  1. (log_prob + target_entropy) >0: alpha = inf
  2. (log_prob + target_entropy) <=0: alpha = 0

Another question is why log_prob needs to be detached?

Check the paper here. The original objective of alpha is a dual problem which is optimized by approximating dual gradient descent. This is done by alternating between optimizing policy with respect to current alpha and taking a gradient step on alpha. So alpha is not simply determined by the sign of (log_prob + target_entropy) .
Also, optimizing this objective is impratical. So a truncated version is used leading to detaching log_prob.

@dxing-cs
Copy link

I don't understand why we need to maintain an optimizer for alpha? Since the log_prob has be detached, isn't alpha simply determined by:

  1. (log_prob + target_entropy) >0: alpha = inf
  2. (log_prob + target_entropy) <=0: alpha = 0

Another question is why log_prob needs to be detached?

Check the paper here. The original objective of alpha is a dual problem which is optimized by approximating dual gradient descent. This is done by alternating between optimizing policy with respect to current alpha and taking a gradient step on alpha. So alpha is not simply determined by the sign of (log_prob + target_entropy) .
Also, optimizing this objective is impratical. So a truncated version is used leading to detaching log_prob.

Thanks for your reply. I still don't understand why alpha is NOT determined by the sign of (log_prob + target_entropy). Since the optimization is done by alternatively updating the alpha and the policy, it seems for me that when updating alpha, the term (log_prob + target_entropy) can be regarded as constant. (Please correct me if I'm wrong)

@danagi
Copy link
Collaborator

danagi commented Dec 13, 2020

I don't understand why we need to maintain an optimizer for alpha? Since the log_prob has be detached, isn't alpha simply determined by:

  1. (log_prob + target_entropy) >0: alpha = inf
  2. (log_prob + target_entropy) <=0: alpha = 0

Another question is why log_prob needs to be detached?

Check the paper here. The original objective of alpha is a dual problem which is optimized by approximating dual gradient descent. This is done by alternating between optimizing policy with respect to current alpha and taking a gradient step on alpha. So alpha is not simply determined by the sign of (log_prob + target_entropy) .
Also, optimizing this objective is impratical. So a truncated version is used leading to detaching log_prob.

Thanks for your reply. I still don't understand why alpha is NOT determined by the sign of (log_prob + target_entropy). Since the optimization is done by alternatively updating the alpha and the policy, it seems for me that when updating alpha, the term (log_prob + target_entropy) can be regarded as constant. (Please correct me if I'm wrong)

Hi, it seems like you don't know constrained optimization problem or the method of Lagrange multipliers. It's hard to explain them in a few words. A simple but possibly inaccurate explanation is that the optimal alpha is in terms of the optimal policy and since the policy is not optimal during gradient descent so you can not set alpha to some value directly.

@alexnikulkov
Copy link
Contributor

I think we should use alpha instead of log_alpha for 2 reasons:

  1. [Philosophical reason] That's how it's implemented in the SAC paper, so we should stick as close as possible to their implementation, unless we have evidence that our implementation performs better.
  2. [Technical reason] The loss represents a Lagrangian of a constrained optimization problem. Alpha is a dual variable for the entropy constraint. Since entropy constraint is an inequality, the dual variable has to be non-negative. log_alpha takes values in [-inf; +inf], while alpha takes values in [0; +inf]

It looks like they use alpha and not log_alpha in the official SAC repo as well: https://github.com/rail-berkeley/softlearning/blob/master/softlearning/algorithms/sac.py#L256

@Trinkle23897
Copy link
Collaborator

unless we have evidence that our implementation performs better.

https://github.com/thu-ml/tianshou/tree/master/examples/mujoco#sac

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants