Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Attention Nets: tf #12753

Merged
merged 20 commits into from
Dec 21, 2020
Merged

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Dec 10, 2020

RLlib's attention nets (GTrXL) have been forced so far to run "inside" RLlib's RNN API (previous internal states are being passed as new state-ins in subsequent timesteps). This is not favorable for attention nets, which need a different handling and time-slicing of past states (attention net's memory). The trajectory view API allows for specifying the needed time-step ranges for forward passes and batched train passes through attention nets.
Besides the above, the handling of the tau-memory of attention nets was also not correct. This PR fixes existing bugs.
In a follow up PR, the torch version of GTrXL will be fully included in the testing as well (to make sure it's 100% en-par with the tf version).

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

input_dict = policy.model.get_input_dict(sample_batch, index=-1)
last_r = policy._value(**input_dict)
# TODO: (sven) Remove once trajectory view API is all-algo default.
state_in_view_req = policy.model.inference_view_requirements.get(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried solving this for now w/o an extra get_input_dict method in the model or the sample collector.
This is however inefficient as we really only need the last timestep of the trajectory (especially wasteful for large observation and internal state spaces).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the motivation for not using model.get_input_dict()? The previous code seemed much cleaner, whereas the current implementation requires the user to know a lot of details of the internal implementation. For code in policy, I would expect "trajectory view" internal implementation details to be hidden away.

Copy link
Contributor Author

@sven1977 sven1977 Dec 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for not using the model to produce the input dict (from the postprocessing sample batch) is as follows:
Let's take the simpler RNN case: The Model's view requirements contain "obs", and "state_in_0/1" (shifted from state_out_0/1 by -1). However, the input dict is needed for "next_obs" (very last one in the trajectory) and "state_out_0/1" (very last one in the trajectory). This need to create yet another action/value computation at the very end of the trajectory is specific to the Policy (in this case: PPO), not the Model!
We need to completely separate what the Model must know (it should not need to know, which algo it's being run with) and what the Policy knows (its postprocessing and loss requirements). So Model.get_input_dict cannot solve this.

The best solution is still:

  • Let the Policy define: "I need an (already compiled!) single timestep(!) input dict at the very end of the trajectory in each postprocessing batch, b/c of my need to calculate GAE. However, I do not know or care how this input dict should be built (the Model needs to specify that)."
  • The SampleCollector will provide that input-dict within the postprocessing SampleBatch, thereby abiding to the model's view requirements (obs -> next-obs shift and state_in_0/1 -> state_out_0/1 shift) and the given trajectory index (-1).
  • Due to the input dict being only for a single action/value computation at the end of the trajectory, we save having to copy next_obs and any state_outs at all. In experiments, this gave another performance boost (~20%), which we are currently missing out on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I see, so basically you're saying the previous implementation wasn't correct since it was using the last "obs" and not the last "next_obs"? I see how this is a problem.

However, can't we still move the code here into some method of the model so the policy code is simplified. I.e., why doesn't this work?

if traj_view_api:
   input_dict = policy.model.get_input_dict(sample_batch, index=MAX_T_PLUS_ONE)   # special index to get the view for the next obs after the end of the rollout
   last_r = policy._value(**input_dict)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, the attention net calculations can still be handled in general purpose code and not in policy-specific code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, that's exactly what I was achieving with the previous "input_dict view" (ViewRequirement(is_input_dict=True, index=-1)). It would have solved the separation-of-concerns goal and makes it very efficient.

a) Make the Policy work w/o it ever knowing about the model it's using.
b) Make the model work w/o ever having to know, which algo it's being run in.
c) Leave logic dealing with ViewReqDicts inside the SampleCollector, which is the only object having direct access to still-in-progress buffers and can handle these requests more efficiently.

MAX_T_PLUS_ONE is just as hackish, isn't it? Using an int (e.g. -1) wouldn't work as it could be confused with an actual index. So a string?

Copy link
Contributor

@ericl ericl Dec 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so it seems we agree that the policy shouldn't be handling model specific code. The difference is what the general purpose code should look like.

IMO, the problem with input dict view was it's confusing what it means (you'd have to explain to me what it means in terms of traj view requirements). I disagree with requirement (c) also, I really don't think this is perf critical and we shouldn't have to sacrifice code simplicity for this.

If you think the MAX_T_PLUS_ONE or -1 is confusing, why not add some helper method like get_input_dict_for_traj_end which is pretty easy to understand ("this returns model inputs for the single observation after the end of the trajectory, useful for value bootstrapping")?

Copy link
Contributor

@ericl ericl Dec 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the optimization, I think the longer-term solution is to "always" track the 1-step extra metadata in the sample batch / collectors so it's possible to calculate get_input_dict_for_traj_end with no extra overheads.

The "sometimes we track this extra thing" is unnecessarily confusing.

