Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Postpone reader.get_read_tasks until execution #38373

Merged
merged 6 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/ray/data/_default_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Default file metadata shuffler class to use.
DEFAULT_FILE_METADATA_SHUFFLER = (
"ray.data.datasource.file_metadata_shuffler.SequentialFileMetadataShuffler"
)
22 changes: 13 additions & 9 deletions python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional

from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data.datasource.datasource import Datasource, ReadTask
from ray.data.datasource.datasource import Datasource, Reader


class Read(AbstractMap):
Expand All @@ -10,18 +10,22 @@ class Read(AbstractMap):
def __init__(
self,
datasource: Datasource,
read_tasks: List[ReadTask],
estimated_num_blocks: int,
reader: Reader,
parallelism: int,
additional_split_factor: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
if len(read_tasks) == estimated_num_blocks:
if additional_split_factor is None:
suffix = ""
self._estimated_num_blocks = parallelism
else:
suffix = f"->SplitBlocks({int(estimated_num_blocks / len(read_tasks))})"
suffix = f"->SplitBlocks({additional_split_factor})"
self._estimated_num_blocks = parallelism * additional_split_factor
super().__init__(f"Read{datasource.get_name()}{suffix}", None, ray_remote_args)
self._datasource = datasource
self._estimated_num_blocks = estimated_num_blocks
self._read_tasks = read_tasks
self._reader = reader
self._parallelism = parallelism
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does self._reader get used here?

self._additional_split_factor = additional_split_factor

def fusable(self) -> bool:
"""Whether this should be fused with downstream operators.
Expand All @@ -30,4 +34,4 @@ def fusable(self) -> bool:
as fusion would prevent the blocks from being dispatched to multiple processes
for parallel processing in downstream operators.
"""
return self._estimated_num_blocks == len(self._read_tasks)
return self._parallelism == self._estimated_num_blocks
5 changes: 4 additions & 1 deletion python/ray/data/_internal/planner/plan_read_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def _plan_read_op(op: Read) -> PhysicalOperator:
"""

def get_input_data() -> List[RefBundle]:
read_tasks = op._read_tasks
read_tasks = op._reader.get_read_tasks(op._parallelism)
Copy link
Contributor Author

@c21 c21 Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does self._reader get used here?

@amogkam - this is used in planner here.

if op._additional_split_factor is not None:
for r in read_tasks:
r._set_additional_split_factor(op._additional_split_factor)
return [
RefBundle(
[
Expand Down
17 changes: 17 additions & 0 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,20 @@ def unify_block_metadata_schema(
# return the first schema.
return schemas_to_unify[0]
return None


def get_attribute_from_class_name(class_name: str) -> Any:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked online, it looks like it's the recommended way to do it - https://stackoverflow.com/questions/452969/does-python-have-an-equivalent-to-java-class-forname .

"""Get Python attribute from the provided class name.

The caller needs to make sure the provided class name includes
full module name, and can be imported successfully.
"""
from importlib import import_module

paths = class_name.split(".")
if len(paths) < 2:
raise ValueError(f"Cannot create object from {class_name}.")

module_name = ".".join(paths[:-1])
attribute_name = paths[-1]
return getattr(import_module(module_name), attribute_name)
4 changes: 4 additions & 0 deletions python/ray/data/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ray
from ray._private.ray_constants import env_integer
from ray.data._default_config import DEFAULT_FILE_METADATA_SHUFFLER
from ray.util.annotations import DeveloperAPI
from ray.util.scheduling_strategies import SchedulingStrategyT

Expand Down Expand Up @@ -171,6 +172,7 @@ def __init__(
use_ray_tqdm: bool,
use_legacy_iter_batches: bool,
enable_progress_bars: bool,
file_metadata_shuffler: str,
):
"""Private constructor (use get_current() instead)."""
self.target_max_block_size = target_max_block_size
Expand Down Expand Up @@ -204,6 +206,7 @@ def __init__(
self.use_ray_tqdm = use_ray_tqdm
self.use_legacy_iter_batches = use_legacy_iter_batches
self.enable_progress_bars = enable_progress_bars
self.file_metadata_shuffler = file_metadata_shuffler

@staticmethod
def get_current() -> "DataContext":
Expand Down Expand Up @@ -253,6 +256,7 @@ def get_current() -> "DataContext":
use_ray_tqdm=DEFAULT_USE_RAY_TQDM,
use_legacy_iter_batches=DEFAULT_USE_LEGACY_ITER_BATCHES,
enable_progress_bars=DEFAULT_ENABLE_PROGRESS_BARS,
file_metadata_shuffler=DEFAULT_FILE_METADATA_SHUFFLER,
)

return _default_context
Expand Down
16 changes: 14 additions & 2 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme
from ray.data._internal.util import (
_check_pyarrow_version,
_resolve_custom_scheme,
get_attribute_from_class_name,
)
from ray.data.block import Block, BlockAccessor
from ray.data.context import DataContext
from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult
Expand Down Expand Up @@ -490,6 +494,10 @@ def __init__(
"'partition_filter' field is set properly."
)

ctx = DataContext.get_current()
shuffler_class = get_attribute_from_class_name(ctx.file_metadata_shuffler)
self._file_metadata_shuffler = shuffler_class(self._reader_args)

def estimate_inmemory_data_size(self) -> Optional[int]:
total_size = 0
for sz in self._file_sizes:
Expand All @@ -505,7 +513,11 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
reader_args = self._reader_args
partitioning = self._partitioning

paths, file_sizes = self._paths, self._file_sizes
paths, file_sizes = self._file_metadata_shuffler.shuffle_files(
self._paths,
self._file_sizes,
)

read_stream = self._delegate._read_stream
filesystem = _wrap_s3_serialization_workaround(self._filesystem)

Expand Down
37 changes: 37 additions & 0 deletions python/ray/data/datasource/file_metadata_shuffler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, Dict, List, Tuple


class FileMetadataShuffler:
"""Abstract class for file metadata shuffler.

Shufflers live on the driver side of the Dataset only.
"""

def __init__(self, reader_args: Dict[str, Any]):
self._reader_args = reader_args

def shuffle_files(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit; similar to class name update, should we update the method name + docstrings to something like shuffle_file_metadatas or shuffle_metadatas? Is it possibly confusing with shuffling files instead of their metadata?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I guess it's probably fine? given we already have FileMetadataShuffler as class name.

self,
paths: List[str],
file_sizes: List[int],
) -> Tuple[List[str], List[int]]:
"""Shuffle files in the given paths and sizes.

Args:
paths: The file paths to shuffle.
file_sizes: The size of file paths, corresponding to `paths`.

Returns:
The file paths and their size after shuffling.
"""
raise NotImplementedError


class SequentialFileMetadataShuffler(FileMetadataShuffler):
def shuffle_files(
self,
paths: List[str],
file_sizes: List[int],
) -> Tuple[List[str], List[int]]:
"""Return files in the given paths and sizes sequentially."""
return (paths, file_sizes)
24 changes: 15 additions & 9 deletions python/ray/data/datasource/mongo_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def __init__(
schema: Optional["pymongoarrow.api.Schema"] = None,
**mongo_args,
):
import pymongo

self._uri = uri
self._database = database
self._collection = collection
Expand All @@ -87,13 +85,8 @@ def __init__(
# If pipeline is unspecified, read the entire collection.
if not pipeline:
self._pipeline = [{"$match": {"_id": {"$exists": "true"}}}]

self._client = pymongo.MongoClient(uri)
_validate_database_collection_exist(self._client, database, collection)

self._avg_obj_size = self._client[database].command("collstats", collection)[
"avgObjSize"
]
# Initialize Mongo client lazily later when creating read tasks.
self._client = None

def estimate_inmemory_data_size(self) -> Optional[int]:
# TODO(jian): Add memory size estimation to improve auto-tune of parallelism.
Expand All @@ -104,9 +97,22 @@ def _get_match_query(self, pipeline: List[Dict]) -> Dict:
return {}
return pipeline[0]["$match"]

def _get_or_create_client(self):
import pymongo

if self._client is None:
self._client = pymongo.MongoClient(self._uri)
_validate_database_collection_exist(
self._client, self._database, self._collection
)
self._avg_obj_size = self._client[self._database].command(
"collstats", self._collection
)["avgObjSize"]

def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
from bson.objectid import ObjectId

self._get_or_create_client()
coll = self._client[self._database][self._collection]
match_query = self._get_match_query(self._pipeline)
partitions_ids = list(
Expand Down
45 changes: 25 additions & 20 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
ParquetMetadataProvider,
PathPartitionFilter,
RangeDatasource,
ReadTask,
SQLDatasource,
TextDatasource,
TFRecordDatasource,
Expand All @@ -70,6 +69,7 @@
get_parquet_bulk_metadata_provider,
get_parquet_metadata_provider,
)
from ray.data.datasource.datasource import Reader
from ray.data.datasource.file_based_datasource import (
_unwrap_arrow_serialization_workaround,
_wrap_arrow_serialization_workaround,
Expand Down Expand Up @@ -351,8 +351,8 @@ def read_datasource(
requested_parallelism,
min_safe_parallelism,
inmemory_size,
read_tasks,
) = _get_read_tasks(datasource, ctx, cur_pg, parallelism, local_uri, read_args)
reader,
) = _get_reader(datasource, ctx, cur_pg, parallelism, local_uri, read_args)
else:
# Prepare read in a remote task at same node.
# NOTE: in Ray client mode, this is expected to be run on head node.
Expand All @@ -361,17 +361,12 @@ def read_datasource(
ray.get_runtime_context().get_node_id(),
soft=False,
)
get_read_tasks = cached_remote_fn(
_get_read_tasks, retry_exceptions=False, num_cpus=0
get_reader = cached_remote_fn(
_get_reader, retry_exceptions=False, num_cpus=0
).options(scheduling_strategy=scheduling_strategy)

(
requested_parallelism,
min_safe_parallelism,
inmemory_size,
read_tasks,
) = ray.get(
get_read_tasks.remote(
(requested_parallelism, min_safe_parallelism, inmemory_size, reader,) = ray.get(
get_reader.remote(
datasource,
ctx,
cur_pg,
Expand All @@ -381,9 +376,14 @@ def read_datasource(
)
)

# TODO(hchen/chengsu): Remove the duplicated get_read_tasks call here after
# removing LazyBlockList code path.
read_tasks = reader.get_read_tasks(requested_parallelism)

# Compute the number of blocks the read will return. If the number of blocks is
# expected to be less than the requested parallelism, boost the number of blocks
# by adding an additional split into `k` pieces to each read task.
additional_split_factor = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small issue. For the new code path, get_read_tasks has actually created the read tasks. But the read tasks are only used for the following calculations and then discarded.
I'm wondering if we can move the following calculation code to the reader.
So here we only create the reader, but not the read tasks. Also we don't need to expose this additional_split_factor to the operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's the major weird thing here. The other code path LazyBlockList has all code paths depending on List[ReadTask], so I don't spend time to refactoring LazyBlockList. Shall we just delete DatasetPipeline and LazyBlockList/BlockList during 2.8?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's fine. can you leave a todo here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG, added.

if read_tasks:
if inmemory_size:
expected_block_size = inmemory_size / len(read_tasks)
Expand All @@ -407,6 +407,7 @@ def read_datasource(
for r in read_tasks:
r._set_additional_split_factor(k)
estimated_num_blocks = estimated_num_blocks * k
additional_split_factor = k
logger.debug("Estimated num output blocks {estimated_num_blocks}")
else:
estimated_num_blocks = 0
Expand Down Expand Up @@ -437,9 +438,13 @@ def read_datasource(
)
block_list._estimated_num_blocks = estimated_num_blocks

# TODO(hchen): move _get_read_tasks and related code to the Read physical operator,
# after removing LazyBlockList code path.
read_op = Read(datasource, read_tasks, estimated_num_blocks, ray_remote_args)
read_op = Read(
datasource,
reader,
requested_parallelism,
additional_split_factor,
ray_remote_args,
)
logical_plan = LogicalPlan(read_op)

return Dataset(
Expand Down Expand Up @@ -2281,15 +2286,15 @@ def from_torch(
return from_items(list(dataset))


def _get_read_tasks(
def _get_reader(
ds: Datasource,
ctx: DataContext,
cur_pg: Optional[PlacementGroup],
parallelism: int,
local_uri: bool,
kwargs: dict,
) -> Tuple[int, int, Optional[int], List[ReadTask]]:
"""Generates read tasks.
) -> Tuple[int, int, Optional[int], Reader]:
"""Generates reader.

Args:
ds: Datasource to read from.
Expand All @@ -2300,7 +2305,7 @@ def _get_read_tasks(

Returns:
Request parallelism from the datasource, the min safe parallelism to avoid
OOM, the estimated inmemory data size, and list of read tasks generated.
OOM, the estimated inmemory data size, and the reader generated.
"""
kwargs = _unwrap_arrow_serialization_workaround(kwargs)
if local_uri:
Expand All @@ -2314,7 +2319,7 @@ def _get_read_tasks(
requested_parallelism,
min_safe_parallelism,
mem_size,
reader.get_read_tasks(requested_parallelism),
reader,
)


Expand Down
6 changes: 0 additions & 6 deletions python/ray/data/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,12 +1413,6 @@ def test_unsupported_pyarrow_versions_check_disabled(
except ImportError as e:
pytest.fail(f"_check_pyarrow_version failed unexpectedly: {e}")

# Test read_parquet.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pyarrow 5 does not support pickling the Parquet reader class. Given we do not support Pyarrow 5, remove the test code here. Already verified for Pyarrow 6+, it's not an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just remove the PyArrow 5 CI?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already removed. CI only tests 6 and 12. This test manually install 5.

try:
ray.data.read_parquet("example://iris.parquet").take_all()
except ImportError as e:
pytest.fail(f"_check_pyarrow_version failed unexpectedly: {e}")

# Test from_numpy (we use Arrow for representing the tensors).
try:
ray.data.from_numpy(np.arange(12).reshape((3, 2, 2)))
Expand Down
Loading
Loading