From 9947518fe4adeb2ecaced967e20297f3eea97cfa Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 17 Apr 2019 14:51:38 -0700 Subject: [PATCH] Enable ArrowDataset to be fed by a placeholder Tensor --- .../arrow/kernels/arrow_dataset_ops.cc | 36 +++++++++------ .../arrow/python/ops/arrow_dataset_ops.py | 46 +++++++++++++------ tests/test_arrow.py | 44 ++++++++++++++++-- 3 files changed, 95 insertions(+), 31 deletions(-) diff --git a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc index 40e7112f8..fd082c3cd 100644 --- a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc +++ b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "arrow/util/io-util.h" #include "tensorflow_io/arrow/kernels/arrow_stream_client.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/graph/graph.h" #define CHECK_ARROW(arrow_status) \ do { \ @@ -362,8 +363,6 @@ class ArrowOpKernelBase : public DatasetOpKernel { // memory in a Python process, or a Pandas DataFrame. class ArrowDatasetOp : public ArrowOpKernelBase { public: - //using DatasetOpKernel::DatasetOpKernel; - explicit ArrowDatasetOp(OpKernelConstruction* ctx) : ArrowOpKernelBase(ctx) {} virtual void MakeArrowDataset( @@ -374,11 +373,14 @@ class ArrowDatasetOp : public ArrowOpKernelBase { const Tensor* batches_tensor; OP_REQUIRES_OK(ctx, ctx->input("serialized_batches", &batches_tensor)); OP_REQUIRES( - ctx, batches_tensor->dims() <= 0, - errors::InvalidArgument("`serialized_batches` must be a scalar.")); - string batches = batches_tensor->flat()(0); - - *output = new Dataset(ctx, batches, columns, output_types_, output_shapes_); + ctx, TensorShapeUtils::IsScalar(batches_tensor->shape()), + errors::InvalidArgument("serialized_batches must be a scalar")); + *output = new Dataset( + ctx, + *batches_tensor, + columns, + output_types_, + output_shapes_); } private: @@ -386,12 +388,14 @@ class ArrowDatasetOp : public ArrowOpKernelBase { public: // Construct a Dataset that consumed Arrow batches from serialized bytes // in a string. Record batches should be serialized in Arrow File format. - Dataset(OpKernelContext* ctx, const string& serialized_batches, + Dataset(OpKernelContext* ctx, + const Tensor batches_tensor, const std::vector& columns, const DataTypeVector& output_types, const std::vector& output_shapes) : ArrowDatasetBase(ctx, columns, output_types, output_shapes), - batches_(serialized_batches) {} + batches_(std::move(batches_tensor)) { + } string DebugString() const override { return "ArrowDatasetOp::Dataset"; } @@ -400,7 +404,13 @@ class ArrowDatasetOp : public ArrowOpKernelBase { DatasetGraphDefBuilder* b, Node** output) const override { Node* batches = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(batches_, &batches)); + if (ctx->optimization_only()) { + TF_RETURN_IF_ERROR(b->AddPlaceholder(batches_, &batches)); + DCHECK_NE(ctx->input_list(), nullptr); + ctx->input_list()->emplace_back(batches->name(), batches_); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(batches_, &batches)); + } Node* columns = nullptr; TF_RETURN_IF_ERROR(b->AddVector(columns_, &columns)); TF_RETURN_IF_ERROR(b->AddDataset(this, {batches, columns}, output)); @@ -422,8 +432,8 @@ class ArrowDatasetOp : public ArrowOpKernelBase { private: Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) override { - std::shared_ptr buffer; - CHECK_ARROW(arrow::Buffer::FromString(dataset()->batches_, &buffer)); + const string& batches = dataset()->batches_.scalar()(); + auto buffer = std::make_shared(batches); auto buffer_reader = std::make_shared(buffer); CHECK_ARROW( arrow::ipc::RecordBatchFileReader::Open(buffer_reader, &reader_)); @@ -458,7 +468,7 @@ class ArrowDatasetOp : public ArrowOpKernelBase { int num_batches_ GUARDED_BY(mu_) = 0; }; - const string batches_; + const Tensor batches_; }; }; diff --git a/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py b/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py index 14a9651a0..9f1e33364 100644 --- a/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py +++ b/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py @@ -120,25 +120,52 @@ class ArrowDataset(ArrowBaseDataset): """ def __init__(self, - record_batches, + serialized_batches, columns, output_types, output_shapes=None): - """Create an ArrowDataset directly from Arrow record batches. + """Create an ArrowDataset from a Tensor of serialized batches. This constructor requires pyarrow to be installed. Args: - record_batches: An Arrow record batch or sequence of record batches + serialized_batches: A string Tensor as a serialized buffer containing + Arrow record batches as Arrow file format columns: A list of column indices to be used in the Dataset output_types: Tensor dtypes of the output tensors output_shapes: TensorShapes of the output tensors or None to infer partial """ + self._serialized_batches = serialized_batches self._columns = columns self._output_types = output_types self._output_shapes = output_shapes or \ nest.map_structure( lambda _: tensorflow.TensorShape(None), self._output_types) + super(ArrowDataset, self).__init__(columns, output_types, output_shapes) + + def _as_variant_tensor(self): + return arrow_ops.arrow_dataset( + self._serialized_batches, + self._columns, + nest.flatten(self.output_types), + nest.flatten(self.output_shapes)) + + @classmethod + def from_record_batches(cls, + record_batches, + columns, + output_types, + output_shapes=None): + """Create an ArrowDataset directly from Arrow record batches. + This constructor requires pyarrow to be installed. + + Args: + record_batches: An Arrow record batch or sequence of record batches + columns: A list of column indices to be used in the Dataset + output_types: Tensor dtypes of the output tensors + output_shapes: TensorShapes of the output tensors or None to + infer partial + """ import pyarrow as pa if isinstance(record_batches, pa.RecordBatch): record_batches = [record_batches] @@ -148,18 +175,11 @@ def __init__(self, for batch in record_batches: writer.write_batch(batch) writer.close() - self._serialized_batches = tensorflow.convert_to_tensor( + serialized_batches = tensorflow.convert_to_tensor( buf.getvalue(), dtype=dtypes.string, name="serialized_batches") - super(ArrowDataset, self).__init__(columns, output_types, output_shapes) - - def _as_variant_tensor(self): - return arrow_ops.arrow_dataset( - self._serialized_batches, - self._columns, - nest.flatten(self.output_types), - nest.flatten(self.output_shapes)) + return cls(serialized_batches, columns, output_types, output_shapes) @classmethod def from_pandas(cls, df, columns=None, preserve_index=True): @@ -179,7 +199,7 @@ def from_pandas(cls, df, columns=None, preserve_index=True): batch = pa.RecordBatch.from_pandas(df, preserve_index=preserve_index) columns = tuple(range(batch.num_columns)) output_types, output_shapes = arrow_schema_to_tensor_types(batch.schema) - return cls(batch, columns, output_types, output_shapes) + return cls.from_record_batches(batch, columns, output_types, output_shapes) class ArrowFeatherDataset(ArrowBaseDataset): diff --git a/tests/test_arrow.py b/tests/test_arrow.py index e8eea2742..2e8983c78 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -19,6 +19,7 @@ from __future__ import print_function from collections import namedtuple +import io import os import sys import socket @@ -175,7 +176,7 @@ def test_arrow_dataset(self): batch = self.make_record_batch(truth_data) # test all columns selected - dataset = arrow_io.ArrowDataset( + dataset = arrow_io.ArrowDataset.from_record_batches( batch, list(range(len(truth_data.output_types))), truth_data.output_types, @@ -184,7 +185,7 @@ def test_arrow_dataset(self): # test column selection columns = (1, 3, len(truth_data.output_types) - 1) - dataset = arrow_io.ArrowDataset( + dataset = arrow_io.ArrowDataset.from_record_batches( batch, columns, tuple([truth_data.output_types[c] for c in columns]), @@ -326,7 +327,7 @@ def test_bool_array_type(self): batch = self.make_record_batch(truth_data) - dataset = arrow_io.ArrowDataset( + dataset = arrow_io.ArrowDataset.from_record_batches( batch, (0,), truth_data.output_types, @@ -339,7 +340,7 @@ def test_incorrect_column_type(self): self.scalar_shapes) batch = self.make_record_batch(truth_data) - dataset = arrow_io.ArrowDataset( + dataset = arrow_io.ArrowDataset.from_record_batches( batch, list(range(len(truth_data.output_types))), tuple([dtypes.int32 for _ in truth_data.output_types]), @@ -357,7 +358,7 @@ def test_map_and_batch(self): (dtypes.int32,), (tensorflow.TensorShape([]),)) batch = self.make_record_batch(truth_data) - dataset = arrow_io.ArrowDataset( + dataset = arrow_io.ArrowDataset.from_record_batches( batch, list(range(len(truth_data.output_types))), truth_data.output_types, @@ -379,6 +380,39 @@ def test_map_and_batch(self): except tensorflow.errors.OutOfRangeError: break + def test_feed_batches(self): + """ + Test that an ArrowDataset can initialize an iterator to feed a placeholder + """ + truth_data = TruthData( + [list(range(10)), [x * 1.1 for x in range(10)]], + (dtypes.int32, dtypes.float64), + (tensorflow.TensorShape([]), tensorflow.TensorShape([]))) + batch = self.make_record_batch(truth_data) + + buf = io.BytesIO() + writer = pa.RecordBatchFileWriter(buf, batch.schema) + writer.write_batch(batch) + writer.close() + + buf_placeholder = tensorflow.compat.v1.placeholder( + tensorflow.dtypes.string, tensorflow.TensorShape([])) + + dataset = arrow_io.ArrowDataset( + buf_placeholder, + list(range(len(truth_data.output_types))), + truth_data.output_types, + truth_data.output_shapes) + it = dataset.make_initializable_iterator() + next_element = it.get_next() + + with self.test_session() as sess: + sess.run(it.initializer, feed_dict={buf_placeholder: buf.getvalue()}) + for row in range(len(truth_data.data)): + value = sess.run(next_element) + self.assertEqual(value[0], truth_data.data[0][row]) + self.assertAlmostEqual(value[1], truth_data.data[1][row], 4) + if __name__ == "__main__": test.main()