tf.convert_to_tensor([1]))
# [0] = remove the batch dim.
return self.model.value_function()[0]
@make_tf_callable(self.get_session())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above: Went back to simpler solution.

], convert_to_torch_tensor(np.asarray([1]), self.device))
# [0] = remove the batch dim.
return self.model.value_function()[0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

np.zeros(
shape=view_req.space.shape,
dtype=view_req.space.dtype) for _ in range(shift)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is becoming more complex as we now allow for more sophisticated view requirements (ranges of timesteps).

@@ -273,20 +355,22 @@ def build(self):
this policy.
"""
# Create batch from our buffers.
batch = SampleBatch(self.buffers)
assert SampleBatch.UNROLL_ID in batch.data
batch = SampleBatch(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal states are already time-chunked (only one internal state value every max_seq_len timesteps). That's why we shouldn't compare column sizes in SampleBatch (e.g. the obs column will have a different batch dim than the state_in_o one)

view_req.shift - (
1 if data_col in [SampleBatch.OBS, "t", "env_id",
SampleBatch.AGENT_INDEX] else 0)
delta = -1 if data_col in [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also more complex now to build an action input dict: Inputs could be over some range of time steps.

@@ -314,29 +318,6 @@ def is_time_major(self) -> bool:
"""
return self.time_major is True

# TODO: (sven) Experimental method.
def get_input_dict(self, sample_batch,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer needed due to simplification in PPO's (and other algos) postprocessing function.

Copy link
Contributor

@ericl ericl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a first pass over the interface changes. I think the key questions here are:

  • Why do we need to add ModelV2.preprocess_train_batch, is there a fundamental problem here not handled by the traj view API?
  • Why increase the complexity of "user code" (policy code) by removing model.get_input_dict()?

input_dict = policy.model.get_input_dict(sample_batch, index=-1)
last_r = policy._value(**input_dict)
# TODO: (sven) Remove once trajectory view API is all-algo default.
state_in_view_req = policy.model.inference_view_requirements.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the motivation for not using model.get_input_dict()? The previous code seemed much cleaner, whereas the current implementation requires the user to know a lot of details of the internal implementation. For code in policy, I would expect "trajectory view" internal implementation details to be hidden away.

return ret

# TODO: (sven) Experimental method.
def preprocess_train_batch(self, train_batch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on the necessity of this API? It would be preferable to not need to add new public API surface

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thinking about this further.

  • I introduced this b/c I thought that -padding RNN inputs vs attention inputs was very different (it's actually not that much different).
  • More importantly: The logic should stay in the view requirements, so maybe we can leave the padding function call outside the model, but allow it to peek at the view requirements to know, what to do.

I'll change this.

@ericl ericl added @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. and removed @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. labels Dec 10, 2020
Copy link
Contributor

@ericl ericl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the problem, but I think we can still do some code-reorganization so the policy class doesn't need to be aware of the Attention Net. This is important so that you can plug in any model to the policy without needing to "make the policy compatible".

@ericl ericl added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Dec 11, 2020
@sven1977
Copy link
Contributor Author

I feel the exact same way. That's why I used the input-dict view approach in the beginning (I explain above, why this would still be the cleanest solution). W/o this option in the ViewRequirement, the other solutions now are more or less hacks.

@sven1977 sven1977 removed the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Dec 13, 2020
@sven1977
Copy link
Contributor Author

Updates:

  • Removed all new ModelV2 API methods (ModelV2 API remains unchanged): preprocess_train_batch AND get_input_dict
  • The RNN-hack for PPO's postprocess_fn has been expanded to work for attention nets as well. Happy to keep this for now. Let's discuss - for a follow-up PR - on how to solve this properly (see initially suggested "input_dict" view option added by the Policy).

@ericl ericl added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Dec 13, 2020
@ericl
Copy link
Contributor

ericl commented Dec 13, 2020

Let's iterate on this API, I think adding -1 arg or input_dict_for_traj_end method is the clearest API.

@sven1977 sven1977 removed the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Dec 17, 2020
@@ -198,7 +198,8 @@ def postprocess_ppo_gae(
# input_dict.
if policy.config["_use_trajectory_view_api"]:
# Create an input dict according to the Model's requirements.
input_dict = policy.model.get_input_dict(sample_batch, index=-1)
input_dict = policy.model.get_input_dict(
sample_batch, index="last")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 love it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking ahead, is it possible we unify model.from_batch() and model.get_input_dict()? It seems currently only PPO is using this get_input_dict method so maybe we should remove it.

Thinking... model.from_batch(sample_batch, batch_index="last") would be pretty clean. This is probably out of scope of this PR though.

@@ -34,6 +35,9 @@ def to_float_np_array(v: List[Any]) -> np.ndarray:
return arr


_INIT_COLS = [SampleBatch.OBS]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@ericl ericl added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Dec 17, 2020
@ericl
Copy link
Contributor

ericl commented Dec 17, 2020

LGTM, had one thought on a possible future cleanup.

@sven1977
Copy link
Contributor Author

waiting for all tests ...

@sven1977 sven1977 merged commit b2bcab7 into ray-project:master Dec 21, 2020
@sven1977 sven1977 deleted the attention_nets_tf branch March 27, 2021 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
@author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants