[core][rdt] ray.get launches concurrent one-sided transfers for multiple ObjectRefs#61773
Conversation
Signed-off-by: Stephanie wang <smwang@cs.washington.edu>
Signed-off-by: Stephanie wang <smwang@cs.washington.edu>
Signed-off-by: Stephanie wang <smwang@cs.washington.edu>
Signed-off-by: Stephanie wang <smwang@cs.washington.edu>
There was a problem hiding this comment.
Code Review
This pull request introduces concurrent fetching for RDT objects in ray.get by refactoring the tensor transport mechanism to support an asynchronous fetch/wait pattern. The changes are well-structured, introducing a new FetchRequest to manage state for asynchronous operations and providing default synchronous implementations for backward compatibility. The NixlTensorTransport is updated to leverage this new asynchronous pattern, and the RDTManager and Worker are modified to pipeline multiple fetch requests. The implementation appears correct and aligns with the goal of improving performance for ray.get with multiple RDT objects. I have one suggestion regarding a potentially unused class that could cause confusion.
| @dataclass | ||
| class TransferMetadata: | ||
| """Base class for in-flight tensor transfer state. | ||
|
|
||
| This class holds the minimal state needed to track an async transfer. | ||
| Backend-specific implementations should extend this class with additional fields. | ||
|
|
||
| Args: | ||
| tensors: The tensors being transferred. | ||
| """ | ||
|
|
||
| tensors: List[Any] |
There was a problem hiding this comment.
The TransferMetadata dataclass appears to be unused in this pull request. The NixlFetchRequest, which holds the in-flight transfer state, inherits from FetchRequest, not this class. Additionally, there is another TransferMetadata NamedTuple defined in rdt_manager.py, which could lead to confusion. If this class is not intended for use, consider removing it to improve clarity.
dayshah
left a comment
There was a problem hiding this comment.
It feels a little weird to have the ray.get path and the implicit get path diverge so much. But I guess eventually we'll unify them a bit if we ever get around to having rdt more integrated into the dependency waiter
|
|
||
|
|
||
| class _PipelineCheckingTransport(TensorTransportManager): | ||
| """Fake one-sided transport that records the order of fetch/wait calls. |
There was a problem hiding this comment.
oh nice, we've been needing these fake transports
There was a problem hiding this comment.
Now we can with the custom transport API :)
| manager = _build_manager([obj_id], backend=_TWO_SIDED_BACKEND_NAME) | ||
|
|
||
| with pytest.raises(ValueError, match="use_object_store=True"): | ||
| manager.get_rdt_objects([obj_id], use_object_store=False) |
There was a problem hiding this comment.
I feel like this test already exists / should exist in the gloo tests where ray.get-ting should fail
There was a problem hiding this comment.
Hmm actually I could not find such an existing test. I figured it's nice to have a unit test for the behavior here since it can run much faster. Actually let me update the match string though.
| self._wait_fetch(object_id, fetch_request) | ||
| except Exception as e: | ||
| if trigger_exception is None: | ||
| trigger_exception = e |
There was a problem hiding this comment.
why wait for all the fetches before raising, i think this is different than the normal ray.get semantic where if any of the obj id's error it's immediately raised without waiting for other obj's
There was a problem hiding this comment.
Ah yeah the problem is that _wait_fetch also does some cleanup :( Right now it's only relevant for NIXL, though, since it's cleaning up memory registrations.
Another option is to update it so that FetchRequest.del does the cleanup.
| ) | ||
| else: | ||
| result[object_id] = rdt_store.wait_and_get_object( | ||
| object_id, timeout=ray_constants.RDT_FETCH_FAIL_TIMEOUT_SECONDS |
There was a problem hiding this comment.
the timeouts here could probably be 0 right, otherwise something went wrong in the fetch?
There was a problem hiding this comment.
We use this codepath for both ray.get (where the timeout here should be 0) and the implicit get (where the timeout here is real, and we should actually wait). This was an issue before this PR too, I added a TODO.
| trigger_exception = e | ||
|
|
||
| if trigger_exception is not None: | ||
| raise trigger_exception |
There was a problem hiding this comment.
the other objects won't get popped out if one of the obj's raises an exception?
There was a problem hiding this comment.
Ah good catch, thanks.
| # the remaining fetches. | ||
| for object_id, fetch_request in fetch_requests.items(): | ||
| try: | ||
| self._wait_fetch(object_id, fetch_request) |
There was a problem hiding this comment.
wait fetch won't abide by the timeout?
There was a problem hiding this comment.
Yeah it didn't before either. Let me see if I can update this.
Yeah I realized this too, but the goal here was just to get this in now and we can later update the implicit get path. |
Co-authored-by: Dhyey Shah <dhyey2019@gmail.com> Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
Signed-off-by: Stephanie wang <smwang@cs.washington.edu>
Signed-off-by: Stephanie wang <smwang@cs.washington.edu>
Signed-off-by: Stephanie wang <smwang@cs.washington.edu>
d9115be to
0730c2a
Compare
Signed-off-by: Stephanie wang <smwang@cs.washington.edu>
ff9a483 to
e5e651e
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Reviewed by Cursor Bugbot for commit e5e651e. Configure here.
| owner_address="", | ||
| call_site="", | ||
| ) | ||
| return result |
There was a problem hiding this comment.
Fetched RDT objects not stored for repeated ray.get
Medium Severity
Objects fetched via _trigger_fetch and _wait_fetch in fetch_and_get_rdt_objects are returned but never added to the rdt_store. The old _fetch_object code called rdt_store.add_object(obj_id, tensors) after a successful fetch, allowing subsequent ray.get calls on the same ObjectRef to find the object in phase 1's store check. Without this caching, every ray.get call on the same RDT ObjectRef triggers a new network transfer, which may fail if the source has already garbage-collected the object.
Reviewed by Cursor Bugbot for commit e5e651e. Configure here.
…ple ObjectRefs (ray-project#61773) Instead of fetching objects one at a time during `ray.get`, we launch a fetch request per ObjectRef, then wait for all of them to complete. This should allow overlapping network transfers for requests of multiple ObjectRefs at a time. Note that this PR only supports this for `ray.get`, not for transfer of task arguments. The PR refactors the TensorTransportManager to add new methods, `fetch_multiple_transfers` and `wait_fetch_complete`. The default implementation for these calls the synchronous `recv_multiple_transfers` during fetch and simply returns the tensors during `wait_fetch_complete`. Backends that support asynchronous fetching can override the new methods to allow concurrent transfers (see NixlTensorTransport for an example). This PR also adds timeout support for ray.get on RDT objects and unit tests for RDTManager. Closes ray-project#61453. --------- Signed-off-by: Stephanie wang <smwang@cs.washington.edu> Co-authored-by: Dhyey Shah <dhyey2019@gmail.com> Signed-off-by: anindyam1969 <amukherjee@kinetica.com>


Description
Instead of fetching objects one at a time during
ray.get, we launch a fetch request per ObjectRef, then wait for all of them to complete. This should allow overlapping network transfers for requests of multiple ObjectRefs at a time. Note that this PR only supports this forray.get, not for transfer of task arguments.The PR refactors the TensorTransportManager to add new methods,
fetch_multiple_transfersandwait_fetch_complete. The default implementation for these calls the synchronousrecv_multiple_transfersduring fetch and simply returns the tensors duringwait_fetch_complete. Backends that support asynchronous fetching can override the new methods to allow concurrent transfers (see NixlTensorTransport for an example).This PR also adds timeout support for ray.get on RDT objects and unit tests for RDTManager.
Related issues
Closes #61453.