Skip to content

Commit

Permalink
[Data] Retry write if error during file clean up (#42326) (#42327)
Browse files Browse the repository at this point in the history
We already do retries if an error happens during the actual write or when a file is opened. But, you can get an error when the file closes (specifically, when the NativeFile context closes), and that isn't retried.

To fix this issue, this PR updates our logic to increase the scope of retries.

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
  • Loading branch information
bveeramani committed Jan 11, 2024
1 parent b2d50b8 commit dc0c031
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 40 deletions.
6 changes: 5 additions & 1 deletion python/ray/data/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@
# Whether to enable get_object_locations for metric
DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS = False

DEFAULT_WRITE_FILE_RETRY_ON_ERRORS = ["AWS Error INTERNAL_FAILURE"]
DEFAULT_WRITE_FILE_RETRY_ON_ERRORS = [
"AWS Error INTERNAL_FAILURE",
"AWS Error NETWORK_CONNECTION",
"AWS Error SLOW_DOWN",
]


@DeveloperAPI
Expand Down
62 changes: 25 additions & 37 deletions python/ray/data/datasource/file_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ray.data.context import DataContext
from ray.data.datasource.block_path_provider import BlockWritePathProvider
from ray.data.datasource.datasink import Datasink
from ray.data.datasource.file_based_datasource import _open_file_with_retry
from ray.data.datasource.filename_provider import (
FilenameProvider,
_DefaultFilenameProvider,
Expand Down Expand Up @@ -86,6 +85,9 @@ def __init__(

self.has_created_dir = False

def open_output_stream(self, path: str) -> "pyarrow.NativeFile":
return self.filesystem.open_output_stream(path, **self.open_stream_args)

def on_write_start(self) -> None:
"""Create a directory to write files to.
Expand Down Expand Up @@ -187,17 +189,6 @@ def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
"""
raise NotImplementedError

def _write_row_to_file_with_retry(
self, row: Dict[str, Any], file: "pyarrow.NativeFile", path: str
):
call_with_retry(
lambda: self.write_row_to_file(row, file),
match=DataContext.get_current().write_file_retry_on_errors,
description=f"write '{path}'",
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)

def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
for row_index, row in enumerate(block.iter_rows(public_row_format=False)):
if self.filename_provider is not None:
Expand All @@ -212,14 +203,18 @@ def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
)
write_path = posixpath.join(self.path, filename)

def write_row_to_path():
with self.open_output_stream(write_path) as file:
self.write_row_to_file(row, file)

logger.get_logger().debug(f"Writing {write_path} file.")
with _open_file_with_retry(
write_path,
lambda: self.filesystem.open_output_stream(
write_path, **self.open_stream_args
),
) as file:
self._write_row_to_file_with_retry(row, file, write_path)
call_with_retry(
write_row_to_path,
match=DataContext.get_current().write_file_retry_on_errors,
description=f"write '{write_path}'",
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)


@DeveloperAPI
Expand Down Expand Up @@ -250,17 +245,6 @@ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
"""
raise NotImplementedError

def _write_block_to_file_with_retry(
self, block: BlockAccessor, file: "pyarrow.NativeFile", path: str
):
call_with_retry(
lambda: self.write_block_to_file(block, file),
match=DataContext.get_current().write_file_retry_on_errors,
description=f"write '{path}'",
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)

def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
if self.filename_provider is not None:
filename = self.filename_provider.get_filename_for_block(
Expand All @@ -278,11 +262,15 @@ def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext):
file_format=self.file_format,
)

def write_block_to_path():
with self.open_output_stream(write_path) as file:
self.write_block_to_file(block, file)

logger.get_logger().debug(f"Writing {write_path} file.")
with _open_file_with_retry(
write_path,
lambda: self.filesystem.open_output_stream(
write_path, **self.open_stream_args
),
) as file:
self._write_block_to_file_with_retry(block, file, write_path)
call_with_retry(
write_block_to_path,
match=DataContext.get_current().write_file_retry_on_errors,
description=f"write '{write_path}'",
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)
70 changes: 68 additions & 2 deletions python/ray/data/tests/test_file_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pyarrow
import pytest
from pyarrow.fs import LocalFileSystem

import ray
from ray.data.block import BlockAccessor
Expand All @@ -14,7 +15,72 @@ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
file.write(b"")


def test_flaky_block_write(ray_start_regular_shared, tmp_path):
class FlakyOutputStream:
def __init__(self, stream: pyarrow.NativeFile, num_attempts: int):
self._stream = stream
self._num_attempts = num_attempts

def __enter__(self):
return self._stream.__enter__()

def __exit__(self, exc_type, exc_value, traceback):
if self._num_attempts < 2:
raise RuntimeError("AWS Error NETWORK_CONNECTION")

self._stream.__exit__(exc_type, exc_value, traceback)


def test_flaky_block_based_open_output_stream(ray_start_regular_shared, tmp_path):
class FlakyCSVDatasink(BlockBasedFileDatasink):
def __init__(self, path: str):
super().__init__(path)
self._num_attempts = 0
self._filesystem = LocalFileSystem()

def open_output_stream(self, path: str) -> "pyarrow.NativeFile":
stream = self._filesystem.open_output_stream(path)
flaky_stream = FlakyOutputStream(stream, self._num_attempts)
self._num_attempts += 1
return flaky_stream

def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
block.to_pandas().to_csv(file)

ds = ray.data.range(100)

ds.write_datasink(FlakyCSVDatasink(tmp_path))

expected_values = list(range(100))
written_values = [row["id"] for row in ray.data.read_csv(tmp_path).take_all()]
assert sorted(written_values) == sorted(expected_values)


def test_flaky_row_based_open_output_stream(ray_start_regular_shared, tmp_path):
class FlakyTextDatasink(RowBasedFileDatasink):
def __init__(self, path: str):
super().__init__(path)
self._num_attempts = 0
self._filesystem = LocalFileSystem()

def open_output_stream(self, path: str) -> "pyarrow.NativeFile":
stream = self._filesystem.open_output_stream(path)
flaky_stream = FlakyOutputStream(stream, self._num_attempts)
self._num_attempts += 1
return flaky_stream

def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
file.write(f"{row['id']}".encode())

ds = ray.data.range(100)

ds.write_datasink(FlakyTextDatasink(tmp_path))

expected_values = [str(i) for i in range(100)]
written_values = [row["text"] for row in ray.data.read_text(tmp_path).take_all()]
assert sorted(written_values) == sorted(expected_values)


def test_flaky_write_block_to_file(ray_start_regular_shared, tmp_path):
class FlakyCSVDatasink(BlockBasedFileDatasink):
def __init__(self, path: str):
super().__init__(path)
Expand All @@ -36,7 +102,7 @@ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
assert sorted(written_values) == sorted(expected_values)


def test_flaky_row_write(ray_start_regular_shared, tmp_path):
def test_flaky_write_row_to_file(ray_start_regular_shared, tmp_path):
class FlakyTextDatasink(RowBasedFileDatasink):
def __init__(self, path: str):
super().__init__(path)
Expand Down

0 comments on commit dc0c031

Please sign in to comment.