Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Use application-level retries for Parquet metadata tasks #42922

Merged
merged 10 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/data/_internal/planner/plan_read_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def do_read(blocks: Iterable[ReadTask], _: TaskContext) -> Iterable[Block]:

yield from call_with_retry(
f=read_task,
match=READ_FILE_RETRY_ON_ERRORS,
description=f"read file {read_fn_name}",
match=READ_FILE_RETRY_ON_ERRORS,
max_attempts=READ_FILE_MAX_ATTEMPTS,
max_backoff_s=READ_FILE_RETRY_MAX_BACKOFF_SECONDS,
)
Expand Down
11 changes: 7 additions & 4 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,17 +943,18 @@ def execute_computation(thread_index: int):

def call_with_retry(
f: Callable[[], Any],
match: List[str],
description: str,
*,
match: Optional[List[str]] = None,
max_attempts: int = 10,
max_backoff_s: int = 32,
) -> Any:
"""Retry a function with exponential backoff.

Args:
f: The function to retry.
match: A list of strings to match in the exception message.
match: A list of strings to match in the exception message. If ``None``, any
error is retried.
description: An imperitive description of the function being retried. For
example, "open the file".
max_attempts: The maximum number of attempts to retry.
Expand All @@ -965,10 +966,12 @@ def call_with_retry(
try:
return f()
except Exception as e:
is_retryable = any([pattern in str(e) for pattern in match])
is_retryable = match is None or any(
[pattern in str(e) for pattern in match]
)
if is_retryable and i + 1 < max_attempts:
# Retry with binary expoential backoff with random jitter.
backoff = min((2 ** (i + 1)) * random.random(), max_backoff_s)
backoff = min((2 ** (i + 1)), max_backoff_s) * random.random()
logger.debug(
f"Retrying {i+1} attempts to {description} after {backoff} seconds."
)
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,8 @@ def _open_file_with_retry(

return call_with_retry(
open_file,
match=OPEN_FILE_RETRY_ON_ERRORS,
description=f"open file {file_path}",
match=OPEN_FILE_RETRY_ON_ERRORS,
max_attempts=OPEN_FILE_MAX_ATTEMPTS,
max_backoff_s=OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS,
)
4 changes: 2 additions & 2 deletions python/ray/data/datasource/file_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def write_row_to_path():
logger.get_logger().debug(f"Writing {write_path} file.")
call_with_retry(
write_row_to_path,
match=DataContext.get_current().write_file_retry_on_errors,
description=f"write '{write_path}'",
match=DataContext.get_current().write_file_retry_on_errors,
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)
Expand Down Expand Up @@ -268,8 +268,8 @@ def write_block_to_path():
logger.get_logger().debug(f"Writing {write_path} file.")
call_with_retry(
write_block_to_path,
match=DataContext.get_current().write_file_retry_on_errors,
description=f"write '{write_path}'",
match=DataContext.get_current().write_file_retry_on_errors,
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)
Expand Down
99 changes: 32 additions & 67 deletions python/ray/data/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import ray.cloudpickle as cloudpickle
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import _check_pyarrow_version, _is_local_scheme
from ray.data._internal.util import (
_check_pyarrow_version,
_is_local_scheme,
call_with_retry,
)
from ray.data.block import Block
from ray.data.context import DataContext
from ray.data.datasource import Datasource
Expand Down Expand Up @@ -52,9 +56,9 @@
NUM_CPUS_FOR_META_FETCH_TASK = 0.5

# The application-level exceptions to retry for metadata prefetching task.
# Default to retry on `OSError` because AWS S3 would throw this transient
# error when load is too high.
RETRY_EXCEPTIONS_FOR_META_FETCH_TASK = [OSError]
# Default to retry on access denied and read timeout errors because AWS S3 would throw
# these transient errors when load is too high.
RETRY_EXCEPTIONS_FOR_META_FETCH_TASK = ["AWS Error ACCESS_DENIED", "Timeout"]
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

# The number of rows to read per batch. This is sized to generate 10MiB batches
# for rows about 1KiB in size.
Expand Down Expand Up @@ -117,60 +121,6 @@ def _deserialize_fragments(
return [p.deserialize() for p in serialized_fragments]


def _deserialize_fragments_with_retry(
serialized_fragments: List[_SerializedFragment],
) -> List["pyarrow._dataset.ParquetFileFragment"]:
"""
Deserialize the given serialized_fragments with retry upon errors.

