Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The choice of the actiion decoder #11

Closed
SiyuanHuang95 opened this issue May 29, 2023 · 3 comments
Closed

The choice of the actiion decoder #11

SiyuanHuang95 opened this issue May 29, 2023 · 3 comments

Comments

@SiyuanHuang95
Copy link

Hi, I noticed that you used torch.distributions.Distribution after MLP to get the final output, could you share some insights about this choice? What's the advantage compared with the direct usage of MLP and softmax?

Also, for the training procedure, should we ignore that header, and direct apply NLL loss with the output of MLP, or should we apply the NLL with the probability of that distribution? If also, could you give some simple code snippets to demonstrate the training usage?

BTW, congrats on the acceptance of ICML, well done!

Bests,

@yunfanjiang
Copy link
Member

Hi there,

Thank you for your congratulatory words. To answer your questions

I noticed that you used torch.distributions.Distribution after MLP to get the final output, could you share some insights about this choice? What's the advantage compared with the direct usage of MLP and softmax?

Theoretically there is no difference between using categorical distribution and MLP + softmax. Personally, I found using torch distributions to be convenient since they implement uniformed interfaces that can work with different strategies to model action heads.

Also, for the training procedure, should we ignore that header, and direct apply NLL loss with the output of MLP, or should we apply the NLL with the probability of that distribution? If also, could you give some simple code snippets to demonstrate the training usage?

Sure, in the discrete case, let's say dist is a torch.distributions.Categorical instance predicted by the model, label is the discretized action, the loss is calculated with torch.nn.functional.cross_entropy. Since it takes unnormalized logits as inputs, we can just pass dist.logits (with proper reshape if necessary) into the loss function. For continuous case with unimodal Gaussian or GMM, I'd recommend to checkout these snippets: here and here.

@SiyuanHuang95
Copy link
Author

SiyuanHuang95 commented Jun 5, 2023

Great thanks for your @yunfanjiang reply and informative hints!

  1. MLP + Softmax case: Okay, I got it. BTW, I noticed that many works use MSE loss to train the policy network, turning the training into the regression problem. Have you ever conducted some experiments to compare them?

  2. Okay, thanks. But I noticed in your work you chose to use discretized ones. So what would be the big different between them?

@yunfanjiang
Copy link
Member

yunfanjiang commented Jun 9, 2023

Great thanks for your @yunfanjiang reply and informative hints!

  1. MLP + Softmax case: Okay, I got it. BTW, I noticed that many works use MSE loss to train the policy network, turning the training into the regression problem. Have you ever conducted some experiments to compare them?
  2. Okay, thanks. But I noticed in your work you chose to use discretized ones. So what would be the big different between them?

Thanks for the followup. To answer them

  1. I assume you were referring to those with continuous actions. In those cases we can totally opt to use a regression loss. However, since GMM is more expressive and can better handle distributional multimodality (which is the case for our benchmark, where multiple solutions exist for a single task), we only experimented with GMM for continuous action case.
  2. In our case we didn't observe significant difference empirically. So we opted to the simpler choice.

Hope these would be helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants