-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Postpone reader.get_read_tasks
until execution
#38373
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Default file metadata shuffler class to use. | ||
DEFAULT_FILE_METADATA_SHUFFLER = ( | ||
"ray.data.datasource.file_metadata_shuffler.SequentialFileMetadataShuffler" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,10 @@ def _plan_read_op(op: Read) -> PhysicalOperator: | |
""" | ||
|
||
def get_input_data() -> List[RefBundle]: | ||
read_tasks = op._read_tasks | ||
read_tasks = op._reader.get_read_tasks(op._parallelism) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@amogkam - this is used in planner here. |
||
if op._additional_split_factor is not None: | ||
for r in read_tasks: | ||
r._set_additional_split_factor(op._additional_split_factor) | ||
return [ | ||
RefBundle( | ||
[ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -513,3 +513,20 @@ def unify_block_metadata_schema( | |
# return the first schema. | ||
return schemas_to_unify[0] | ||
return None | ||
|
||
|
||
def get_attribute_from_class_name(class_name: str) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checked online, it looks like it's the recommended way to do it - https://stackoverflow.com/questions/452969/does-python-have-an-equivalent-to-java-class-forname . |
||
"""Get Python attribute from the provided class name. | ||
|
||
The caller needs to make sure the provided class name includes | ||
full module name, and can be imported successfully. | ||
""" | ||
from importlib import import_module | ||
|
||
paths = class_name.split(".") | ||
if len(paths) < 2: | ||
raise ValueError(f"Cannot create object from {class_name}.") | ||
|
||
module_name = ".".join(paths[:-1]) | ||
attribute_name = paths[-1] | ||
return getattr(import_module(module_name), attribute_name) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Any, Dict, List, Tuple | ||
|
||
|
||
class FileMetadataShuffler: | ||
"""Abstract class for file metadata shuffler. | ||
|
||
Shufflers live on the driver side of the Dataset only. | ||
""" | ||
|
||
def __init__(self, reader_args: Dict[str, Any]): | ||
self._reader_args = reader_args | ||
|
||
def shuffle_files( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit; similar to class name update, should we update the method name + docstrings to something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm I guess it's probably fine? given we already have |
||
self, | ||
paths: List[str], | ||
file_sizes: List[int], | ||
) -> Tuple[List[str], List[int]]: | ||
"""Shuffle files in the given paths and sizes. | ||
|
||
Args: | ||
paths: The file paths to shuffle. | ||
file_sizes: The size of file paths, corresponding to `paths`. | ||
|
||
Returns: | ||
The file paths and their size after shuffling. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class SequentialFileMetadataShuffler(FileMetadataShuffler): | ||
def shuffle_files( | ||
self, | ||
paths: List[str], | ||
file_sizes: List[int], | ||
) -> Tuple[List[str], List[int]]: | ||
"""Return files in the given paths and sizes sequentially.""" | ||
return (paths, file_sizes) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,7 +58,6 @@ | |
ParquetMetadataProvider, | ||
PathPartitionFilter, | ||
RangeDatasource, | ||
ReadTask, | ||
SQLDatasource, | ||
TextDatasource, | ||
TFRecordDatasource, | ||
|
@@ -70,6 +69,7 @@ | |
get_parquet_bulk_metadata_provider, | ||
get_parquet_metadata_provider, | ||
) | ||
from ray.data.datasource.datasource import Reader | ||
from ray.data.datasource.file_based_datasource import ( | ||
_unwrap_arrow_serialization_workaround, | ||
_wrap_arrow_serialization_workaround, | ||
|
@@ -351,8 +351,8 @@ def read_datasource( | |
requested_parallelism, | ||
min_safe_parallelism, | ||
inmemory_size, | ||
read_tasks, | ||
) = _get_read_tasks(datasource, ctx, cur_pg, parallelism, local_uri, read_args) | ||
reader, | ||
) = _get_reader(datasource, ctx, cur_pg, parallelism, local_uri, read_args) | ||
else: | ||
# Prepare read in a remote task at same node. | ||
# NOTE: in Ray client mode, this is expected to be run on head node. | ||
|
@@ -361,17 +361,12 @@ def read_datasource( | |
ray.get_runtime_context().get_node_id(), | ||
soft=False, | ||
) | ||
get_read_tasks = cached_remote_fn( | ||
_get_read_tasks, retry_exceptions=False, num_cpus=0 | ||
get_reader = cached_remote_fn( | ||
_get_reader, retry_exceptions=False, num_cpus=0 | ||
).options(scheduling_strategy=scheduling_strategy) | ||
|
||
( | ||
requested_parallelism, | ||
min_safe_parallelism, | ||
inmemory_size, | ||
read_tasks, | ||
) = ray.get( | ||
get_read_tasks.remote( | ||
(requested_parallelism, min_safe_parallelism, inmemory_size, reader,) = ray.get( | ||
get_reader.remote( | ||
datasource, | ||
ctx, | ||
cur_pg, | ||
|
@@ -381,9 +376,14 @@ def read_datasource( | |
) | ||
) | ||
|
||
# TODO(hchen/chengsu): Remove the duplicated get_read_tasks call here after | ||
# removing LazyBlockList code path. | ||
read_tasks = reader.get_read_tasks(requested_parallelism) | ||
|
||
# Compute the number of blocks the read will return. If the number of blocks is | ||
# expected to be less than the requested parallelism, boost the number of blocks | ||
# by adding an additional split into `k` pieces to each read task. | ||
additional_split_factor = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One small issue. For the new code path, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, that's the major weird thing here. The other code path There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's fine. can you leave a todo here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SG, added. |
||
if read_tasks: | ||
if inmemory_size: | ||
expected_block_size = inmemory_size / len(read_tasks) | ||
|
@@ -407,6 +407,7 @@ def read_datasource( | |
for r in read_tasks: | ||
r._set_additional_split_factor(k) | ||
estimated_num_blocks = estimated_num_blocks * k | ||
additional_split_factor = k | ||
logger.debug("Estimated num output blocks {estimated_num_blocks}") | ||
else: | ||
estimated_num_blocks = 0 | ||
|
@@ -437,9 +438,13 @@ def read_datasource( | |
) | ||
block_list._estimated_num_blocks = estimated_num_blocks | ||
|
||
# TODO(hchen): move _get_read_tasks and related code to the Read physical operator, | ||
# after removing LazyBlockList code path. | ||
read_op = Read(datasource, read_tasks, estimated_num_blocks, ray_remote_args) | ||
read_op = Read( | ||
datasource, | ||
reader, | ||
requested_parallelism, | ||
additional_split_factor, | ||
ray_remote_args, | ||
) | ||
logical_plan = LogicalPlan(read_op) | ||
|
||
return Dataset( | ||
|
@@ -2281,15 +2286,15 @@ def from_torch( | |
return from_items(list(dataset)) | ||
|
||
|
||
def _get_read_tasks( | ||
def _get_reader( | ||
ds: Datasource, | ||
ctx: DataContext, | ||
cur_pg: Optional[PlacementGroup], | ||
parallelism: int, | ||
local_uri: bool, | ||
kwargs: dict, | ||
) -> Tuple[int, int, Optional[int], List[ReadTask]]: | ||
"""Generates read tasks. | ||
) -> Tuple[int, int, Optional[int], Reader]: | ||
"""Generates reader. | ||
|
||
Args: | ||
ds: Datasource to read from. | ||
|
@@ -2300,7 +2305,7 @@ def _get_read_tasks( | |
|
||
Returns: | ||
Request parallelism from the datasource, the min safe parallelism to avoid | ||
OOM, the estimated inmemory data size, and list of read tasks generated. | ||
OOM, the estimated inmemory data size, and the reader generated. | ||
""" | ||
kwargs = _unwrap_arrow_serialization_workaround(kwargs) | ||
if local_uri: | ||
|
@@ -2314,7 +2319,7 @@ def _get_read_tasks( | |
requested_parallelism, | ||
min_safe_parallelism, | ||
mem_size, | ||
reader.get_read_tasks(requested_parallelism), | ||
reader, | ||
) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1413,12 +1413,6 @@ def test_unsupported_pyarrow_versions_check_disabled( | |
except ImportError as e: | ||
pytest.fail(f"_check_pyarrow_version failed unexpectedly: {e}") | ||
|
||
# Test read_parquet. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pyarrow 5 does not support pickling the Parquet reader class. Given we do not support Pyarrow 5, remove the test code here. Already verified for Pyarrow 6+, it's not an issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we just remove the PyArrow 5 CI? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already removed. CI only tests 6 and 12. This test manually install 5. |
||
try: | ||
ray.data.read_parquet("example://iris.parquet").take_all() | ||
except ImportError as e: | ||
pytest.fail(f"_check_pyarrow_version failed unexpectedly: {e}") | ||
|
||
# Test from_numpy (we use Arrow for representing the tensors). | ||
try: | ||
ray.data.from_numpy(np.arange(12).reshape((3, 2, 2))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does
self._reader
get used here?