Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorflow_io/arrow/kernels/arrow_stream_client_unix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ arrow::Status ArrowStreamClient::Tell(int64_t* position) const {
arrow::Status ArrowStreamClient::Read(int64_t nbytes,
int64_t* bytes_read,
void* out) {
// TODO: look into why 0 bytes are requested
// TODO: 0 bytes requested when message body length == 0
if (nbytes == 0) {
*bytes_read = 0;
return arrow::Status::OK();
}

Expand Down
16 changes: 9 additions & 7 deletions tensorflow_io/arrow/python/ops/arrow_dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@

import tensorflow as tf
from tensorflow import dtypes
from tensorflow.compat.v2 import data
from tensorflow.python.data.ops.dataset_ops import flat_structure
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure as structure_lib
from tensorflow_io.core.python.ops import core_ops

Expand Down Expand Up @@ -88,7 +87,7 @@ def arrow_schema_to_tensor_types(schema):
return tensor_types, tensor_shapes


class ArrowBaseDataset(data.Dataset):
class ArrowBaseDataset(dataset_ops.DatasetV2):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ok to use as a base for 1.15.0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be OK I think. As long as dataset could be passed to tf.keras (with 1.15 and 2.0) we should be fine. Both DatasetV1 and DatasetV2 works with tf.keras. 👍

"""Base class for Arrow Datasets to provide columns used in record batches
and corresponding output tensor types, shapes and classes.
"""
Expand Down Expand Up @@ -121,21 +120,24 @@ def __init__(self,
dtypes.string,
name="batch_mode")
if batch_size is not None or batch_mode == 'auto':
spec_batch_size = batch_size if batch_mode == 'drop_remainder' else None
# pylint: disable=protected-access
self._structure = self._structure._batch(
batch_size if batch_mode == 'drop_remainder' else None)
self._structure = nest.map_structure(
lambda component_spec: component_spec._batch(spec_batch_size),
self._structure)
print(self._flat_structure)
variant_tensor = make_variant_fn(
columns=self._columns,
batch_size=self._batch_size,
batch_mode=self._batch_mode,
**flat_structure(self))
**self._flat_structure)
super(ArrowBaseDataset, self).__init__(variant_tensor)

def _inputs(self):
return []

@property
def _element_structure(self):
def element_spec(self):
return self._structure

@property
Expand Down
2 changes: 0 additions & 2 deletions tests/test_arrow_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from tensorflow import errors # pylint: disable=wrong-import-position
from tensorflow import test # pylint: disable=wrong-import-position

pytest.skip(
"arrow test is disabled temporarily", allow_module_level=True)
import tensorflow_io.arrow as arrow_io # pylint: disable=wrong-import-position

if sys.version_info == (3, 4):
Expand Down