Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] ConnectorV2: Enhance performance of add_n_batch_items() for already batched items. #43669

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Mar 4, 2024

ConnectorV2: Enhance performance of add_n_batch_items() for already batched items:

  • When adding an entire episode's observations (for example to build a train batch) to a batch via ConnectorV2.add_n_items_to_batch, the method would split the (already batched) item into a list, then add the list items individually to the batch, then re-batch. This is very expensive, especially for complex spaces.
  • Instead, we now allow adding already batched items via this same method and then "concatenate" the individual items along the batch axes afterwards (instead of stacking them onto a new batch axis).

As an example for how much speedup this change delivers, run the nested_action_spaces.py example script with and w/o this fix. The difference in running time is about 100s vs 80s (use the --enable-new-api-stack --num-agent=0 options).

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…ector_v2_enhance_performance_of_add_n_batch_items
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…ector_v2_enhance_performance_of_add_n_batch_items
…batch_items

Signed-off-by: Sven Mika <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…_of_add_n_batch_items' into connector_v2_enhance_performance_of_add_n_batch_items
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Copy link
Collaborator

@simonsays1980 simonsays1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Some remarks in regard to performance and the _has_batched_items.

sa_episode.get_observations(indices=ts)
for ts in range(len(sa_episode))
],
items_to_add=sa_episode.get_observations(slice(0, len(sa_episode))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup that looks somehow cleaner and faster.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thanks for mentioning this in the other PR! This helped me pinpoint the performance decrease this time.

data = tree.map_structure(
# Expand on axis 0 (the to-be-time-dim) if item has not been batched'
# yet, otherwise axis=1 (the time-dim).
lambda s: np.expand_dims(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expand_dims has an impact on the memory management of numpy arrays. Is here maybe a 'reshape' action possible or do we need to fill the new axis with new values? For example do we make a (32, 4) array a (32, 1, 4) one or a (32, max_seq_len, 4) one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, great point! I didn't know that. Maybe we can just replace by a reshape. No we don't expect nor want this axis to be >1. It's the "simple" time-axis=1 add for action computing forward passes.

# Use __new__ to create a new instance of our subclass.
obj = np.asarray(input_array).view(cls)
# Set the _has_batch_dim property.
obj._has_batch_dim = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this attribute? What I see is that we only check against an object being an instance of type BatchedNdArray and not if _has_batch_dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right. This being another class should be sufficient.

Another option would be to signal via the listin which we are storing these that this is a list-to-be-concatenated (rather than a list-to-be-stacked). 🤔

This would remove the messiness of the BatchedNdArray approach, in which we normally should check, whether really all items in the list are of that class (which we don't do right now!).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Signed-off-by: sven1977 <svenmika1977@gmail.com>
…_of_add_n_batch_items' into connector_v2_enhance_performance_of_add_n_batch_items
@sven1977 sven1977 merged commit 5237e49 into ray-project:master Mar 5, 2024
9 checks passed
@sven1977 sven1977 deleted the connector_v2_enhance_performance_of_add_n_batch_items branch March 6, 2024 09:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants