Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Discrete SAC #882

Merged
merged 19 commits into from
Mar 24, 2023
Merged

[Feature] Discrete SAC #882

merged 19 commits into from
Mar 24, 2023

Conversation

BY571
Copy link
Contributor

@BY571 BY571 commented Jan 30, 2023

Description

Adding a discrete SAC example

Motivation and Context

Current SAC implementation only supports continuous action spaces. This PR will add the option to run a discrete SAC example based on the paper.

Convergence proof tested on CartPole-v1 (wandb)
image

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 30, 2023
@BY571 BY571 marked this pull request as ready for review January 31, 2023 10:18
@vmoens vmoens added enhancement New feature or request new algo New algorithm request or PR labels Feb 6, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
Before landing:

  • can we move the loss to sac.py? I'd rather have them all in the same place if that makes sense?
  • Can we add the loss to the doc?
  • Is this supposed to work with gSDE? gSDE is not tailored for discrete action spaces AFAICT

torchrl/objectives/discrete_sac.py Outdated Show resolved Hide resolved
torchrl/objectives/discrete_sac.py Outdated Show resolved Hide resolved
@BY571
Copy link
Contributor Author

BY571 commented Feb 7, 2023

LGTM Before landing:

  • can we move the loss to sac.py? I'd rather have them all in the same place if that makes sense?
  • Can we add the loss to the doc?
  • Is this supposed to work with gSDE? gSDE is not tailored for discrete action spaces AFAICT

Do you mean discrete and continuous sac loss in one objective class or having both losses just in the same file?
I'd prefer to have them in the same class, what do you think? Will have a look at it in the coming days.

Will add it to the doc and also take off the gSDE :)

@vmoens
Copy link
Contributor

vmoens commented Feb 7, 2023

I'd prefer to have them in the same class
How much control flow would that entail?
Does it save a lot of code?

I was thinking of having them in the same file. If having them in the same class does not create a monster class I'm happy to consider it.

Will it work with v1 and v2?

@BY571
Copy link
Contributor Author

BY571 commented Feb 9, 2023

I was thinking of having them in the same file. If having them in the same class does not create a monster class I'm happy to consider it.

Will it work with v1 and v2?

For now, I just added it to the sac.py file in objectives. As it only works with v2 it might get messy and as you said would probably create a monster class. Let me know what you think.

I also took off the gSDE from the loss and updated the description of the actor_network to be a TensorDictModule.

How can I update the docs?

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- let's try to merge this :)

# Conflicts:
#	docs/source/reference/objectives.rst

if target_entropy == "auto":
target_entropy = -float(
np.log(1.0 / action_spec["action"].shape[0]) * target_entropy_weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

careful here: the [0] can be the batch size
maybe the last dimension? Or since it's discrete we can check if it's a one hot or a discrete encoding and directly retrieve the number of options from the spec?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's the only place where we use action_spec
Maybe we could just pass the number of possible actions rather than passing the spec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed that and adapted the example script due to recent TorchRL changes.
However, now some tests fail but I'm at it and hope to resolve them quickly!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the example script issues and updated the objective tests as well as they were getting several errors.
Hopefully ready to merge now! :)

@vmoens vmoens merged commit 8e03f6b into pytorch:main Mar 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request new algo New algorithm request or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants