Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][Dataset] Actor based prefetching #23952

Merged
merged 10 commits into from
Apr 29, 2022
Merged

Conversation

scv119
Copy link
Contributor

@scv119 scv119 commented Apr 16, 2022

Why are these changes needed?

The prefetch_blocks implementation doesn't work as expected. Due to ray.wait() doesn't given us fine grained control, today we block waiting any of the block returns. As I read the code, it may or may not actually fetching all the blocks.

A better way to ensure prefetching not blocking is to use ray remote function call, which is not blocking and ensures the blocks are fetched eventually.

The evaluation shows the actor based prefetcher could achieve perfect prefetching if the prefetch speed is faster than consuming speed.

Evaluation

Preliminary test shows the actor based prefetcher improve the throughtput of following script on a 5 cpu node 1 gpu node cluster , which simulate ingesting validation dataset by

1. load 80GB data on 5 cpu nodes.
2. ingest the data on gpu nodes with ds.iter_batch(prefetching_blocks=1, batch_size=250000), where fo reach batch we sleep for a small time period to simulate validation.
3. we measure and report the total time of step 2 (validation)

0.01 second per batch validation time:

If we simulate per-batch validation time to 0.01 seconds:

validation time: 88.94s -> 68.13s

ray.wait prefetching:

(consume pid=11517, ip=10.0.1.12) == Pipeline Window 0 ==
(consume pid=11517, ip=10.0.1.12) Stage 1 read: 48/48 blocks executed in 26.37s
(consume pid=11517, ip=10.0.1.12) * Remote wall time: 5.12s min, 24.05s max, 13.77s mean, 660.84s total
(consume pid=11517, ip=10.0.1.12) * Remote cpu time: 4.35s min, 5.33s max, 4.92s mean, 235.99s total
(consume pid=11517, ip=10.0.1.12) * Output num rows: 10000000 min, 10000000 max, 10000000 mean, 480000000 total
(consume pid=11517, ip=10.0.1.12) * Output size bytes: 1787500000 min, 1787500000 max, 1787500000 mean, 85800000000 total
(consume pid=11517, ip=10.0.1.12) * Tasks per node: 9 min, 10 max, 9 mean; 5 nodes used
(consume pid=11517, ip=10.0.1.12)
(consume pid=11517, ip=10.0.1.12) Stage 2 map_batches: 48/48 blocks executed in 60.55s
(consume pid=11517, ip=10.0.1.12) * Remote wall time: 11.41s min, 16.25s max, 14.01s mean, 672.57s total
(consume pid=11517, ip=10.0.1.12) * Remote cpu time: 11.95s min, 16.8s max, 14.58s mean, 699.93s total
(consume pid=11517, ip=10.0.1.12) * Output num rows: 10000000 min, 10000000 max, 10000000 mean, 480000000 total
(consume pid=11517, ip=10.0.1.12) * Output size bytes: 1680000128 min, 1680000128 max, 1680000128 mean, 80640006144 total
(consume pid=11517, ip=10.0.1.12) * Tasks per node: 9 min, 10 max, 9 mean; 5 nodes used
(consume pid=11517, ip=10.0.1.12)
(consume pid=11517, ip=10.0.1.12) Dataset iterator time breakdown:
(consume pid=11517, ip=10.0.1.12) * In ray.wait(): 41.07s
(consume pid=11517, ip=10.0.1.12) * In ray.get(): 24.78s
(consume pid=11517, ip=10.0.1.12) * In format_batch(): 2.9s
(consume pid=11517, ip=10.0.1.12) * In user code: 19.4s
(consume pid=11517, ip=10.0.1.12) * Total time: 88.22s
(consume pid=11517, ip=10.0.1.12)
(consume pid=11517, ip=10.0.1.12) ##### Overall Pipeline Time Breakdown #####
(consume pid=11517, ip=10.0.1.12) * Time in dataset iterator: 68.82s
(consume pid=11517, ip=10.0.1.12) * Time in user code: 19.39s
(consume pid=11517, ip=10.0.1.12) * Total time: 88.94s
(consume pid=11517, ip=10.0.1.12)

actor based prefetching

(consume pid=11209, ip=10.0.1.12) == Pipeline Window 0 ==
(consume pid=11209, ip=10.0.1.12) Stage 1 read: 48/48 blocks executed in 32.15s
(consume pid=11209, ip=10.0.1.12) * Remote wall time: 5.88s min, 30.26s max, 15.79s mean, 758.15s total
(consume pid=11209, ip=10.0.1.12) * Remote cpu time: 4.54s min, 5.84s max, 5.01s mean, 240.63s total
(consume pid=11209, ip=10.0.1.12) * Output num rows: 10000000 min, 10000000 max, 10000000 mean, 480000000 total
(consume pid=11209, ip=10.0.1.12) * Output size bytes: 1787500000 min, 1787500000 max, 1787500000 mean, 85800000000 total
(consume pid=11209, ip=10.0.1.12) * Tasks per node: 9 min, 10 max, 9 mean; 5 nodes used
(consume pid=11209, ip=10.0.1.12)
(consume pid=11209, ip=10.0.1.12) Stage 2 map_batches: 48/48 blocks executed in 63.56s
(consume pid=11209, ip=10.0.1.12) * Remote wall time: 11.48s min, 17.06s max, 14.36s mean, 689.4s total
(consume pid=11209, ip=10.0.1.12) * Remote cpu time: 12.08s min, 17.57s max, 14.94s mean, 717.03s total
(consume pid=11209, ip=10.0.1.12) * Output num rows: 10000000 min, 10000000 max, 10000000 mean, 480000000 total
(consume pid=11209, ip=10.0.1.12) * Output size bytes: 1680000128 min, 1680000128 max, 1680000128 mean, 80640006144 total
(consume pid=11209, ip=10.0.1.12) * Tasks per node: 9 min, 10 max, 9 mean; 5 nodes used
(consume pid=11209, ip=10.0.1.12)
(consume pid=11209, ip=10.0.1.12) Dataset iterator time breakdown:
(consume pid=11209, ip=10.0.1.12) * In ray.wait(): 33.04ms
(consume pid=11209, ip=10.0.1.12) * In ray.get(): 43.91s
(consume pid=11209, ip=10.0.1.12) * In format_batch(): 3.85s
(consume pid=11209, ip=10.0.1.12) * In user code: 19.52s
(consume pid=11209, ip=10.0.1.12) * Total time: 67.41s
(consume pid=11209, ip=10.0.1.12)
(consume pid=11209, ip=10.0.1.12) ##### Overall Pipeline Time Breakdown #####
(consume pid=11209, ip=10.0.1.12) * Time in dataset iterator: 47.9s
(consume pid=11209, ip=10.0.1.12) * Time in user code: 19.49s
(consume pid=11209, ip=10.0.1.12) * Total time: 68.13s

0.05 second per batch validation time:

If we simulate per-batch validation time to 0.05 seconds:

validation time: 165.99 -> 104.1s

Actually as we can see from the experiment, we achieved perfect prefetching as the ray.get/ray.wait time is close to 0.

ray.wait prefetching:

(consume pid=11990, ip=10.0.1.12) == Pipeline Window 0 ==
(consume pid=11990, ip=10.0.1.12) Stage 1 read: 48/48 blocks executed in 30.24s
(consume pid=11990, ip=10.0.1.12) * Remote wall time: 7.08s min, 28.37s max, 13.67s mean, 656.38s total
(consume pid=11990, ip=10.0.1.12) * Remote cpu time: 4.37s min, 5.4s max, 4.88s mean, 234.34s total
(consume pid=11990, ip=10.0.1.12) * Output num rows: 10000000 min, 10000000 max, 10000000 mean, 480000000 total
(consume pid=11990, ip=10.0.1.12) * Output size bytes: 1787500000 min, 1787500000 max, 1787500000 mean, 85800000000 total
(consume pid=11990, ip=10.0.1.12) * Tasks per node: 9 min, 10 max, 9 mean; 5 nodes used
(consume pid=11990, ip=10.0.1.12)
(consume pid=11990, ip=10.0.1.12) Stage 2 map_batches: 48/48 blocks executed in 58.15s
(consume pid=11990, ip=10.0.1.12) * Remote wall time: 11.5s min, 16.48s max, 13.9s mean, 667.43s total
(consume pid=11990, ip=10.0.1.12) * Remote cpu time: 12.09s min, 17.02s max, 14.48s mean, 694.83s total
(consume pid=11990, ip=10.0.1.12) * Output num rows: 10000000 min, 10000000 max, 10000000 mean, 480000000 total
(consume pid=11990, ip=10.0.1.12) * Output size bytes: 1680000128 min, 1680000128 max, 1680000128 mean, 80640006144 total
(consume pid=11990, ip=10.0.1.12) * Tasks per node: 9 min, 10 max, 9 mean; 5 nodes used
(consume pid=11990, ip=10.0.1.12)
(consume pid=11990, ip=10.0.1.12) Dataset iterator time breakdown:
(consume pid=11990, ip=10.0.1.12) * In ray.wait(): 43.73s
(consume pid=11990, ip=10.0.1.12) * In ray.get(): 22.03s
(consume pid=11990, ip=10.0.1.12) * In format_batch(): 3.14s
(consume pid=11990, ip=10.0.1.12) * In user code: 96.28s
(consume pid=11990, ip=10.0.1.12) * Total time: 165.27s
(consume pid=11990, ip=10.0.1.12)
(consume pid=11990, ip=10.0.1.12) ##### Overall Pipeline Time Breakdown #####
(consume pid=11990, ip=10.0.1.12) * Time in dataset iterator: 69.0s
(consume pid=11990, ip=10.0.1.12) * Time in user code: 96.26s
(consume pid=11990, ip=10.0.1.12) * Total time: 165.99

