-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] observation function api for multi-agent #8236
Conversation
Can one of the admins verify this patch? |
Test PASSed. |
Test PASSed. |
>>> # Observer that merges global state into individual obs. It is | ||
... # rewriting the discrete obs into a tuple with global state. | ||
>>> example_obs_fn1({"a": 1, "b": 2, "global_state": 101}, ...) | ||
{"a": [1, 101], "b": [2, 101]} | ||
|
||
>>> # Observer for e.g., custom centralized critic model. It is | ||
... # rewriting the discrete obs into a dict with more data. | ||
>>> example_obs_fn2({"a": 1, "b": 2}, ...) | ||
{"a": {"self": 1, "other": 2}, "b": {"self": 2, "other": 1}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meta (tip): .. code-block:: python
is easier to type and renders equally well.
TODO(ekl): enable batch processing. | ||
|
||
Args: | ||
agent_obs (dict): Dictionary of default observations from the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tip - if you put Dict[AgentID, TensorType]
, readthedocs automatically hyperlinks it.
def central_critic_observer(agent_obs, **kw): | ||
"""Rewrites the agent obs to include opponent data for training.""" | ||
|
||
to_update = info["post_batch"][SampleBatch.CUR_OBS] | ||
my_id = info["agent_id"] | ||
other_id = 1 if my_id == 0 else 0 | ||
action_encoder = ModelCatalog.get_preprocessor_for_space(Discrete(2)) | ||
|
||
# set the opponent actions into the observation | ||
_, opponent_batch = info["all_pre_batches"][other_id] | ||
opponent_actions = np.array([ | ||
action_encoder.transform(a) | ||
for a in opponent_batch[SampleBatch.ACTIONS] | ||
]) | ||
to_update[:, -2:] = opponent_actions | ||
new_obs = { | ||
0: { | ||
"own_obs": agent_obs[0], | ||
"opponent_obs": agent_obs[1], | ||
"opponent_action": 0, # filled in by FillInActions | ||
}, | ||
1: { | ||
"own_obs": agent_obs[1], | ||
"opponent_obs": agent_obs[0], | ||
"opponent_action": 0, # filled in by FillInActions | ||
}, | ||
} | ||
return new_obs | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would this look different if you subclass the ObservationFunction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably just move it into call.
@@ -1,9 +1,9 @@ | |||
"""An example of implementing a centralized critic by modifying the env. | |||
"""An example of implementing a centralized critic with ObservationFunction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
user may ask, what is "ObservationFunction" given you don't use it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the interface definition. You can extend it or not, I don't know how to document a function signature otherwise though
Test PASSed. |
Why are these changes needed?
This adds an observation function that sits between multi agent envs and policies. It can handle communication / data sharing between local observations, or merging global observations into local observations.
MultiAgentEnv -> policies
MultiAgentEnv -> obs_func -> policies
In the future, the obs func can also be made differentiable to enable shared computation with multi-agent.
Lightly documented; will add more documentation as we iterate on the API.