[Data] Remove _base_dataset from StreamSplitDataIterator#61607
Conversation
There was a problem hiding this comment.
Code Review
This pull request effectively decouples StreamSplitDataIterator from direct Dataset access by delegating operations to the SplitCoordinator. This is a significant improvement for reducing memory usage on the head node, particularly in distributed training scenarios. The use of a lock for schema computation is a solid approach to handle concurrency. The accompanying tests are thorough and cover the new functionality well. I have one suggestion to enhance the robustness of the schema() method to align its behavior with other DataIterator implementations.
## Summary `TrainRunContext` previously held a direct reference to the `datasets` dict, which contains full `Dataset` objects. Because `TrainRunContext` is serialized and sent to every training worker during `init_train_context`, this caused head node object store usage to scale linearly with the number of workers. At scale, this contributed to head node OOMs and training ingest instability. This is a companion change to #61607, which addressed the same serialization issue in `StreamSplitDataIterator`. Together, these two PRs ensure that `Dataset` objects are no longer serialized to train workers, eliminating a class of memory pressure that scales with worker count during training ingest. ## Changes - Remove the `datasets` field from `TrainRunContext` in `context.py`, so the `Dataset` objects are no longer part of the serialized context - Update `DatasetsCallback.__init__` in `datasets.py` to accept `datasets` as a separate argument instead of extracting it from `TrainRunContext` - Update `DataParallelTrainer._create_default_callbacks()` in `data_parallel_trainer.py` to pass `self.datasets` directly to `DatasetsCallback` - Update `create_dummy_run_context` test utility and all affected tests ## Tests - Existing tests in `test_data_resource_cleanup.py` and `test_data_integration.py` updated to use the new `DatasetsCallback` signature - All 7 affected tests pass: `test_datasets_callback`, `test_datasets_callback_v1_uses_exclude_resources`, `test_v2_no_negative_exclude_resources`, `test_datasets_callback_multiple_datasets`, `test_after_worker_group_abort`, `test_after_worker_group_shutdown`, `test_split_coordinator_shutdown_executor` Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
## Summary `TrainRunContext` previously held a direct reference to the `datasets` dict, which contains full `Dataset` objects. Because `TrainRunContext` is serialized and sent to every training worker during `init_train_context`, this caused head node object store usage to scale linearly with the number of workers. At scale, this contributed to head node OOMs and training ingest instability. This is a companion change to ray-project#61607, which addressed the same serialization issue in `StreamSplitDataIterator`. Together, these two PRs ensure that `Dataset` objects are no longer serialized to train workers, eliminating a class of memory pressure that scales with worker count during training ingest. ## Changes - Remove the `datasets` field from `TrainRunContext` in `context.py`, so the `Dataset` objects are no longer part of the serialized context - Update `DatasetsCallback.__init__` in `datasets.py` to accept `datasets` as a separate argument instead of extracting it from `TrainRunContext` - Update `DataParallelTrainer._create_default_callbacks()` in `data_parallel_trainer.py` to pass `self.datasets` directly to `DatasetsCallback` - Update `create_dummy_run_context` test utility and all affected tests ## Tests - Existing tests in `test_data_resource_cleanup.py` and `test_data_integration.py` updated to use the new `DatasetsCallback` signature - All 7 affected tests pass: `test_datasets_callback`, `test_datasets_callback_v1_uses_exclude_resources`, `test_v2_no_negative_exclude_resources`, `test_datasets_callback_multiple_datasets`, `test_after_worker_group_abort`, `test_after_worker_group_shutdown`, `test_split_coordinator_shutdown_executor` Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
47f7c6a to
cd51b7f
Compare
Move schema() resolution from StreamSplitDataIterator to the SplitCoordinator actor, which already holds the dataset. This avoids accessing _base_dataset directly from the iterator for schema calls, and adds thread-safe caching with a guard against schema resolution during active execution. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
Move remaining _base_dataset usages (get_context, _get_dataset_tag) to SplitCoordinator and remove the redundant client-side _run_index increment. This fully decouples StreamSplitDataIterator from direct dataset access. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…n and updated tests Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
cd51b7f to
b94271f
Compare
Signed-off-by: Jason Li <57246540+JasonLi1909@users.noreply.github.com>
Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
| raise RuntimeError( | ||
| "Cannot call schema() during active dataset execution. " |
There was a problem hiding this comment.
can you try running this without raising the error to see what happens? and add the output to the PR description appendix
There was a problem hiding this comment.
Ran without the safe guard, it can lead to a hang because the second executor will try to schedule tasks but won't have any resources available, added to PR description.
Co-authored-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Jason Li <57246540+JasonLi1909@users.noreply.github.com>
Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
| return [ | ||
| StreamSplitDataIterator(base_dataset, coord_actor, i, n) for i in range(n) | ||
| ] | ||
| return [StreamSplitDataIterator(coord_actor, i, n) for i in range(n)] |
There was a problem hiding this comment.
Just to confirm -- this StreamSplitDataIterator was being serialized 2x?
There was a problem hiding this comment.
No, each worker only gets one iterator so it's only serialized once. The other copy of the dataset object came from the TrainRunContext which was removed in #61953
| with self._dataset_state_lock: | ||
| if self._schema is not None: | ||
| return self._schema | ||
| if self._current_executor is not None and self._current_executor.is_alive(): |
There was a problem hiding this comment.
I read the PR description but still a bit confused -- Are you implying there are scenarios where the schema is None, but current executor is not None, and that's why we need this guard? Like, can the executor be running when there is no schema produced yet? I think adding a comment about the if guard would be very helpful for future readers
There was a problem hiding this comment.
self._schema is for caching. If get_dataset_schema was not called previously, then it is very possible for it to be empty when called during execution. This guard is to primary to prevent two executions on the same dataset- which can lead to deadlock. Will update pr description to be more clear
…#61607) `StreamSplitDataIterator` previously held a direct reference to the base `Dataset` object, which was used for `schema()`, `get_context()`, `_get_dataset_tag()`, and incrementing `_run_index`. During `init_train_context` on `RayTrainWorker` actors, the iterator is serialized and copied once per training worker, and because it contained the full dataset object, this caused significant head node object store usage proportional to the number of workers. At scale, this led to training ingest instability and head node OOMs. This PR moves all remaining dataset interactions into the `SplitCoordinator` actor, so the iterator communicates exclusively through the coordinator and no longer serializes the dataset object. ## Changes - Remove `_base_dataset` from `StreamSplitDataIterator` and route `schema()`, `get_context()`, and `_get_dataset_tag()` through new coordinator methods (`get_dataset_schema()`, `get_dataset_context()`, `get_dataset_tag()`) via `ray.get()` - Remove `_run_index` increment from `_to_ref_bundle_iterator` since it is now handled on the `SplitCoordinator` side via `ExecutionPlan.create_executor()` --------- Signed-off-by: JasonLi1909 <jasli1909@gmail.com> Signed-off-by: Jason Li <57246540+JasonLi1909@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Frank Mancina <fmancina@haproxy.com>
## Summary `TrainRunContext` previously held a direct reference to the `datasets` dict, which contains full `Dataset` objects. Because `TrainRunContext` is serialized and sent to every training worker during `init_train_context`, this caused head node object store usage to scale linearly with the number of workers. At scale, this contributed to head node OOMs and training ingest instability. This is a companion change to ray-project#61607, which addressed the same serialization issue in `StreamSplitDataIterator`. Together, these two PRs ensure that `Dataset` objects are no longer serialized to train workers, eliminating a class of memory pressure that scales with worker count during training ingest. ## Changes - Remove the `datasets` field from `TrainRunContext` in `context.py`, so the `Dataset` objects are no longer part of the serialized context - Update `DatasetsCallback.__init__` in `datasets.py` to accept `datasets` as a separate argument instead of extracting it from `TrainRunContext` - Update `DataParallelTrainer._create_default_callbacks()` in `data_parallel_trainer.py` to pass `self.datasets` directly to `DatasetsCallback` - Update `create_dummy_run_context` test utility and all affected tests ## Tests - Existing tests in `test_data_resource_cleanup.py` and `test_data_integration.py` updated to use the new `DatasetsCallback` signature - All 7 affected tests pass: `test_datasets_callback`, `test_datasets_callback_v1_uses_exclude_resources`, `test_v2_no_negative_exclude_resources`, `test_datasets_callback_multiple_datasets`, `test_after_worker_group_abort`, `test_after_worker_group_shutdown`, `test_split_coordinator_shutdown_executor` Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
…#61607) `StreamSplitDataIterator` previously held a direct reference to the base `Dataset` object, which was used for `schema()`, `get_context()`, `_get_dataset_tag()`, and incrementing `_run_index`. During `init_train_context` on `RayTrainWorker` actors, the iterator is serialized and copied once per training worker, and because it contained the full dataset object, this caused significant head node object store usage proportional to the number of workers. At scale, this led to training ingest instability and head node OOMs. This PR moves all remaining dataset interactions into the `SplitCoordinator` actor, so the iterator communicates exclusively through the coordinator and no longer serializes the dataset object. ## Changes - Remove `_base_dataset` from `StreamSplitDataIterator` and route `schema()`, `get_context()`, and `_get_dataset_tag()` through new coordinator methods (`get_dataset_schema()`, `get_dataset_context()`, `get_dataset_tag()`) via `ray.get()` - Remove `_run_index` increment from `_to_ref_bundle_iterator` since it is now handled on the `SplitCoordinator` side via `ExecutionPlan.create_executor()` --------- Signed-off-by: JasonLi1909 <jasli1909@gmail.com> Signed-off-by: Jason Li <57246540+JasonLi1909@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Revert the DatasetManager removal from #55760. That revert was needed because StreamSplitDataIterator held large Dataset references, making them slow to send to workers via get_dataset_shard. #61607 removed those Dataset references, so the original performance concern no longer applies. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Justin Yu <justin.v.yu@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Summary
StreamSplitDataIteratorpreviously held a direct reference to the baseDatasetobject, which was used forschema(),get_context(),_get_dataset_tag(), and incrementing_run_index. Duringinit_train_contextonRayTrainWorkeractors, the iterator is serialized and copied once per training worker, and because it contained the full dataset object, this caused significant head node object store usage proportional to the number of workers. At scale, this led to training ingest instability and head node OOMs. This PR moves all remaining dataset interactions into theSplitCoordinatoractor, so the iterator communicates exclusively through the coordinator and no longer serializes the dataset object.Changes
_base_datasetfromStreamSplitDataIteratorand routeschema(),get_context(), and_get_dataset_tag()through new coordinator methods (get_dataset_schema(),get_dataset_context(),get_dataset_tag()) viaray.get()_run_indexincrement from_to_ref_bundle_iteratorsince it is now handled on theSplitCoordinatorside viaExecutionPlan.create_executor()_dataset_state_lockfor schema accessget_dataset_schema()on the coordinator is guarded by a_dataset_state_lock. Schema computation can trigger dataset execution, so the lock ensures this only happens once. The first caller computes and caches the result inself._schema, and subsequent callers return the cached value without redundant execution. The lock also prevents race conditions from concurrent access to dataset state.Schema access during active dataset execution
If a cached schema is available (e.g., computed before iteration started),
get_dataset_schema()returns it immediately, even during active execution. Otherwise, if the schema has never been computed and the streaming executor is active, it raises aRuntimeError.This guard is necessary because
ExecutionPlan.schema()with an unknown output type (e.g., a UDF whose return schema can't be inferred statically) falls back toexecute_to_iterator()inplan.py, which creates a secondStreamingExecutoron the same plan. That second executor computes its resource budget independently from the cluster viaResourceManager.get_global_limits(), unaware that the first executor has already claimed all available CPUs. The second executor's map tasks are submitted to Ray but can never be scheduled.Because
schema()blocks waiting for the first output bundle (next(gen)inexecute_to_iterator), theSplitCoordinatoractor becomes permanently blocked. This in turn blocks the consumers of the first executor, creating a circular deadlock where neither executor can make progress.SplitCoordinatorconcurrencyget_dataset_context()andget_dataset_tag()are lightweight, non-blocking calls that don't contend with the main execution threads, so the existingmax_concurrency=n+1on the coordinator actor is sufficient.Tests
test_streaming_split_schema_before_execution: verifies schema retrieval works before any iteration startstest_streaming_split_schema_during_execution: verifiesschema()raisesRuntimeErrorwhen called during active iteration (called from inside the iteration loop to avoid race conditions)test_streaming_split_schema_after_execution: verifies schema retrieval works after execution completestest_streaming_split_context: verifiesget_context()returns a validDataContextfrom the coordinatortest_streaming_split_dataset_tag: verifies_get_dataset_tag()returns correct split-indexed tags (_split_0,_split_1) from the coordinatorstreaming_splittests pass, plustest_context_propagation.py::test_streaming_split