From 599be0beac66760c00bfecf23f02b55fa070ca8d Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Tue, 18 Jun 2024 13:08:17 -0700 Subject: [PATCH 1/7] tests Signed-off-by: Scott Lee --- .../data/_internal/planner/plan_udf_map_op.py | 162 ++++++++++++------ python/ray/data/tests/test_map.py | 29 ++++ 2 files changed, 143 insertions(+), 48 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index b1747616c655..8dbb24cbd71b 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -1,4 +1,8 @@ +import asyncio import collections +import inspect +import queue +from concurrent.futures import ThreadPoolExecutor from types import GeneratorType from typing import Any, Callable, Iterable, Iterator, List, Optional @@ -104,16 +108,6 @@ def _parse_op_fn(op: AbstractUDFMap): fn_constructor_args = op._fn_constructor_args or () fn_constructor_kwargs = op._fn_constructor_kwargs or {} - op_fn = make_callable_class_concurrent(op_fn) - - def fn(item: Any) -> Any: - assert ray.data._cached_fn is not None - assert ray.data._cached_cls == op_fn - try: - return ray.data._cached_fn(item, *fn_args, **fn_kwargs) - except Exception as e: - _handle_debugger_exception(e) - def init_fn(): if ray.data._cached_fn is None: ray.data._cached_cls = op_fn @@ -121,6 +115,28 @@ def init_fn(): *fn_constructor_args, **fn_constructor_kwargs ) + if inspect.isasyncgenfunction(op._fn.__call__): + + async def fn(item: Any) -> Any: + assert ray.data._cached_fn is not None + assert ray.data._cached_cls == op_fn + + try: + return ray.data._cached_fn(item, *fn_args, **fn_kwargs) + except Exception as e: + _handle_debugger_exception(e) + + else: + op_fn = make_callable_class_concurrent(op_fn) + + def fn(item: Any) -> Any: + assert ray.data._cached_fn is not None + assert ray.data._cached_cls == op_fn + try: + return ray.data._cached_fn(item, *fn_args, **fn_kwargs) + except Exception as e: + _handle_debugger_exception(e) + else: def fn(item: Any) -> Any: @@ -158,6 +174,7 @@ def _validate_batch_output(batch: Block) -> None: np.ndarray, collections.abc.Mapping, pd.core.frame.DataFrame, + dict, ), ): raise ValueError( @@ -193,45 +210,94 @@ def _validate_batch_output(batch: Block) -> None: def _generate_transform_fn_for_map_batches( fn: UserDefinedFunction, ) -> MapTransformCallable[DataBatch, DataBatch]: - def transform_fn( - batches: Iterable[DataBatch], _: TaskContext - ) -> Iterable[DataBatch]: - for batch in batches: - try: - if ( - not isinstance(batch, collections.abc.Mapping) - and BlockAccessor.for_block(batch).num_rows() == 0 - ): - # For empty input blocks, we directly ouptut them without - # calling the UDF. - # TODO(hchen): This workaround is because some all-to-all - # operators output empty blocks with no schema. - res = [batch] - else: - res = fn(batch) - if not isinstance(res, GeneratorType): - res = [res] - except ValueError as e: - read_only_msgs = [ - "assignment destination is read-only", - "buffer source array is read-only", - ] - err_msg = str(e) - if any(msg in err_msg for msg in read_only_msgs): - raise ValueError( - f"Batch mapper function {fn.__name__} tried to mutate a " - "zero-copy read-only batch. To be able to mutate the " - "batch, pass zero_copy_batch=False to map_batches(); " - "this will create a writable copy of the batch before " - "giving it to fn. To elide this copy, modify your mapper " - "function so it doesn't try to mutate its input." - ) from e + if inspect.iscoroutinefunction(fn): + # UDF is a callable class with async generator `__call__` method. + def transform_fn( + input_iterable: Iterable[DataBatch], _: TaskContext + ) -> Iterable[DataBatch]: + # Use a queue to store results from async generator calls. + # In the main event loop, we will put results into this queue + # from async generator, and yield them from the queue as they + # become available. + result_queue = queue.Queue() + + async def process_batch(batch: DataBatch): + output_batch_iterator = await fn(batch) + # As soon as results become available from the async generator, + # put them into the result queue so they can be yielded. + async for output_row in output_batch_iterator: + result_queue.put(output_row) + + async def process_all_batches(): + tasks = [asyncio.create_task(process_batch(x)) for x in input_iterable] + for task in asyncio.as_completed(tasks): + await task + # Sentinel to indicate completion. + result_queue.put(None) + + def run_event_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(process_all_batches()) + loop.close() + + # Start the event loop in a new thread + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(run_event_loop) + + # Yield results as they become available + while True: + # `out_batch` here is a one-row batch which contains + # output from the async generator, corresponding to a + # single row from the input batch. + out_batch = result_queue.get() + # Exit when sentinel is received. + if out_batch is None: + break + _validate_batch_output(out_batch) + yield out_batch + + else: + + def transform_fn( + batches: Iterable[DataBatch], _: TaskContext + ) -> Iterable[DataBatch]: + for batch in batches: + try: + if ( + not isinstance(batch, collections.abc.Mapping) + and BlockAccessor.for_block(batch).num_rows() == 0 + ): + # For empty input blocks, we directly ouptut them without + # calling the UDF. + # TODO(hchen): This workaround is because some all-to-all + # operators output empty blocks with no schema. + res = [batch] + else: + res = fn(batch) + if not isinstance(res, GeneratorType): + res = [res] + except ValueError as e: + read_only_msgs = [ + "assignment destination is read-only", + "buffer source array is read-only", + ] + err_msg = str(e) + if any(msg in err_msg for msg in read_only_msgs): + raise ValueError( + f"Batch mapper function {fn.__name__} tried to mutate a " + "zero-copy read-only batch. To be able to mutate the " + "batch, pass zero_copy_batch=False to map_batches(); " + "this will create a writable copy of the batch before " + "giving it to fn. To elide this copy, modify your mapper " + "function so it doesn't try to mutate its input." + ) from e + else: + raise e from None else: - raise e from None - else: - for out_batch in res: - _validate_batch_output(out_batch) - yield out_batch + for out_batch in res: + _validate_batch_output(out_batch) + yield out_batch return transform_fn diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index a511eb7255c3..a2209bcedec2 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -1,3 +1,4 @@ +import asyncio import itertools import math import os @@ -700,6 +701,34 @@ def fail_generator(batch): ).take() +def test_map_batches_async_generator(ray_start_regular_shared): + async def sleep_and_yield(i): + await asyncio.sleep(i) + return {"input": [i], "output": [2**i]} + + class AsyncActor: + def __init__(self): + pass + + async def __call__(self, batch): + tasks = [sleep_and_yield(i) for i in batch["id"]] + results = await asyncio.gather(*tasks) + for result in results: + yield result + + n = 5 + ds = ray.data.range(n, override_num_blocks=1) + ds = ds.map_batches(AsyncActor, batch_size=None, concurrency=1) + + start_t = time.time() + output = ds.take_all() + runtime = time.time() - start_t + assert runtime < sum(range(n)), runtime + + expected_output = [{"input": i, "output": 2**i} for i in range(n)] + assert output == expected_output, (output, expected_output) + + def test_map_batches_actors_preserves_order(shutdown_only): class UDFClass: def __call__(self, x): From dec32a15d6315256b3fb20e3b2314ab8e7fd930a Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Thu, 20 Jun 2024 21:17:09 -0700 Subject: [PATCH 2/7] initialize loop in init_fn Signed-off-by: Scott Lee --- .../data/_internal/planner/plan_udf_map_op.py | 73 +++++++++++-------- 1 file changed, 43 insertions(+), 30 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 8dbb24cbd71b..b02338b6599a 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -108,18 +108,21 @@ def _parse_op_fn(op: AbstractUDFMap): fn_constructor_args = op._fn_constructor_args or () fn_constructor_kwargs = op._fn_constructor_kwargs or {} - def init_fn(): - if ray.data._cached_fn is None: - ray.data._cached_cls = op_fn - ray.data._cached_fn = op_fn( - *fn_constructor_args, **fn_constructor_kwargs - ) - if inspect.isasyncgenfunction(op._fn.__call__): + def init_fn(): + if ray.data._cached_fn is None: + ray.data._cached_cls = op_fn + ray.data._cached_fn = op_fn( + *fn_constructor_args, **fn_constructor_kwargs + ) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + ray.data._cached_loop = loop async def fn(item: Any) -> Any: assert ray.data._cached_fn is not None assert ray.data._cached_cls == op_fn + assert ray.data._cached_loop is not None try: return ray.data._cached_fn(item, *fn_args, **fn_kwargs) @@ -129,6 +132,13 @@ async def fn(item: Any) -> Any: else: op_fn = make_callable_class_concurrent(op_fn) + def init_fn(): + if ray.data._cached_fn is None: + ray.data._cached_cls = op_fn + ray.data._cached_fn = op_fn( + *fn_constructor_args, **fn_constructor_kwargs + ) + def fn(item: Any) -> Any: assert ray.data._cached_fn is not None assert ray.data._cached_cls == op_fn @@ -215,42 +225,45 @@ def _generate_transform_fn_for_map_batches( def transform_fn( input_iterable: Iterable[DataBatch], _: TaskContext ) -> Iterable[DataBatch]: - # Use a queue to store results from async generator calls. - # In the main event loop, we will put results into this queue - # from async generator, and yield them from the queue as they - # become available. - result_queue = queue.Queue() + # Use a queue to store outputs from async generator calls. + # We will put output batches into this queue from async + # generators, and in the main event loop, yield them from + # the queue as they become available. + output_batch_queue = queue.Queue() async def process_batch(batch: DataBatch): output_batch_iterator = await fn(batch) # As soon as results become available from the async generator, # put them into the result queue so they can be yielded. - async for output_row in output_batch_iterator: - result_queue.put(output_row) + async for output_batch in output_batch_iterator: + output_batch_queue.put(output_batch) async def process_all_batches(): - tasks = [asyncio.create_task(process_batch(x)) for x in input_iterable] - for task in asyncio.as_completed(tasks): - await task - # Sentinel to indicate completion. - result_queue.put(None) + loop = ray.data._cached_loop + tasks = [loop.create_task(process_batch(x)) for x in input_iterable] - def run_event_loop(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(process_all_batches()) - loop.close() + ctx = ray.data.DataContext.get_current() + if ctx.execution_options.preserve_order: + for task in tasks: + await task() + else: + for task in asyncio.as_completed(tasks): + await task + # Sentinel to indicate completion. + output_batch_queue.put(None) - # Start the event loop in a new thread - executor = ThreadPoolExecutor(max_workers=1) - executor.submit(run_event_loop) + # Use the existing event loop to create and run + # Tasks to process each batch + loop = ray.data._cached_loop + loop.run_until_complete(process_all_batches()) # Yield results as they become available while True: - # `out_batch` here is a one-row batch which contains - # output from the async generator, corresponding to a + # `out_batch` here is a one-row output batch + # from the async generator, corresponding to a # single row from the input batch. - out_batch = result_queue.get() + out_batch = output_batch_queue.get() + # Exit when sentinel is received. if out_batch is None: break From 4292b93954cf27936b69275045d9bad8fc846e60 Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Thu, 20 Jun 2024 21:17:44 -0700 Subject: [PATCH 3/7] import Signed-off-by: Scott Lee --- python/ray/data/_internal/planner/plan_udf_map_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index b02338b6599a..03991ad543f7 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -2,7 +2,6 @@ import collections import inspect import queue -from concurrent.futures import ThreadPoolExecutor from types import GeneratorType from typing import Any, Callable, Iterable, Iterator, List, Optional @@ -109,6 +108,7 @@ def _parse_op_fn(op: AbstractUDFMap): fn_constructor_kwargs = op._fn_constructor_kwargs or {} if inspect.isasyncgenfunction(op._fn.__call__): + def init_fn(): if ray.data._cached_fn is None: ray.data._cached_cls = op_fn From de61a0d6223b941e33bd95fb844ba94c2e5206fa Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Thu, 20 Jun 2024 21:22:02 -0700 Subject: [PATCH 4/7] avoid using asyncio.gather in test Signed-off-by: Scott Lee --- python/ray/data/tests/test_map.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index a2209bcedec2..e88784c29e25 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -712,9 +712,9 @@ def __init__(self): async def __call__(self, batch): tasks = [sleep_and_yield(i) for i in batch["id"]] - results = await asyncio.gather(*tasks) - for result in results: - yield result + tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]] + for task in tasks: + yield await task n = 5 ds = ray.data.range(n, override_num_blocks=1) From d217c1cb37e6e63aba5b7cae22926274671811cb Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Fri, 21 Jun 2024 21:07:24 -0700 Subject: [PATCH 5/7] address comments Signed-off-by: Scott Lee --- .../data/_internal/planner/plan_udf_map_op.py | 182 +++++++++++------- python/ray/data/tests/test_map.py | 61 +++--- 2 files changed, 149 insertions(+), 94 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 03991ad543f7..95a7cf09a2c3 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -2,6 +2,7 @@ import collections import inspect import queue +from threading import Thread from types import GeneratorType from typing import Any, Callable, Iterable, Iterator, List, Optional @@ -46,6 +47,22 @@ from ray.util.rpdb import _is_ray_debugger_enabled +class _MapActorContext: + def __init__( + self, + cached_cls: UserDefinedFunction, + cached_fn: Callable[[Any], Any], + cached_loop: Optional[asyncio.AbstractEventLoop] = None, + cached_asyncio_thread: Optional[Thread] = None, + ): + self.cached_cls = cached_cls + self.cached_fn = cached_fn + + # Only used for callable class with async generator `__call__` method. + self.cached_loop = cached_loop + self.cached_asyncio_thread = cached_asyncio_thread + + def plan_udf_map_op( op: AbstractUDFMap, physical_children: List[PhysicalOperator] ) -> MapOperator: @@ -108,24 +125,38 @@ def _parse_op_fn(op: AbstractUDFMap): fn_constructor_kwargs = op._fn_constructor_kwargs or {} if inspect.isasyncgenfunction(op._fn.__call__): - + # TODO(scottjlee): (1) support non-generator async functions + # (2) make the map actor async def init_fn(): - if ray.data._cached_fn is None: - ray.data._cached_cls = op_fn - ray.data._cached_fn = op_fn( - *fn_constructor_args, **fn_constructor_kwargs - ) + if ray.data._map_actor_context is None: loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - ray.data._cached_loop = loop + + def run_loop(): + asyncio.set_event_loop(loop) + loop.run_forever() + + thread = Thread(target=run_loop) + thread.start() + + ray.data._map_actor_context = _MapActorContext( + cached_cls=op_fn, + cached_fn=op_fn( + *fn_constructor_args, + **fn_constructor_kwargs, + ), + cached_loop=loop, + cached_asyncio_thread=thread, + ) async def fn(item: Any) -> Any: - assert ray.data._cached_fn is not None - assert ray.data._cached_cls == op_fn - assert ray.data._cached_loop is not None + assert ray.data._map_actor_context is not None try: - return ray.data._cached_fn(item, *fn_args, **fn_kwargs) + return ray.data._map_actor_context.cached_fn( + item, + *fn_args, + **fn_kwargs, + ) except Exception as e: _handle_debugger_exception(e) @@ -133,17 +164,23 @@ async def fn(item: Any) -> Any: op_fn = make_callable_class_concurrent(op_fn) def init_fn(): - if ray.data._cached_fn is None: - ray.data._cached_cls = op_fn - ray.data._cached_fn = op_fn( - *fn_constructor_args, **fn_constructor_kwargs + if ray.data._map_actor_context is None: + ray.data._map_actor_context = _MapActorContext( + cached_cls=op_fn, + cached_fn=op_fn( + *fn_constructor_args, + **fn_constructor_kwargs, + ), ) def fn(item: Any) -> Any: - assert ray.data._cached_fn is not None - assert ray.data._cached_cls == op_fn + assert ray.data._map_actor_context is not None try: - return ray.data._cached_fn(item, *fn_args, **fn_kwargs) + return ray.data._map_actor_context.cached_fn( + item, + *fn_args, + **fn_kwargs, + ) except Exception as e: _handle_debugger_exception(e) @@ -222,53 +259,7 @@ def _generate_transform_fn_for_map_batches( ) -> MapTransformCallable[DataBatch, DataBatch]: if inspect.iscoroutinefunction(fn): # UDF is a callable class with async generator `__call__` method. - def transform_fn( - input_iterable: Iterable[DataBatch], _: TaskContext - ) -> Iterable[DataBatch]: - # Use a queue to store outputs from async generator calls. - # We will put output batches into this queue from async - # generators, and in the main event loop, yield them from - # the queue as they become available. - output_batch_queue = queue.Queue() - - async def process_batch(batch: DataBatch): - output_batch_iterator = await fn(batch) - # As soon as results become available from the async generator, - # put them into the result queue so they can be yielded. - async for output_batch in output_batch_iterator: - output_batch_queue.put(output_batch) - - async def process_all_batches(): - loop = ray.data._cached_loop - tasks = [loop.create_task(process_batch(x)) for x in input_iterable] - - ctx = ray.data.DataContext.get_current() - if ctx.execution_options.preserve_order: - for task in tasks: - await task() - else: - for task in asyncio.as_completed(tasks): - await task - # Sentinel to indicate completion. - output_batch_queue.put(None) - - # Use the existing event loop to create and run - # Tasks to process each batch - loop = ray.data._cached_loop - loop.run_until_complete(process_all_batches()) - - # Yield results as they become available - while True: - # `out_batch` here is a one-row output batch - # from the async generator, corresponding to a - # single row from the input batch. - out_batch = output_batch_queue.get() - - # Exit when sentinel is received. - if out_batch is None: - break - _validate_batch_output(out_batch) - yield out_batch + transform_fn = _generate_transform_fn_for_async_map_batches(fn) else: @@ -315,6 +306,65 @@ def transform_fn( return transform_fn +def _generate_transform_fn_for_async_map_batches( + fn: UserDefinedFunction, +) -> MapTransformCallable[DataBatch, DataBatch]: + class OutputQueueSentinel: + """Sentinel to indicate completion of async generator.""" + + pass + + def transform_fn( + input_iterable: Iterable[DataBatch], _: TaskContext + ) -> Iterable[DataBatch]: + # Use a queue to store outputs from async generator calls. + # We will put output batches into this queue from async + # generators, and in the main event loop, yield them from + # the queue as they become available. + output_batch_queue = queue.Queue() + + async def process_batch(batch: DataBatch): + output_batch_iterator = await fn(batch) + # As soon as results become available from the async generator, + # put them into the result queue so they can be yielded. + async for output_batch in output_batch_iterator: + output_batch_queue.put(output_batch) + + async def process_all_batches(): + loop = ray.data._map_actor_context.cached_loop + tasks = [loop.create_task(process_batch(x)) for x in input_iterable] + + ctx = ray.data.DataContext.get_current() + if ctx.execution_options.preserve_order: + for task in tasks: + await task() + else: + for task in asyncio.as_completed(tasks): + await task + # Sentinel to indicate completion. + output_batch_queue.put(OutputQueueSentinel()) + + # Use the existing event loop to create and run + # Tasks to process each batch + loop = ray.data._map_actor_context.cached_loop + future = asyncio.run_coroutine_threadsafe(process_all_batches(), loop) + + # Yield results as they become available + while not future.done(): + # `out_batch` here is a one-row output batch + # from the async generator, corresponding to a + # single row from the input batch. + out_batch = output_batch_queue.get() + + # Exit when sentinel is received. + if isinstance(out_batch, OutputQueueSentinel): + break + _validate_batch_output(out_batch) + yield out_batch + + return transform_fn + + def _validate_row_output(item): if not isinstance(item, collections.abc.Mapping): raise ValueError( diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index e88784c29e25..5e723075b2d6 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -701,34 +701,6 @@ def fail_generator(batch): ).take() -def test_map_batches_async_generator(ray_start_regular_shared): - async def sleep_and_yield(i): - await asyncio.sleep(i) - return {"input": [i], "output": [2**i]} - - class AsyncActor: - def __init__(self): - pass - - async def __call__(self, batch): - tasks = [sleep_and_yield(i) for i in batch["id"]] - tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]] - for task in tasks: - yield await task - - n = 5 - ds = ray.data.range(n, override_num_blocks=1) - ds = ds.map_batches(AsyncActor, batch_size=None, concurrency=1) - - start_t = time.time() - output = ds.take_all() - runtime = time.time() - start_t - assert runtime < sum(range(n)), runtime - - expected_output = [{"input": i, "output": 2**i} for i in range(n)] - assert output == expected_output, (output, expected_output) - - def test_map_batches_actors_preserves_order(shutdown_only): class UDFClass: def __call__(self, x): @@ -1086,6 +1058,39 @@ def test_nonserializable_map_batches(shutdown_only): x.map_batches(lambda _: lock).take(1) +def test_map_batches_async_generator(shutdown_only): + ray.shutdown() + ray.init(num_cpus=10) + + async def sleep_and_yield(i): + print("sleep", i) + await asyncio.sleep(i % 5) + print("yield", i) + return {"input": [i], "output": [2**i]} + + class AsyncActor: + def __init__(self): + pass + + async def __call__(self, batch): + tasks = [sleep_and_yield(i) for i in batch["id"]] + tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]] + for task in tasks: + yield await task + + n = 10 + ds = ray.data.range(n, override_num_blocks=2) + ds = ds.map_batches(AsyncActor, batch_size=1, concurrency=1, max_concurrency=2) + + start_t = time.time() + output = ds.take_all() + runtime = time.time() - start_t + assert runtime < sum(range(n)), runtime + + expected_output = [{"input": i, "output": 2**i} for i in range(n)] + assert output == expected_output, (output, expected_output) + + if __name__ == "__main__": import sys From 9175edd3d1c2886c3db637852db1442e35197177 Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Fri, 21 Jun 2024 21:51:14 -0700 Subject: [PATCH 6/7] add _map_actor_context Signed-off-by: Scott Lee --- python/ray/data/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/data/__init__.py b/python/ray/data/__init__.py index 901a282775c5..9a27c0b6585d 100644 --- a/python/ray/data/__init__.py +++ b/python/ray/data/__init__.py @@ -62,8 +62,7 @@ # Module-level cached global functions for callable classes. It needs to be defined here # since it has to be process-global across cloudpickled funcs. -_cached_fn = None -_cached_cls = None +_map_actor_context = None configure_logging() From c5af4ae163382e0f46469260c2995c099bd70ceb Mon Sep 17 00:00:00 2001 From: sjl Date: Tue, 25 Jun 2024 20:53:30 +0000 Subject: [PATCH 7/7] comments --- .../data/_internal/planner/plan_udf_map_op.py | 111 ++++++++---------- python/ray/data/tests/test_map.py | 2 +- 2 files changed, 51 insertions(+), 62 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 95a7cf09a2c3..d1be38f8ac52 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -50,17 +50,31 @@ class _MapActorContext: def __init__( self, - cached_cls: UserDefinedFunction, - cached_fn: Callable[[Any], Any], - cached_loop: Optional[asyncio.AbstractEventLoop] = None, - cached_asyncio_thread: Optional[Thread] = None, + udf_map_cls: UserDefinedFunction, + udf_map_fn: Callable[[Any], Any], + is_async: bool, ): - self.cached_cls = cached_cls - self.cached_fn = cached_fn + self.udf_map_cls = udf_map_cls + self.udf_map_fn = udf_map_fn + self.is_async = is_async + self.udf_map_asyncio_loop = None + self.udf_map_asyncio_thread = None + if is_async: + self._init_async() + + def _init_async(self): # Only used for callable class with async generator `__call__` method. - self.cached_loop = cached_loop - self.cached_asyncio_thread = cached_asyncio_thread + loop = asyncio.new_event_loop() + + def run_loop(): + asyncio.set_event_loop(loop) + loop.run_forever() + + thread = Thread(target=run_loop) + thread.start() + self.udf_map_asyncio_loop = loop + self.udf_map_asyncio_thread = thread def plan_udf_map_op( @@ -124,35 +138,32 @@ def _parse_op_fn(op: AbstractUDFMap): fn_constructor_args = op._fn_constructor_args or () fn_constructor_kwargs = op._fn_constructor_kwargs or {} - if inspect.isasyncgenfunction(op._fn.__call__): - # TODO(scottjlee): (1) support non-generator async functions - # (2) make the map actor async - def init_fn(): - if ray.data._map_actor_context is None: - loop = asyncio.new_event_loop() - - def run_loop(): - asyncio.set_event_loop(loop) - loop.run_forever() - - thread = Thread(target=run_loop) - thread.start() - - ray.data._map_actor_context = _MapActorContext( - cached_cls=op_fn, - cached_fn=op_fn( - *fn_constructor_args, - **fn_constructor_kwargs, - ), - cached_loop=loop, - cached_asyncio_thread=thread, - ) + is_async_gen = inspect.isasyncgenfunction(op._fn.__call__) + + # TODO(scottjlee): (1) support non-generator async functions + # (2) make the map actor async + if not is_async_gen: + op_fn = make_callable_class_concurrent(op_fn) + + def init_fn(): + if ray.data._map_actor_context is None: + ray.data._map_actor_context = _MapActorContext( + udf_map_cls=op_fn, + udf_map_fn=op_fn( + *fn_constructor_args, + **fn_constructor_kwargs, + ), + is_async=is_async_gen, + ) + + if is_async_gen: async def fn(item: Any) -> Any: assert ray.data._map_actor_context is not None + assert ray.data._map_actor_context.is_async try: - return ray.data._map_actor_context.cached_fn( + return ray.data._map_actor_context.udf_map_fn( item, *fn_args, **fn_kwargs, @@ -161,22 +172,12 @@ async def fn(item: Any) -> Any: _handle_debugger_exception(e) else: - op_fn = make_callable_class_concurrent(op_fn) - - def init_fn(): - if ray.data._map_actor_context is None: - ray.data._map_actor_context = _MapActorContext( - cached_cls=op_fn, - cached_fn=op_fn( - *fn_constructor_args, - **fn_constructor_kwargs, - ), - ) def fn(item: Any) -> Any: assert ray.data._map_actor_context is not None + assert not ray.data._map_actor_context.is_async try: - return ray.data._map_actor_context.cached_fn( + return ray.data._map_actor_context.udf_map_fn( item, *fn_args, **fn_kwargs, @@ -309,11 +310,6 @@ def transform_fn( def _generate_transform_fn_for_async_map_batches( fn: UserDefinedFunction, ) -> MapTransformCallable[DataBatch, DataBatch]: - class OutputQueueSentinel: - """Sentinel to indicate completion of async generator.""" - - pass - def transform_fn( input_iterable: Iterable[DataBatch], _: TaskContext ) -> Iterable[DataBatch]: @@ -331,7 +327,7 @@ async def process_batch(batch: DataBatch): output_batch_queue.put(output_batch) async def process_all_batches(): - loop = ray.data._map_actor_context.cached_loop + loop = ray.data._map_actor_context.udf_map_asyncio_loop tasks = [loop.create_task(process_batch(x)) for x in input_iterable] ctx = ray.data.DataContext.get_current() @@ -341,24 +337,17 @@ async def process_all_batches(): else: for task in asyncio.as_completed(tasks): await task - # Sentinel to indicate completion. - output_batch_queue.put(OutputQueueSentinel()) - # Use the existing event loop to create and run - # Tasks to process each batch - loop = ray.data._map_actor_context.cached_loop + # Use the existing event loop to create and run Tasks to process each batch + loop = ray.data._map_actor_context.udf_map_asyncio_loop future = asyncio.run_coroutine_threadsafe(process_all_batches(), loop) - # Yield results as they become available + # Yield results as they become available. while not future.done(): - # `out_batch` here is a one-row output batch + # Here, `out_batch` is a one-row output batch # from the async generator, corresponding to a # single row from the input batch. out_batch = output_batch_queue.get() - - # Exit when sentinel is received. - if isinstance(out_batch, OutputQueueSentinel): - break _validate_batch_output(out_batch) yield out_batch diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 5e723075b2d6..337bacc6f77d 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -1073,13 +1073,13 @@ def __init__(self): pass async def __call__(self, batch): - tasks = [sleep_and_yield(i) for i in batch["id"]] tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]] for task in tasks: yield await task n = 10 ds = ray.data.range(n, override_num_blocks=2) + ds = ds.map(lambda x: x) ds = ds.map_batches(AsyncActor, batch_size=1, concurrency=1, max_concurrency=2) start_t = time.time()