actor based prefetching

(consume pid=11780, ip=10.0.1.12) == Pipeline Window 0 ==
(consume pid=11780, ip=10.0.1.12) Stage 1 read: 48/48 blocks executed in 30.43s
(consume pid=11780, ip=10.0.1.12) * Remote wall time: 5.98s min, 28.4s max, 14.35s mean, 688.75s total
(consume pid=11780, ip=10.0.1.12) * Remote cpu time: 4.53s min, 5.7s max, 4.91s mean, 235.8s total
(consume pid=11780, ip=10.0.1.12) * Output num rows: 10000000 min, 10000000 max, 10000000 mean, 480000000 total
(consume pid=11780, ip=10.0.1.12) * Output size bytes: 1787500000 min, 1787500000 max, 1787500000 mean, 85800000000 total
(consume pid=11780, ip=10.0.1.12) * Tasks per node: 9 min, 10 max, 9 mean; 5 nodes used
(consume pid=11780, ip=10.0.1.12)
(consume pid=11780, ip=10.0.1.12) Stage 2 map_batches: 48/48 blocks executed in 61.91s
(consume pid=11780, ip=10.0.1.12) * Remote wall time: 11.27s min, 15.14s max, 13.62s mean, 653.93s total
(consume pid=11780, ip=10.0.1.12) * Remote cpu time: 11.89s min, 15.93s max, 14.22s mean, 682.59s total
(consume pid=11780, ip=10.0.1.12) * Output num rows: 10000000 min, 10000000 max, 10000000 mean, 480000000 total
(consume pid=11780, ip=10.0.1.12) * Output size bytes: 1680000128 min, 1680000128 max, 1680000128 mean, 80640006144 total
(consume pid=11780, ip=10.0.1.12) * Tasks per node: 9 min, 10 max, 9 mean; 5 nodes used
(consume pid=11780, ip=10.0.1.12)
(consume pid=11780, ip=10.0.1.12) Dataset iterator time breakdown:
(consume pid=11780, ip=10.0.1.12) * In ray.wait(): 24.93ms
(consume pid=11780, ip=10.0.1.12) * In ray.get(): 2.93s
(consume pid=11780, ip=10.0.1.12) * In format_batch(): 4.0s
(consume pid=11780, ip=10.0.1.12) * In user code: 96.3s
(consume pid=11780, ip=10.0.1.12) * Total time: 103.38s
(consume pid=11780, ip=10.0.1.12)
(consume pid=11780, ip=10.0.1.12) ##### Overall Pipeline Time Breakdown #####
(consume pid=11780, ip=10.0.1.12) * Time in dataset iterator: 7.09s
(consume pid=11780, ip=10.0.1.12) * Time in user code: 96.26s
(consume pid=11780, ip=10.0.1.12) * Total time: 104.1s

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@scv119 scv119 changed the title [Dataset] try to fix prefetch [RFC][Dataset] Actor based prefetching Apr 16, 2022
@scv119 scv119 marked this pull request as ready for review April 16, 2022 05:07
@scv119 scv119 force-pushed the wait branch 6 times, most recently from f7ce018 to 9ea1089 Compare April 16, 2022 07:25
@clarkzinzow
Copy link
Contributor

clarkzinzow commented Apr 16, 2022

Awesome, great benchmarking!

It seems like the underlying problem with the current implementation is that ray.wait() isn't working as expected. For ray.wait(refs, num_returns=1, fetch_local=True), a pull request should be issued for every object in refs, and the ray.wait() call shouldn't return until one of the objects is local. The actor-based prefetcher is doing the same thing, triggering a pull request for every object in refs, but isn't blocking until one of them is local. Ultimately, I'd expect the time in prefetching + the time in ray.get() to be ~ the same across these two approaches, with the ray.wait() approach being slightly faster due to the lack of task submission overhead. I think that the only difference in the created pull requests are the rough priority of ray.wait() calls vs task arguments under memory pressure, which I assume isn't relevant in this case.

I'm assuming that we don't know why ray.wait() is not behaving as expected here?

@clarkzinzow
Copy link
Contributor

clarkzinzow commented Apr 16, 2022

