Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] RewardSum transform for multiple reward keys #1544

Merged
merged 22 commits into from
Oct 2, 2023

Conversation

matteobettini
Copy link
Contributor

This PR extends the RewardSum to work with multiple reward keys.

Signed-off-by: Matteo Bettini <matbet@meta.com>
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 19, 2023
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
@matteobettini matteobettini marked this pull request as ready for review September 19, 2023 15:03
@matteobettini matteobettini changed the title [BugFIx] RewardSum transform for multiple reward keys [BugFix] RewardSum transform for multiple reward keys Sep 19, 2023
@matteobettini matteobettini added the bug Something isn't working label Sep 19, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this supposed to work with many done signals?
I think we should wait for #1539 to land first

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
@matteobettini
Copy link
Contributor Author

This transform is supposed to work according to the MARL grouping api.

It will look for the _reset entry in the root and consider it the default reset
For each reward key then, if it finds a _reset in its tensordict it will use that instead.

so each reward key will be by default associated with the _reset in its td and, if that is not present, it will be associated with the _reset in root (if present)

the behavior aligns with the MARL api

Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
@matteobettini
Copy link
Contributor Author

it should not be dependent from #1539 if we follow this logic

Signed-off-by: Matteo Bettini <matbet@meta.com>
@vmoens
Copy link
Contributor

vmoens commented Sep 22, 2023

This transform is supposed to work according to the MARL grouping api.

It will look for the _reset entry in the root and consider it the default reset

I thought our default was that there wasn't a "_reset" at the root (only if there is a "done" at the root)?

@vmoens
Copy link
Contributor

vmoens commented Sep 22, 2023

To be clear, the direction I understood we were moving towards:

TensorDict({"_reset": reset, "nested": {"done": done}), []) # no allowed! 
TensorDict({"nested": {"_reset": reset, "done": done}), []) # allowed
TensorDict({"_reset": reset, "done": done}, []) # allowed

The first is not allowed because:

  • the "_reset" is not part of the env.reset_keys which MUST be an iterable that follows the grouped done
  • it is unclear what this reset refers to and what to do with it in specific cases: does it prevail over sub-resets? do we allow any arbitrary shape? There are way too many decisions to be taken IMO

Signed-off-by: Matteo Bettini <matbet@meta.com>
@matteobettini
Copy link
Contributor Author

i have updated with what discussed, now i ll just have to test it

@matteobettini matteobettini marked this pull request as draft September 22, 2023 14:08
@vmoens vmoens marked this pull request as ready for review October 1, 2023 06:08
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If @matteobettini you're happy with 8da298f I'm good with merging this

Copy link
Contributor Author

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM some comments

If no ``in_keys`` are specified, this transform assumes ``"reward"`` to be the input key.
However, multiple rewards (e.g. ``"reward1"`` and ``"reward2""``) can also be specified.
out_keys (list of NestedKeys, optional): The output sum keys, should be one per each input key.
reset_keys (list of NestedKeys, optional): the list of reset_keys to be
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here i preferred having done keys rather than reset keys, this is because users are familiar with what a done key is and could not know about reset keys. Plus there is a 1:1 matching between the 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not the way it was done: if I pass env.done_keys there are some duplicates (eg, truncation / termination / done).
Having a "done_keys" list to me is more dangerous because of this, and eventually the only thing we're pointing is the tree structure where the reset_keys should be found. I personally prefer to pass reset_keys: it's what is needed, ie there is a lower risk that refactoring the done_keys mechanism in the future will break this transform. Per se asking users to pass a list X when we interpolate a list Y that is present within the env already as env.Y seems a convoluted way of doing things.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thought about this: per se most users won't need to pass reset keys. We just support it if someone really wants to do nasty things like summing part of the rewards but not all etc. Advanced usage requires advanced understanding so it's fine to ask for reset_keys even this isn't something that is always user-facing.

test/test_transforms.py Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Show resolved Hide resolved
@matteobettini matteobettini marked this pull request as draft October 2, 2023 08:36
@matteobettini matteobettini marked this pull request as ready for review October 2, 2023 08:44
@matteobettini matteobettini marked this pull request as draft October 2, 2023 08:47
@vmoens vmoens marked this pull request as ready for review October 2, 2023 12:30
@matteobettini
Copy link
Contributor Author

LGTM

@vmoens vmoens merged commit 1697102 into pytorch:main Oct 2, 2023
54 of 59 checks passed
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Signed-off-by: Matteo Bettini <matbet@meta.com>
Co-authored-by: vmoens <vincentmoens@gmail.com>
@matteobettini matteobettini deleted the fix_reward_sum branch December 4, 2023 11:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants