Skip to content

Commit

Permalink
[Data] Add FileDatasink subclasses (ray-project#40693)
Browse files Browse the repository at this point in the history
This PR is part of a larger effort to clean up Datasource interfaces (ray-project#40296). This ray-project#40691 added the new FileDatasink base class, and this PR migrates FileDatasource implementations to the new API.

The primary motivation for these changes is to reduced complexity of our internal code base. For more information, see https://docs.google.com/document/d/1Bqhbzvxv7liwpOhyBzRVy5tOzXdy-NiMSFa-6hupr18/edit#heading=h.rytitv546vx5.

---------

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu>
Co-authored-by: Stephanie Wang <swang@cs.berkeley.edu>
  • Loading branch information
2 people authored and ujjawal-khare committed Nov 29, 2023
1 parent adb2832 commit 12bed94
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 50 deletions.
85 changes: 38 additions & 47 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,19 @@
from ray.data.datasource import (
BlockWritePathProvider,
Connection,
CSVDatasource,
Datasink,
Datasource,
FilenameProvider,
ImageDatasource,
JSONDatasource,
NumpyDatasource,
ParquetDatasource,
ReadTask,
TFRecordDatasource,
_BigQueryDatasink,
_CSVDatasink,
_ImageDatasink,
_JSONDatasink,
_NumpyDatasink,
_SQLDatasink,
_TFRecordDatasink,
_WebDatasetDatasink,
)
from ray.data.iterator import DataIterator
from ray.data.random_access_dataset import RandomAccessDataset
Expand Down Expand Up @@ -2839,19 +2840,18 @@ def write_json(
:class:`~ray.data.Dataset` block. These
are dict(orient="records", lines=True) by default.
"""
self.write_datasource(
JSONDatasource(),
ray_remote_args=ray_remote_args,
path=path,
dataset_uuid=self._uuid,
datasink = _JSONDatasink(
path,
pandas_json_args_fn=pandas_json_args_fn,
pandas_json_args=pandas_json_args,
filesystem=filesystem,
try_create_dir=try_create_dir,
open_stream_args=arrow_open_stream_args,
filename_provider=filename_provider,
block_path_provider=block_path_provider,
write_args_fn=pandas_json_args_fn,
**pandas_json_args,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)

@PublicAPI(stability="alpha")
@ConsumptionAPI
Expand Down Expand Up @@ -2901,18 +2901,17 @@ def write_images(
opening the file to write to.
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the write tasks.
""" # noqa: E501
self.write_datasource(
ImageDatasource(),
ray_remote_args=ray_remote_args,
path=path,
dataset_uuid=self._uuid,
datasink = _ImageDatasink(
path,
column,
file_format,
filesystem=filesystem,
try_create_dir=try_create_dir,
open_stream_args=arrow_open_stream_args,
filename_provider=filename_provider,
column=column,
file_format=file_format,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)

@ConsumptionAPI
def write_csv(
Expand Down Expand Up @@ -2994,19 +2993,18 @@ def write_csv(
#pyarrow.csv.write_csv>`_
when writing each block to a file.
"""
self.write_datasource(
CSVDatasource(),
ray_remote_args=ray_remote_args,
path=path,
dataset_uuid=self._uuid,
datasink = _CSVDatasink(
path,
arrow_csv_args_fn=arrow_csv_args_fn,
arrow_csv_args=arrow_csv_args,
filesystem=filesystem,
try_create_dir=try_create_dir,
open_stream_args=arrow_open_stream_args,
filename_provider=filename_provider,
block_path_provider=block_path_provider,
write_args_fn=arrow_csv_args_fn,
**arrow_csv_args,
dataset_uuid=self._uuid,
)
self.write_datasource(datasink, ray_remote_args=ray_remote_args)

@ConsumptionAPI
def write_tfrecords(
Expand Down Expand Up @@ -3079,19 +3077,17 @@ def write_tfrecords(
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the write tasks.
"""

self.write_datasource(
TFRecordDatasource(),
ray_remote_args=ray_remote_args,
datasink = _TFRecordDatasink(
path=path,
dataset_uuid=self._uuid,
tf_schema=tf_schema,
filesystem=filesystem,
try_create_dir=try_create_dir,
open_stream_args=arrow_open_stream_args,
filename_provider=filename_provider,
block_path_provider=block_path_provider,
tf_schema=tf_schema,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)

@PublicAPI(stability="alpha")
@ConsumptionAPI
Expand Down Expand Up @@ -3152,21 +3148,17 @@ def write_webdataset(
ray_remote_args: Kwargs passed to ``ray.remote`` in the write tasks.
"""

from ray.data.datasource.webdataset_datasource import WebDatasetDatasource

self.write_datasource(
WebDatasetDatasource(),
ray_remote_args=ray_remote_args,
path=path,
dataset_uuid=self._uuid,
datasink = _WebDatasetDatasink(
path,
encoder=encoder,
filesystem=filesystem,
try_create_dir=try_create_dir,
open_stream_args=arrow_open_stream_args,
filename_provider=filename_provider,
block_path_provider=block_path_provider,
encoder=encoder,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)

@ConsumptionAPI
def write_numpy(
Expand Down Expand Up @@ -3230,18 +3222,17 @@ def write_numpy(
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the write tasks.
"""

self.write_datasource(
NumpyDatasource(),
ray_remote_args=ray_remote_args,
path=path,
dataset_uuid=self._uuid,
column=column,
datasink = _NumpyDatasink(
path,
column,
filesystem=filesystem,
try_create_dir=try_create_dir,
open_stream_args=arrow_open_stream_args,
filename_provider=filename_provider,
block_path_provider=block_path_provider,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)

@ConsumptionAPI
def write_sql(
Expand Down
12 changes: 12 additions & 0 deletions python/ray/data/datasource/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
BlockWritePathProvider,
DefaultBlockWritePathProvider,
)
from ray.data.datasource.csv_datasink import _CSVDatasink
from ray.data.datasource.csv_datasource import CSVDatasource
from ray.data.datasource.datasink import Datasink
from ray.data.datasource.datasource import (
Expand Down Expand Up @@ -34,9 +35,12 @@
ParquetMetadataProvider,
)
from ray.data.datasource.filename_provider import FilenameProvider
from ray.data.datasource.image_datasink import _ImageDatasink
from ray.data.datasource.image_datasource import ImageDatasource
from ray.data.datasource.json_datasink import _JSONDatasink
from ray.data.datasource.json_datasource import JSONDatasource
from ray.data.datasource.mongo_datasource import MongoDatasource
from ray.data.datasource.numpy_datasink import _NumpyDatasink
from ray.data.datasource.numpy_datasource import NumpyDatasource
from ray.data.datasource.parquet_base_datasource import ParquetBaseDatasource
from ray.data.datasource.parquet_datasource import ParquetDatasource
Expand All @@ -49,8 +53,10 @@
from ray.data.datasource.sql_datasink import _SQLDatasink
from ray.data.datasource.sql_datasource import Connection, SQLDatasource
from ray.data.datasource.text_datasource import TextDatasource
from ray.data.datasource.tfrecords_datasink import _TFRecordDatasink
from ray.data.datasource.tfrecords_datasource import TFRecordDatasource
from ray.data.datasource.torch_datasource import TorchDatasource
from ray.data.datasource.webdataset_datasink import _WebDatasetDatasink
from ray.data.datasource.webdataset_datasource import WebDatasetDatasource

# Note: HuggingFaceDatasource should NOT be imported here, because
Expand All @@ -64,6 +70,7 @@
"BlockBasedFileDatasink",
"BlockWritePathProvider",
"Connection",
"_CSVDatasink",
"CSVDatasource",
"Datasink",
"Datasource",
Expand All @@ -78,8 +85,11 @@
"FileExtensionFilter",
"FileMetadataProvider",
"FilenameProvider",
"_ImageDatasink",
"ImageDatasource",
"_JSONDatasink",
"JSONDatasource",
"_NumpyDatasink",
"NumpyDatasource",
"ParquetBaseDatasource",
"ParquetDatasource",
Expand All @@ -95,8 +105,10 @@
"Reader",
"RowBasedFileDatasink",
"TextDatasource",
"_TFRecordDatasink",
"TFRecordDatasource",
"TorchDatasource",
"_WebDatasetDatasink",
"WebDatasetDatasource",
"WriteResult",
"_S3FileSystemWrapper",
Expand Down
33 changes: 33 additions & 0 deletions python/ray/data/datasource/csv_datasink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, Callable, Dict, Optional

import pyarrow

from ray.data.block import BlockAccessor
from ray.data.datasource.file_based_datasource import _resolve_kwargs
from ray.data.datasource.file_datasink import BlockBasedFileDatasink


class _CSVDatasink(BlockBasedFileDatasink):
def __init__(
self,
path: str,
*,
arrow_csv_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
arrow_csv_args: Optional[Dict[str, Any]] = None,
file_format="csv",
**file_datasink_kwargs,
):
super().__init__(path, file_format=file_format, **file_datasink_kwargs)

if arrow_csv_args is None:
arrow_csv_args = {}

self.arrow_csv_args_fn = arrow_csv_args_fn
self.arrow_csv_args = arrow_csv_args

def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
from pyarrow import csv

writer_args = _resolve_kwargs(self.arrow_csv_args_fn, **self.arrow_csv_args)
write_options = writer_args.pop("write_options", None)
csv.write_csv(block.to_arrow(), file, write_options, **writer_args)
24 changes: 24 additions & 0 deletions python/ray/data/datasource/image_datasink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import io
from typing import Any, Dict

import pyarrow

from ray.data.datasource.file_datasink import RowBasedFileDatasink


class _ImageDatasink(RowBasedFileDatasink):
def __init__(
self, path: str, column: str, file_format: str, **file_datasink_kwargs
):
super().__init__(path, file_format=file_format, **file_datasink_kwargs)

self.column = column
self.file_format = file_format

def write_row_to_file(self, row: Dict[str, Any], file: "pyarrow.NativeFile"):
from PIL import Image

image = Image.fromarray(row[self.column])
buffer = io.BytesIO()
image.save(buffer, format=self.file_format)
file.write(buffer.getvalue())
33 changes: 33 additions & 0 deletions python/ray/data/datasource/json_datasink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, Callable, Dict, Optional

import pyarrow

from ray.data.block import BlockAccessor
from ray.data.datasource.file_based_datasource import _resolve_kwargs
from ray.data.datasource.file_datasink import BlockBasedFileDatasink


class _JSONDatasink(BlockBasedFileDatasink):
def __init__(
self,
path: str,
*,
pandas_json_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
pandas_json_args: Optional[Dict[str, Any]] = None,
file_format: str = "json",
**file_datasink_kwargs,
):
super().__init__(path, file_format=file_format, **file_datasink_kwargs)

if pandas_json_args is None:
pandas_json_args = {}

self.pandas_json_args_fn = pandas_json_args_fn
self.pandas_json_args = pandas_json_args

def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
writer_args = _resolve_kwargs(self.pandas_json_args_fn, **self.pandas_json_args)
orient = writer_args.pop("orient", "records")
lines = writer_args.pop("lines", True)

block.to_pandas().to_json(file, orient=orient, lines=lines, **writer_args)
23 changes: 23 additions & 0 deletions python/ray/data/datasource/numpy_datasink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
import pyarrow

from ray.data.block import BlockAccessor
from ray.data.datasource.file_datasink import BlockBasedFileDatasink


class _NumpyDatasink(BlockBasedFileDatasink):
def __init__(
self,
path: str,
column: str,
*,
file_format: str = "npy",
**file_datasink_kwargs,
):
super().__init__(path, file_format=file_format, **file_datasink_kwargs)

self.column = column

def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"):
value = block.to_numpy(self.column)
np.save(file, value)
Loading

0 comments on commit 12bed94

Please sign in to comment.