New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] ConnectorV2: Enhance performance of add_n_batch_items()
for already batched items.
#43669
[RLlib] ConnectorV2: Enhance performance of add_n_batch_items()
for already batched items.
#43669
Conversation
…ector_v2_enhance_performance_of_add_n_batch_items
…ector_v2_enhance_performance_of_add_n_batch_items
…batch_items Signed-off-by: Sven Mika <svenmika1977@gmail.com>
…_of_add_n_batch_items' into connector_v2_enhance_performance_of_add_n_batch_items
Signed-off-by: sven1977 <svenmika1977@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Some remarks in regard to performance and the _has_batched_items
.
sa_episode.get_observations(indices=ts) | ||
for ts in range(len(sa_episode)) | ||
], | ||
items_to_add=sa_episode.get_observations(slice(0, len(sa_episode))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup that looks somehow cleaner and faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, thanks for mentioning this in the other PR! This helped me pinpoint the performance decrease this time.
data = tree.map_structure( | ||
# Expand on axis 0 (the to-be-time-dim) if item has not been batched' | ||
# yet, otherwise axis=1 (the time-dim). | ||
lambda s: np.expand_dims( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expand_dims
has an impact on the memory management of numpy arrays. Is here maybe a 'reshape' action possible or do we need to fill the new axis with new values? For example do we make a (32, 4)
array a (32, 1, 4)
one or a (32, max_seq_len, 4)
one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, great point! I didn't know that. Maybe we can just replace by a reshape. No we don't expect nor want this axis to be >1. It's the "simple" time-axis=1 add for action computing forward passes.
rllib/utils/spaces/space_utils.py
Outdated
# Use __new__ to create a new instance of our subclass. | ||
obj = np.asarray(input_array).view(cls) | ||
# Set the _has_batch_dim property. | ||
obj._has_batch_dim = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this attribute? What I see is that we only check against an object being an instance of type BatchedNdArray
and not if _has_batch_dim
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right. This being another class should be sufficient.
Another option would be to signal via the list
in which we are storing these that this is a list-to-be-concatenated (rather than a list-to-be-stacked). 🤔
This would remove the messiness of the BatchedNdArray
approach, in which we normally should check, whether really all items in the list are of that class (which we don't do right now!).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…_of_add_n_batch_items' into connector_v2_enhance_performance_of_add_n_batch_items
ConnectorV2: Enhance performance of
add_n_batch_items()
for already batched items:ConnectorV2.add_n_items_to_batch
, the method would split the (already batched) item into a list, then add the list items individually to the batch, then re-batch. This is very expensive, especially for complex spaces.As an example for how much speedup this change delivers, run the nested_action_spaces.py example script with and w/o this fix. The difference in running time is about 100s vs 80s (use the
--enable-new-api-stack --num-agent=0
options).Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.