From 78357a25948d45ce747aac3286dd9f2d8077e1b6 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 30 Sep 2019 16:05:02 -0700 Subject: [PATCH] Enable arrow tests, fix bug with Arrow stream message reading --- .../arrow/kernels/arrow_stream_client_unix.cc | 3 ++- .../arrow/python/ops/arrow_dataset_ops.py | 16 +++++++++------- tests/test_arrow_eager.py | 2 -- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc b/tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc index 3e45fc094..4710dbfb0 100644 --- a/tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc +++ b/tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc @@ -121,8 +121,9 @@ arrow::Status ArrowStreamClient::Tell(int64_t* position) const { arrow::Status ArrowStreamClient::Read(int64_t nbytes, int64_t* bytes_read, void* out) { - // TODO: look into why 0 bytes are requested + // TODO: 0 bytes requested when message body length == 0 if (nbytes == 0) { + *bytes_read = 0; return arrow::Status::OK(); } diff --git a/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py b/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py index 1963f9a25..8c1216d86 100644 --- a/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py +++ b/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py @@ -27,8 +27,7 @@ import tensorflow as tf from tensorflow import dtypes -from tensorflow.compat.v2 import data -from tensorflow.python.data.ops.dataset_ops import flat_structure +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure as structure_lib from tensorflow_io.core.python.ops import core_ops @@ -88,7 +87,7 @@ def arrow_schema_to_tensor_types(schema): return tensor_types, tensor_shapes -class ArrowBaseDataset(data.Dataset): +class ArrowBaseDataset(dataset_ops.DatasetV2): """Base class for Arrow Datasets to provide columns used in record batches and corresponding output tensor types, shapes and classes. """ @@ -121,21 +120,24 @@ def __init__(self, dtypes.string, name="batch_mode") if batch_size is not None or batch_mode == 'auto': + spec_batch_size = batch_size if batch_mode == 'drop_remainder' else None # pylint: disable=protected-access - self._structure = self._structure._batch( - batch_size if batch_mode == 'drop_remainder' else None) + self._structure = nest.map_structure( + lambda component_spec: component_spec._batch(spec_batch_size), + self._structure) + print(self._flat_structure) variant_tensor = make_variant_fn( columns=self._columns, batch_size=self._batch_size, batch_mode=self._batch_mode, - **flat_structure(self)) + **self._flat_structure) super(ArrowBaseDataset, self).__init__(variant_tensor) def _inputs(self): return [] @property - def _element_structure(self): + def element_spec(self): return self._structure @property diff --git a/tests/test_arrow_eager.py b/tests/test_arrow_eager.py index ca337070a..f153da0ad 100644 --- a/tests/test_arrow_eager.py +++ b/tests/test_arrow_eager.py @@ -35,8 +35,6 @@ from tensorflow import errors # pylint: disable=wrong-import-position from tensorflow import test # pylint: disable=wrong-import-position -pytest.skip( - "arrow test is disabled temporarily", allow_module_level=True) import tensorflow_io.arrow as arrow_io # pylint: disable=wrong-import-position if sys.version_info == (3, 4):