-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] - Get and set states in MultiAgentEpisode
and SingleAgentEpisode
#45012
[RLlib] - Get and set states in MultiAgentEpisode
and SingleAgentEpisode
#45012
Conversation
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
…get_state' and 'from_state' to 'SingleAGentEpisode' together with test. Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
…ith a test. Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
@@ -44,7 +44,7 @@ def build(self) -> None: | |||
# Note that the KL coeff is not controlled by a Scheduler, but seeks | |||
# to stay close to a given kl_target value in our implementation of | |||
# `self.additional_update_for_module()`. | |||
self.curr_kl_coeffs_per_module: Dict[ModuleID, Scheduler] = LambdaDefaultDict( | |||
self.curr_kl_coeffs_per_module: Dict[ModuleID, TensorType] = LambdaDefaultDict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch
rllib/env/multi_agent_episode.py
Outdated
@@ -1704,50 +1704,84 @@ def get_state(self) -> Dict[str, Any]: | |||
return list( | |||
{ | |||
"id_": self.id_, | |||
"agent_ids": self.agent_ids, | |||
"agent_to_module_mapping_fn": self.agent_to_module_mapping_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, let's make the state a dict. States should always be Dict[str, Any]
. I think this is a leftover from the early DreamerV3 days :)
rllib/env/multi_agent_episode.py
Outdated
}.items() | ||
) | ||
|
||
@staticmethod | ||
def from_state(state) -> None: | ||
def from_state(state) -> "MultiAgentEpisode": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here: typehint: state: Dict[str, Any]
rllib/env/multi_agent_episode.py
Outdated
"""Creates a multi-agent episode from a state dictionary. | ||
|
||
See `MultiAgentEpisode.get_state()` for creating a state for | ||
a `MultiAgentEpisode` pickable state. For recreating a | ||
`MultiAgentEpisode` from a state, this state has to be complete, | ||
i.e. all data must have been stored in the state. | ||
|
||
Args: | ||
state: A list of tuples containing all data required to recreate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, let's make this a dict.
rllib/env/single_agent_episode.py
Outdated
@@ -1643,6 +1643,75 @@ def agent_steps(self) -> int: | |||
""" | |||
return self.env_steps() | |||
|
|||
def get_state(self) -> list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here: dict
@@ -342,6 +342,24 @@ def test_sample_with_modules_to_sample(self): | |||
# Assert that all n-steps are 1.0 as passed into `sample`. | |||
self.assertTrue(np.all(n_steps - 1.0 < tolerance)) | |||
|
|||
# def test_get_state_and_set_state(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Just needs the changes from list to dict, then can be merged. :)
Thanks @simonsays1980 !
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Why are these changes needed?
This PR adds
get_state
andfrom_state
toMultiAgentEpisode
andSingleAgentEpisode
. This is needed for checkpointingEpisodeReplayBuffer
objects.Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.