Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions tensorflow_io/arrow/kernels/arrow_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "arrow/util/io-util.h"
#include "tensorflow_io/arrow/kernels/arrow_stream_client.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/graph/graph.h"

#define CHECK_ARROW(arrow_status) \
do { \
Expand Down Expand Up @@ -362,8 +363,6 @@ class ArrowOpKernelBase : public DatasetOpKernel {
// memory in a Python process, or a Pandas DataFrame.
class ArrowDatasetOp : public ArrowOpKernelBase {
public:
//using DatasetOpKernel::DatasetOpKernel;

explicit ArrowDatasetOp(OpKernelConstruction* ctx) : ArrowOpKernelBase(ctx) {}

virtual void MakeArrowDataset(
Expand All @@ -374,24 +373,29 @@ class ArrowDatasetOp : public ArrowOpKernelBase {
const Tensor* batches_tensor;
OP_REQUIRES_OK(ctx, ctx->input("serialized_batches", &batches_tensor));
OP_REQUIRES(
ctx, batches_tensor->dims() <= 0,
errors::InvalidArgument("`serialized_batches` must be a scalar."));
string batches = batches_tensor->flat<string>()(0);

*output = new Dataset(ctx, batches, columns, output_types_, output_shapes_);
ctx, TensorShapeUtils::IsScalar(batches_tensor->shape()),
errors::InvalidArgument("serialized_batches must be a scalar"));
*output = new Dataset(
ctx,
*batches_tensor,
columns,
output_types_,
output_shapes_);
}

private:
class Dataset : public ArrowDatasetBase {
public:
// Construct a Dataset that consumed Arrow batches from serialized bytes
// in a string. Record batches should be serialized in Arrow File format.
Dataset(OpKernelContext* ctx, const string& serialized_batches,
Dataset(OpKernelContext* ctx,
const Tensor batches_tensor,
const std::vector<int32>& columns,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: ArrowDatasetBase(ctx, columns, output_types, output_shapes),
batches_(serialized_batches) {}
batches_(std::move(batches_tensor)) {
}

string DebugString() const override { return "ArrowDatasetOp::Dataset"; }

Expand All @@ -400,7 +404,13 @@ class ArrowDatasetOp : public ArrowOpKernelBase {
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* batches = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(batches_, &batches));
if (ctx->optimization_only()) {
TF_RETURN_IF_ERROR(b->AddPlaceholder(batches_, &batches));
DCHECK_NE(ctx->input_list(), nullptr);
ctx->input_list()->emplace_back(batches->name(), batches_);
} else {
TF_RETURN_IF_ERROR(b->AddTensor(batches_, &batches));
}
Node* columns = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(columns_, &columns));
TF_RETURN_IF_ERROR(b->AddDataset(this, {batches, columns}, output));
Expand All @@ -422,8 +432,8 @@ class ArrowDatasetOp : public ArrowOpKernelBase {
private:
Status SetupStreamsLocked(Env* env)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
std::shared_ptr<arrow::Buffer> buffer;
CHECK_ARROW(arrow::Buffer::FromString(dataset()->batches_, &buffer));
const string& batches = dataset()->batches_.scalar<string>()();
auto buffer = std::make_shared<arrow::Buffer>(batches);
auto buffer_reader = std::make_shared<arrow::io::BufferReader>(buffer);
CHECK_ARROW(
arrow::ipc::RecordBatchFileReader::Open(buffer_reader, &reader_));
Expand Down Expand Up @@ -458,7 +468,7 @@ class ArrowDatasetOp : public ArrowOpKernelBase {
int num_batches_ GUARDED_BY(mu_) = 0;
};

const string batches_;
const Tensor batches_;
};
};

Expand Down
46 changes: 33 additions & 13 deletions tensorflow_io/arrow/python/ops/arrow_dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,52 @@ class ArrowDataset(ArrowBaseDataset):
"""

def __init__(self,
record_batches,
serialized_batches,
columns,
output_types,
output_shapes=None):
"""Create an ArrowDataset directly from Arrow record batches.
"""Create an ArrowDataset from a Tensor of serialized batches.
This constructor requires pyarrow to be installed.

Args:
record_batches: An Arrow record batch or sequence of record batches
serialized_batches: A string Tensor as a serialized buffer containing
Arrow record batches as Arrow file format
columns: A list of column indices to be used in the Dataset
output_types: Tensor dtypes of the output tensors
output_shapes: TensorShapes of the output tensors or None to
infer partial
"""
self._serialized_batches = serialized_batches
self._columns = columns
self._output_types = output_types
self._output_shapes = output_shapes or \
nest.map_structure(
lambda _: tensorflow.TensorShape(None), self._output_types)
super(ArrowDataset, self).__init__(columns, output_types, output_shapes)

def _as_variant_tensor(self):
return arrow_ops.arrow_dataset(
self._serialized_batches,
self._columns,
nest.flatten(self.output_types),
nest.flatten(self.output_shapes))

@classmethod
def from_record_batches(cls,
record_batches,
columns,
output_types,
output_shapes=None):
"""Create an ArrowDataset directly from Arrow record batches.
This constructor requires pyarrow to be installed.

Args:
record_batches: An Arrow record batch or sequence of record batches
columns: A list of column indices to be used in the Dataset
output_types: Tensor dtypes of the output tensors
output_shapes: TensorShapes of the output tensors or None to
infer partial
"""
import pyarrow as pa
if isinstance(record_batches, pa.RecordBatch):
record_batches = [record_batches]
Expand All @@ -148,18 +175,11 @@ def __init__(self,
for batch in record_batches:
writer.write_batch(batch)
writer.close()
self._serialized_batches = tensorflow.convert_to_tensor(
serialized_batches = tensorflow.convert_to_tensor(
buf.getvalue(),
dtype=dtypes.string,
name="serialized_batches")
super(ArrowDataset, self).__init__(columns, output_types, output_shapes)

def _as_variant_tensor(self):
return arrow_ops.arrow_dataset(
self._serialized_batches,
self._columns,
nest.flatten(self.output_types),
nest.flatten(self.output_shapes))
return cls(serialized_batches, columns, output_types, output_shapes)

@classmethod
def from_pandas(cls, df, columns=None, preserve_index=True):
Expand All @@ -179,7 +199,7 @@ def from_pandas(cls, df, columns=None, preserve_index=True):
batch = pa.RecordBatch.from_pandas(df, preserve_index=preserve_index)
columns = tuple(range(batch.num_columns))
output_types, output_shapes = arrow_schema_to_tensor_types(batch.schema)
return cls(batch, columns, output_types, output_shapes)
return cls.from_record_batches(batch, columns, output_types, output_shapes)


class ArrowFeatherDataset(ArrowBaseDataset):
Expand Down
44 changes: 39 additions & 5 deletions tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

from collections import namedtuple
import io
import os
import sys
import socket
Expand Down Expand Up @@ -175,7 +176,7 @@ def test_arrow_dataset(self):
batch = self.make_record_batch(truth_data)

# test all columns selected
dataset = arrow_io.ArrowDataset(
dataset = arrow_io.ArrowDataset.from_record_batches(
batch,
list(range(len(truth_data.output_types))),
truth_data.output_types,
Expand All @@ -184,7 +185,7 @@ def test_arrow_dataset(self):

# test column selection
columns = (1, 3, len(truth_data.output_types) - 1)
dataset = arrow_io.ArrowDataset(
dataset = arrow_io.ArrowDataset.from_record_batches(
batch,
columns,
tuple([truth_data.output_types[c] for c in columns]),
Expand Down Expand Up @@ -326,7 +327,7 @@ def test_bool_array_type(self):

batch = self.make_record_batch(truth_data)

dataset = arrow_io.ArrowDataset(
dataset = arrow_io.ArrowDataset.from_record_batches(
batch,
(0,),
truth_data.output_types,
Expand All @@ -339,7 +340,7 @@ def test_incorrect_column_type(self):
self.scalar_shapes)
batch = self.make_record_batch(truth_data)

dataset = arrow_io.ArrowDataset(
dataset = arrow_io.ArrowDataset.from_record_batches(
batch,
list(range(len(truth_data.output_types))),
tuple([dtypes.int32 for _ in truth_data.output_types]),
Expand All @@ -357,7 +358,7 @@ def test_map_and_batch(self):
(dtypes.int32,),
(tensorflow.TensorShape([]),))
batch = self.make_record_batch(truth_data)
dataset = arrow_io.ArrowDataset(
dataset = arrow_io.ArrowDataset.from_record_batches(
batch,
list(range(len(truth_data.output_types))),
truth_data.output_types,
Expand All @@ -379,6 +380,39 @@ def test_map_and_batch(self):
except tensorflow.errors.OutOfRangeError:
break

def test_feed_batches(self):
"""
Test that an ArrowDataset can initialize an iterator to feed a placeholder
"""
truth_data = TruthData(
[list(range(10)), [x * 1.1 for x in range(10)]],
(dtypes.int32, dtypes.float64),
(tensorflow.TensorShape([]), tensorflow.TensorShape([])))
batch = self.make_record_batch(truth_data)

buf = io.BytesIO()
writer = pa.RecordBatchFileWriter(buf, batch.schema)
writer.write_batch(batch)
writer.close()

buf_placeholder = tensorflow.compat.v1.placeholder(
tensorflow.dtypes.string, tensorflow.TensorShape([]))

dataset = arrow_io.ArrowDataset(
buf_placeholder,
list(range(len(truth_data.output_types))),
truth_data.output_types,
truth_data.output_shapes)
it = dataset.make_initializable_iterator()
next_element = it.get_next()

with self.test_session() as sess:
sess.run(it.initializer, feed_dict={buf_placeholder: buf.getvalue()})
for row in range(len(truth_data.data)):
value = sess.run(next_element)
self.assertEqual(value[0], truth_data.data[0][row])
self.assertAlmostEqual(value[1], truth_data.data[1][row], 4)


if __name__ == "__main__":
test.main()