This retry helps when the upstream datasource is not able to handle
overloaded read request or failed with some retriable failures.
For example when reading data from HA hdfs service, hdfs might
lose connection for some unknown reason expecially when
simutaneously running many hyper parameter tuning jobs
with ray.data parallelism setting at high value like the default 200
Such connection failure can be restored with some waiting and retry.
"""
min_interval = 0
final_exception = None
for i in range(FILE_READING_RETRY):
try:
return _deserialize_fragments(serialized_fragments)
except Exception as e:
import random
import time

retry_timing = (
""
if i == FILE_READING_RETRY - 1
else (f"Retry after {min_interval} sec. ")
)
log_only_show_in_1st_retry = (
""
if i
else (
f"If earlier read attempt threw certain Exception"
f", it may or may not be an issue depends on these retries "
f"succeed or not. serialized_fragments:{serialized_fragments}"
)
)
logger.exception(
f"{i + 1}th attempt to deserialize ParquetFileFragment failed. "
f"{retry_timing}"
f"{log_only_show_in_1st_retry}"
)
if not min_interval:
# to make retries of different process hit hdfs server
# at slightly different time
min_interval = 1 + random.random()
# exponential backoff at
# 1, 2, 4, 8, 16, 32, 64
time.sleep(min_interval)
min_interval = min_interval * 2
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
final_exception = e
raise final_exception


@PublicAPI
class ParquetDatasource(Datasource):
"""Parquet datasource, for reading and writing Parquet files.
Expand Down Expand Up @@ -292,10 +242,6 @@ def __init__(
prefetch_remote_args[
"scheduling_strategy"
] = DataContext.get_current().scheduling_strategy
if RETRY_EXCEPTIONS_FOR_META_FETCH_TASK is not None:
prefetch_remote_args[
"retry_exceptions"
] = RETRY_EXCEPTIONS_FOR_META_FETCH_TASK
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

self._metadata = (
meta_provider.prefetch_file_metadata(
Expand Down Expand Up @@ -549,14 +495,33 @@ def _read_fragments(
yield table


def _deserialize_fragments_with_retry(fragments):
# The deserialization retry helps when the upstream datasource is not able to
# handle overloaded read request or failed with some retriable failures.
# For example when reading data from HA hdfs service, hdfs might
# lose connection for some unknown reason expecially when
# simutaneously running many hyper parameter tuning jobs
# with ray.data parallelism setting at high value like the default 200
# Such connection failure can be restored with some waiting and retry.
return call_with_retry(
lambda: _deserialize_fragments(fragments),
description="deserialize fragments",
max_attempts=FILE_READING_RETRY,
)


def _fetch_metadata_serialization_wrapper(
fragments: List[_SerializedFragment],
) -> List["pyarrow.parquet.FileMetaData"]:
fragments: List[
"pyarrow._dataset.ParquetFileFragment"
] = _deserialize_fragments_with_retry(fragments)

return _fetch_metadata(fragments)
deserialized_fragments = _deserialize_fragments_with_retry(fragments)
metadata = call_with_retry(
lambda: _fetch_metadata(deserialized_fragments),
description="fetch metdata",
match=RETRY_EXCEPTIONS_FOR_META_FETCH_TASK,
max_attempts=32,
max_backoff_s=32,
)
return metadata


def _fetch_metadata(
Expand Down
5 changes: 0 additions & 5 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ray.data.datasource.parquet_datasource import (
NUM_CPUS_FOR_META_FETCH_TASK,
PARALLELIZE_META_FETCH_THRESHOLD,
RETRY_EXCEPTIONS_FOR_META_FETCH_TASK,
ParquetDatasource,
_deserialize_fragments_with_retry,
_SerializedFragment,
Expand Down Expand Up @@ -212,10 +211,6 @@ def prefetch_file_metadata(self, fragments, **ray_remote_args):
ray_remote_args["scheduling_strategy"]
== DataContext.get_current().scheduling_strategy
)
assert (
ray_remote_args["retry_exceptions"]
== RETRY_EXCEPTIONS_FOR_META_FETCH_TASK
)
return None

ds = ray.data.read_parquet(
Expand Down