Also, if we were to bypass ray.wait() in our prefetching, instead of creating a detached actor for each consumer node, I'd vote for issuing ray.get prefetches in a background thread within the consumer process (created in .iter_batches()), where the main thread pulls these prefetched blocks via a threading queue. This would avoid the extra actor per trainer node, the task submission overhead, etc. Ludwig currently does this background thread prefetching themselves since they ran into bugs with our old prefetching implementation, and they've had great results.

We were already planning to do this in order to be able to pipeline some last-last-mile transformations (Torch tensor conversion, local shuffling, device transfer, etc.) with training, so this implementation would stick even if we fixed ray.wait.

@scv119
Copy link
Contributor Author

scv119 commented Apr 18, 2022

I'm assuming that we don't know why ray.wait() is not behaving as expected here?

i think ray.wait behavior is not well defined and can only be reasoned by reading the code.

Also, if we were to bypass ray.wait() in our prefetching, instead of creating a detached actor for each consumer node, I'd vote for issuing ray.get prefetches in a background thread within the consumer process (created in .iter_batches()),

Do you have code snippet on how to do that? Is calling ray.get in a separate thread defined behavior?

Copy link
Contributor

@ericl ericl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that we could probably do this as a separate thread (which would be less likely to be problematic than using a separate actor).

The root cause here seems to be a bug in ray.wait() though, no? Wait is supposed to be fetching the blocks, and somehow that's not going on correctly. cc @stephanie-wang

I'm fine with merging a thread/actor-based hack for now, as long as we document this as a workaround and investigate the root cause.

@ericl ericl added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Apr 18, 2022
@stephanie-wang
Copy link
Contributor

Agreed, this sounds like a bug in ray.wait(fetch_local=True). Could you open an issue with the findings here? As @scv119 says, we should also better define the semantics for this call. It should act the same way as the ray.get hack here as long as there is memory availability, but yes, this is not documented right now.

python/ray/data/impl/block_batching.py Outdated Show resolved Hide resolved
python/ray/data/impl/block_batching.py Outdated Show resolved Hide resolved
python/ray/data/impl/block_batching.py Outdated Show resolved Hide resolved
@clarkzinzow
Copy link
Contributor

@scv119 In case we go with the actor-based prefetcher, I left a review.

scv119 and others added 2 commits April 24, 2022 08:45
Co-authored-by: Clark Zinzow <clarkzinzow@gmail.com>
@scv119
Copy link
Contributor Author

scv119 commented Apr 24, 2022

cc @clarkzinzow @jianoaix maybe merge this?

@clarkzinzow
Copy link
Contributor

@scv119 I'll make a pass at a background thread-based solution on Monday, and if that doesn't pan out, we can merge this actor-based solution.

@jianoaix
Copy link
Contributor

This PR looks good to me. Actually, from my understanding, using Actors is cleaner since it's a higher level abstraction than using raw threads. If the performance diff is small, Actor should be preferred. Anything I am missing?

@clarkzinzow
Copy link
Contributor

@jianoaix There are multiple issues with this kind of actor-based prefetching, detailed here.

# ray.wait doesn't work as expected so that we use
# the actor based prefetcher as a work around. Read
# https://github.com/ray-project/ray/issues/23983 for details.
if len(block_window) > 1 and context.actor_prefetcher_enabled:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we able to tell whether this is running at Ray Client (maybe from the ray address given to ray.init()?), and if so, we go back to use ray.wait()? I'm looking whether this solution can address the issue of Ray Client as pointed out by @clarkzinzow at #24158.
If so, it seems a reasonable short-term fix before we land a better one (thread based prefetch or anything else). This is changing an internal spot so it should be not too hard to evolve in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah there is an easy way to test if it's ray client:
ray.util.client.ray.is_connected()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments addressed

Copy link
Contributor

@jianoaix jianoaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as a short-term fix and to catch the 1.13 cut today.

and context.actor_prefetcher_enabled
and not ray.util.client.ray.is_connected()
):
prefetcher = get_or_create_prefetcher()
Copy link
Contributor

@clarkzinzow clarkzinzow Apr 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will currently executes a named actor handle fetch on ever prefetch window; could we cache this actor handle so it's fetched once per batch_blocks() call?

Copy link
Contributor

@clarkzinzow clarkzinzow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the 1.13 cut as well, although it'd be great if we could do one named actor handle fetch per .iter_batches() call instead of one per prefetch window, i.e. one call vs. len(blocks) - prefetch_blocks calls.

I'll look at making this change, along with test coverage of each path.

…per .iter_batches() call; add test coverage.
Copy link
Contributor Author

@scv119 scv119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@clarkzinzow
Copy link
Contributor

Datasets tests look good, merging!

@clarkzinzow clarkzinzow merged commit f375200 into ray-project:master Apr 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
@author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants