Skip to content

Commit

Permalink
add test for compression
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Bomio <martinbomio@spotify.com>
  • Loading branch information
martinbomio committed Jan 17, 2024
1 parent dc52292 commit 735fb4e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
4 changes: 0 additions & 4 deletions python/ray/data/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,9 @@ def __init__(
self._batch_size = batch_size or DEFAULT_BATCH_SIZE

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
print(f"_read_stream -> _fast_read: {self._fast_read}")
if self._fast_read:
print("reading fast!")
yield from self._fast_read_stream(f, path)
else:
print("reading slow!")
yield from self._slow_read_stream(f, path)

def _slow_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
Expand Down Expand Up @@ -108,7 +105,6 @@ def _convert_example_to_dict(
example: "tf.train.Example",
tf_schema: Optional["schema_pb2.Schema"],
) -> Dict[str, "pyarrow.Array"]:
print("_convert_example_to_dict HERE")
record = {}
schema_dict = {}
# Convert user-specified schema into dict for convenient mapping
Expand Down
16 changes: 12 additions & 4 deletions python/ray/data/tests/test_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,13 @@ def _str2bytes(d):


@pytest.mark.parametrize(
"with_tf_schema,fast_read",
[(True, True), (True, False), (False, True), (False, False)],
"with_tf_schema,fast_read,compression",
[(True, True, None), (True, True, "GZIP"), (True, False, None), (False, True, None), (False, True, "GZIP"), (False, False, None)],
)
def test_read_tfrecords(
with_tf_schema,
fast_read,
compression,
ray_start_regular_shared,
tmp_path,
):
Expand All @@ -358,11 +359,18 @@ def test_read_tfrecords(
tf_schema = _features_to_schema(example.features)

path = os.path.join(tmp_path, "data.tfrecords")
with tf.io.TFRecordWriter(path=path) as writer:
with tf.io.TFRecordWriter(
path=path, options=tf.io.TFRecordOptions(compression_type=compression)
) as writer:
writer.write(example.SerializeToString())

arrow_open_stream_args = None
if compression:
arrow_open_stream_args={"compression": compression}

ds = read_tfrecords_with_fast_read_override(
path, tf_schema=tf_schema, fast_read=fast_read
path, tf_schema=tf_schema, fast_read=fast_read,
arrow_open_stream_args=arrow_open_stream_args
)

df = ds.to_pandas()
Expand Down

0 comments on commit 735fb4e

Please sign in to comment.