From 3133e030fe9c52b32ef5657a15e0fdcef931b24d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 24 Apr 2019 22:04:02 +0000 Subject: [PATCH 1/3] Refactor data input to split FileInput from StreamInput Signed-off-by: Yong Tang --- .../cifar/kernels/cifar_dataset_ops.cc | 20 +- tensorflow_io/core/kernels/dataset_ops.h | 416 ++++++++++++++---- .../mnist/kernels/mnist_dataset_ops.cc | 28 +- tensorflow_io/text/kernels/text_input.cc | 12 +- 4 files changed, 371 insertions(+), 105 deletions(-) diff --git a/tensorflow_io/cifar/kernels/cifar_dataset_ops.cc b/tensorflow_io/cifar/kernels/cifar_dataset_ops.cc index ea5330b45..79e64390c 100644 --- a/tensorflow_io/cifar/kernels/cifar_dataset_ops.cc +++ b/tensorflow_io/cifar/kernels/cifar_dataset_ops.cc @@ -23,9 +23,9 @@ limitations under the License. namespace tensorflow { namespace data { namespace { -class CIFAR10Input: public DataInput { +class CIFAR10Input: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { state.reset(new int64(0)); } @@ -45,7 +45,7 @@ class CIFAR10Input: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { return Status::OK(); } void EncodeAttributes(VariantTensorData* data) const override { @@ -56,9 +56,9 @@ class CIFAR10Input: public DataInput { protected: }; -class CIFAR100Input: public DataInput { +class CIFAR100Input: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { state.reset(new int64(0)); } @@ -81,7 +81,7 @@ class CIFAR100Input: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { return Status::OK(); } void EncodeAttributes(VariantTensorData* data) const override { @@ -96,13 +96,13 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(CIFAR10Input, "tensorflow::CIFAR10Input") REGISTER_UNARY_VARIANT_DECODE_FUNCTION(CIFAR100Input, "tensorflow::CIFAR100Input"); REGISTER_KERNEL_BUILDER(Name("CIFAR10Input").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("CIFAR100Input").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("CIFAR10Dataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); REGISTER_KERNEL_BUILDER(Name("CIFAR100Dataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/dataset_ops.h b/tensorflow_io/core/kernels/dataset_ops.h index 1c4b31714..d0fc12231 100644 --- a/tensorflow_io/core/kernels/dataset_ops.h +++ b/tensorflow_io/core/kernels/dataset_ops.h @@ -147,36 +147,84 @@ class ArchiveInputStream : public io::InputStreamInterface { TF_DISALLOW_COPY_AND_ASSIGN(ArchiveInputStream); }; +// Note: Forward declaration for friend class. +template class FileInput; +template class StreamInput; + template class DataInput { public: DataInput() {} virtual ~DataInput() {} - virtual Status FromStream(io::InputStreamInterface& s) = 0; - virtual Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const = 0; + protected: virtual void EncodeAttributes(VariantTensorData* data) const = 0; virtual bool DecodeAttributes(const VariantTensorData& data) = 0; - - Status ReadInputStream(io::InputStreamInterface& s, int64 chunk, int64 count, string* buffer, int64* returned) const { - int64 offset = s.Tell(); - int64 bytes_to_read = count * chunk; - Status status = (buffer == nullptr) ? s.SkipNBytes(bytes_to_read) : s.ReadNBytes(bytes_to_read, buffer); - if (!(status.ok() || status == errors::OutOfRange("EOF reached"))) { - return status; - } - int64 bytes_read = s.Tell() - offset; - if (bytes_read % chunk != 0) { - return errors::DataLoss("corrupted data, expected multiple of ", chunk, ", received ", bytes_read); + virtual Status ReadReferenceRecord(void* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const = 0; + Status ReadReferenceBatchRecord(void* s, IteratorContext* ctx, std::unique_ptr& state, int64 batch, int64 count, int64* returned, std::vector* out_tensors) const { + int64 record_read = 0; + int64 record_to_read = count - (*returned); + std::vector chunk_tensors; + TF_RETURN_IF_ERROR(ReadReferenceRecord(s, ctx, state, record_to_read, &record_read, &chunk_tensors)); + if (record_read > 0) { + if (out_tensors->size() == 0) { + // Replace out_tensors with chunk_tensors + out_tensors->reserve(chunk_tensors.size()); + // batch == 0 could only read at most one record + // so it only happens here. + if (batch == 0) { + for (size_t i = 0; i < chunk_tensors.size(); i++) { + TensorShape shape = chunk_tensors[i].shape(); + shape.RemoveDim(0); + Tensor value_tensor(ctx->allocator({}), chunk_tensors[i].dtype(), shape); + value_tensor.CopyFrom(chunk_tensors[i], shape); + out_tensors->emplace_back(std::move(value_tensor)); + } + } else { + for (size_t i = 0; i < chunk_tensors.size(); i++) { + out_tensors->emplace_back(std::move(chunk_tensors[i])); + } + } + } else { + // Append out_tensors with chunk_tensors + for (size_t i = 0; i < out_tensors->size(); i++) { + TensorShape shape = (*out_tensors)[i].shape(); + shape.set_dim(0, shape.dim_size(0) + record_read); + Tensor value_tensor(ctx->allocator({}), (*out_tensors)[i].dtype(), shape); + TensorShape element_shape = shape; + element_shape.RemoveDim(0); + Tensor element(ctx->allocator({}), (*out_tensors)[i].dtype(), element_shape); + for (size_t index = 0; index < (*out_tensors)[i].shape().dim_size(0); index++) { + TF_RETURN_IF_ERROR(batch_util::CopySliceToElement((*out_tensors)[i], &element, index)); + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(element, &value_tensor, index)); + } + for (size_t index = 0; index < record_read; index++) { + TF_RETURN_IF_ERROR(batch_util::CopySliceToElement(chunk_tensors[i], &element, index)); + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(element, &value_tensor, (*out_tensors)[i].shape().dim_size(0) + index)); + } + (*out_tensors)[i] = std::move(value_tensor); + } + } + (*returned) += record_read; } - *returned = bytes_read / chunk; return Status::OK(); } - Status FromInputStream(io::InputStreamInterface& s, const string& filename, const string& entryname, const string& filtername) { + friend class FileInput; + friend class StreamInput; +}; +template +class FileInput : public DataInput { + public: + FileInput() {} + virtual ~FileInput() {} + Status FromInputStream(io::InputStreamInterface* s, const string& filename, const string& entryname, const string& filtername) { filename_ = filename; entryname_ = entryname; filtername_ = filtername; return FromStream(s); } + Status ReadBatchRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 batch, int64 count, int64* returned, std::vector* out_tensors) const { + return (static_cast *>(this))->ReadReferenceBatchRecord(static_cast(s), ctx, state, batch, count, returned, out_tensors); + } void Encode(VariantTensorData* data) const { data->tensors_ = {Tensor(DT_STRING, TensorShape({})), Tensor(DT_STRING, TensorShape({})), Tensor(DT_STRING, TensorShape({}))}; data->tensors_[0].scalar()() = filename_; @@ -202,15 +250,36 @@ class DataInput { return filtername_; } protected: + virtual Status FromStream(io::InputStreamInterface* s) = 0; + virtual Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const = 0; + virtual void EncodeAttributes(VariantTensorData* data) const = 0; + virtual bool DecodeAttributes(const VariantTensorData& data) = 0; + Status ReadInputStream(io::InputStreamInterface* s, int64 chunk, int64 count, string* buffer, int64* returned) const { + int64 offset = s->Tell(); + int64 bytes_to_read = count * chunk; + Status status = (buffer == nullptr) ? s->SkipNBytes(bytes_to_read) : s->ReadNBytes(bytes_to_read, buffer); + if (!(status.ok() || status == errors::OutOfRange("EOF reached"))) { + return status; + } + int64 bytes_read = s->Tell() - offset; + if (bytes_read % chunk != 0) { + return errors::DataLoss("corrupted data, expected multiple of ", chunk, ", received ", bytes_read); + } + *returned = bytes_read / chunk; + return Status::OK(); + } + Status ReadReferenceRecord(void* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + return ReadRecord(static_cast(s), ctx, state, record_to_read, record_read, out_tensors); + } string filename_; string entryname_; string filtername_; }; template -class DataInputOp: public OpKernel { +class FileInputOp: public OpKernel { public: - explicit DataInputOp(OpKernelConstruction* context) : OpKernel(context) { + explicit FileInputOp(OpKernelConstruction* context) : OpKernel(context) { env_ = context->env(); OP_REQUIRES_OK(context, context->GetAttr("filters", &filters_)); } @@ -236,7 +305,7 @@ class DataInputOp: public OpKernel { // No filter means only a file stream. io::RandomAccessInputStream file_stream(file.get()); T entry; - OP_REQUIRES_OK(ctx, entry.FromInputStream(file_stream, filename, string(""), string(""))); + OP_REQUIRES_OK(ctx, entry.FromInputStream(&file_stream, filename, string(""), string(""))); output.emplace_back(std::move(entry)); continue; } @@ -264,7 +333,7 @@ class DataInputOp: public OpKernel { // none with text type correctly (not reading data in none archive) // So use the shortcut here. io::RandomAccessInputStream file_stream(file.get()); - OP_REQUIRES_OK(ctx, entry.FromInputStream(file_stream, filename, entryname, filtername)); + OP_REQUIRES_OK(ctx, entry.FromInputStream(&file_stream, filename, entryname, filtername)); } else if (filtername == "gz") { // Treat gz file specially. Looks like libarchive always have issue // with text file so use ZlibInputStream. Now libarchive @@ -272,10 +341,10 @@ class DataInputOp: public OpKernel { io::RandomAccessInputStream file_stream(file.get()); io::ZlibCompressionOptions zlib_compression_options = zlib_compression_options = io::ZlibCompressionOptions::GZIP(); io::ZlibInputStream compression_stream(&file_stream, 65536, 65536, zlib_compression_options); - OP_REQUIRES_OK(ctx, entry.FromInputStream(compression_stream, filename, entryname, filtername)); + OP_REQUIRES_OK(ctx, entry.FromInputStream(&compression_stream, filename, entryname, filtername)); } else { archive_stream.ResetEntryOffset(); - OP_REQUIRES_OK(ctx, entry.FromInputStream(archive_stream, filename, entryname, filtername)); + OP_REQUIRES_OK(ctx, entry.FromInputStream(&archive_stream, filename, entryname, filtername)); } output.emplace_back(std::move(entry)); } @@ -297,9 +366,9 @@ class DataInputOp: public OpKernel { std::vector filters_ GUARDED_BY(mu_); }; template -class InputDatasetBase : public DatasetBase { +class FileInputDatasetBase : public DatasetBase { public: - InputDatasetBase(OpKernelContext* ctx, const std::vector& input, const int64 batch, const DataTypeVector& output_types, const std::vector& output_shapes) + FileInputDatasetBase(OpKernelContext* ctx, const std::vector& input, const int64 batch, const DataTypeVector& output_types, const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), ctx_(ctx), input_(input), @@ -349,11 +418,11 @@ class InputDatasetBase : public DatasetBase { return Status::OK(); } private: - class Iterator : public DatasetIterator> { + class Iterator : public DatasetIterator> { public: - using tensorflow::data::DatasetIterator>::dataset; - explicit Iterator(const typename tensorflow::data::DatasetIterator>::Params& params) - : DatasetIterator>(params), stream_(nullptr), archive_(nullptr, [](struct archive *a){ archive_read_free(a);}), file_(nullptr){} + using tensorflow::data::DatasetIterator>::dataset; + explicit Iterator(const typename tensorflow::data::DatasetIterator>::Params& params) + : DatasetIterator>(params), stream_(nullptr), archive_(nullptr, [](struct archive *a){ archive_read_free(a);}), file_(nullptr){} Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -363,51 +432,7 @@ class InputDatasetBase : public DatasetBase { int64 count = dataset()->batch_ == 0 ? 1 : dataset()->batch_; while (returned < count) { if (stream_) { - int64 record_read = 0; - int64 record_to_read = count - returned; - std::vector chunk_tensors; - TF_RETURN_IF_ERROR(dataset()->input_[current_input_index_].ReadRecord((*stream_.get()), ctx, current_input_state_, count - returned, &record_read, &chunk_tensors)); - if (record_read > 0) { - if (out_tensors->size() == 0) { - // Replace out_tensors with chunk_tensors - out_tensors->reserve(chunk_tensors.size()); - // dataset()->batch_ == 0 could only read at most one record - // so it only happens here. - if (dataset()->batch_ == 0) { - for (size_t i = 0; i < chunk_tensors.size(); i++) { - TensorShape shape = chunk_tensors[i].shape(); - shape.RemoveDim(0); - Tensor value_tensor(ctx->allocator({}), chunk_tensors[i].dtype(), shape); - value_tensor.CopyFrom(chunk_tensors[i], shape); - out_tensors->emplace_back(std::move(value_tensor)); - } - } else { - for (size_t i = 0; i < chunk_tensors.size(); i++) { - out_tensors->emplace_back(std::move(chunk_tensors[i])); - } - } - } else { - // Append out_tensors with chunk_tensors - for (size_t i = 0; i < out_tensors->size(); i++) { - TensorShape shape = (*out_tensors)[i].shape(); - shape.set_dim(0, shape.dim_size(0) + record_read); - Tensor value_tensor(ctx->allocator({}), (*out_tensors)[i].dtype(), shape); - TensorShape element_shape = shape; - element_shape.RemoveDim(0); - Tensor element(ctx->allocator({}), (*out_tensors)[i].dtype(), element_shape); - for (size_t index = 0; index < (*out_tensors)[i].shape().dim_size(0); index++) { - TF_RETURN_IF_ERROR(batch_util::CopySliceToElement((*out_tensors)[i], &element, index)); - TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(element, &value_tensor, index)); - } - for (size_t index = 0; index < record_read; index++) { - TF_RETURN_IF_ERROR(batch_util::CopySliceToElement(chunk_tensors[i], &element, index)); - TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(element, &value_tensor, (*out_tensors)[i].shape().dim_size(0) + index)); - } - (*out_tensors)[i] = std::move(value_tensor); - } - } - returned += record_read; - } + TF_RETURN_IF_ERROR(dataset()->input_[current_input_index_].ReadBatchRecord(stream_.get(), ctx, current_input_state_, dataset()->batch_, count, &returned, out_tensors)); if (returned == count) { *end_of_sequence = false; return Status::OK(); @@ -510,10 +535,251 @@ class InputDatasetBase : public DatasetBase { }; template -class InputDatasetOp : public DatasetOpKernel { +class FileInputDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + explicit FileInputDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + OP_REQUIRES( + ctx, (input_tensor->dtype() == DT_VARIANT || input_tensor->dtype() == DT_STRING), + errors::InvalidArgument("`input` must be a variant or string, received ", input_tensor->dtype())); + OP_REQUIRES( + ctx, input_tensor->dims() <= 1, + errors::InvalidArgument("`input` must be a scalar or a vector, dim = ", input_tensor->dims())); + std::vector input; + input.reserve(input_tensor->NumElements()); + if (input_tensor->dtype() == DT_VARIANT) { + for (int i = 0; i < input_tensor->NumElements(); ++i) { + input.push_back(*(input_tensor->flat()(i).get())); + } + } else { + for (int i = 0; i < input_tensor->NumElements(); ++i) { + string message = input_tensor->flat()(i); + VariantTensorDataProto serialized_proto_f; + VariantTensorData serialized_data_f; + DecodeVariant(&message, &serialized_proto_f); + serialized_data_f.FromProto(serialized_proto_f); + InputType entry; + entry.Decode(serialized_data_f); + input.emplace_back(entry); + } + } + const Tensor* batch_tensor; + OP_REQUIRES_OK(ctx, ctx->input("batch", &batch_tensor)); + int64 batch = batch_tensor->scalar()(); + *output = new FileInputDatasetBase(ctx, input, batch, output_types_, output_shapes_); + } + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +template +class StreamInput : public DataInput { + public: + StreamInput() {} + virtual ~StreamInput() {} + Status FromInputEndpoint(const string& endpoint) { + endpoint_ = endpoint; + return FromEndpoint(endpoint); + } + void Encode(VariantTensorData* data) const { + data->tensors_ = {Tensor(DT_STRING, TensorShape({}))}; + data->tensors_[0].scalar()() = endpoint_; + + EncodeAttributes(data); + } + bool Decode(const VariantTensorData& data) { + endpoint_ = data.tensors(0).scalar()(); + + return DecodeAttributes(data); + } + const string& endpoint() const { + return endpoint_; + } + Status ReadBatchRecord(IteratorContext* ctx, std::unique_ptr& state, int64 batch, int64 count, int64* returned, std::vector* out_tensors) const { + return (static_cast *>(this))->ReadReferenceBatchRecord(nullptr, ctx, state, batch, count, returned, out_tensors); + } + protected: + virtual Status FromEndpoint(const string& endpoint) = 0; + virtual Status ReadRecord(IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const = 0; + virtual void EncodeAttributes(VariantTensorData* data) const = 0; + virtual bool DecodeAttributes(const VariantTensorData& data) = 0; + Status ReadReferenceRecord(void* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + return ReadRecord(ctx, state, record_to_read, record_read, out_tensors); + } + string endpoint_; +}; + +template +class StreamInputOp: public OpKernel { + public: + explicit StreamInputOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + void Compute(OpKernelContext* ctx) override { + const Tensor* source_tensor; + OP_REQUIRES_OK(ctx, ctx->input("source", &source_tensor)); + OP_REQUIRES( + ctx, source_tensor->dims() <= 1, + errors::InvalidArgument("`source` must be a scalar or a vector.")); + + std::vector source; + source.reserve(source_tensor->NumElements()); + for (int i = 0; i < source_tensor->NumElements(); ++i) { + source.push_back(source_tensor->flat()(i)); + } + + std::vector output; + + for (const auto& endpoint: source) { + T entry; + OP_REQUIRES_OK(ctx, entry.FromInputEndpoint(endpoint)); + output.emplace_back(std::move(entry)); + } + + Tensor* output_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({static_cast(output.size())}), &output_tensor)); + for (int i = 0; i < output.size(); i++) { + output_tensor->flat()(i) = output[i]; + } + } + protected: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; +template +class StreamInputDatasetBase : public DatasetBase { + public: + StreamInputDatasetBase(OpKernelContext* ctx, const std::vector& input, const int64 batch, const DataTypeVector& output_types, const std::vector& output_shapes) + : DatasetBase(DatasetContext(ctx)), + ctx_(ctx), + input_(input), + batch_(batch), + output_types_(output_types), + output_shapes_(output_shapes) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, DebugString())})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "InputDatasetBase::Dataset"; + } + + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { + Node* input_node; + Tensor input_tensor(DT_STRING, TensorShape({static_cast(input_.size())})); + // GraphDefInternal has some trouble with Variant so use serialized string. + for (size_t i = 0; i < input_.size(); i++) { + string message; + VariantTensorData serialized_data_f; + VariantTensorDataProto serialized_proto_f; + input_[i].Encode(&serialized_data_f); + serialized_data_f.ToProto(&serialized_proto_f); + EncodeVariant(serialized_proto_f, &message); + input_tensor.flat()(i) = message; + } + TF_RETURN_IF_ERROR(b->AddTensor(input_tensor, &input_node)); + Node* batch_node; + Tensor batch_tensor(DT_INT64, TensorShape({})); + batch_tensor.scalar()() = batch_; + TF_RETURN_IF_ERROR(b->AddTensor(batch_tensor, &batch_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, batch_node}, node)); + return Status::OK(); + } + private: + class Iterator : public DatasetIterator> { + public: + using tensorflow::data::DatasetIterator>::dataset; + explicit Iterator(const typename tensorflow::data::DatasetIterator>::Params& params) + : DatasetIterator>(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + int64 returned = 0; + int64 count = dataset()->batch_ == 0 ? 1 : dataset()->batch_; + while (returned < count) { + if (current_input_index_ < dataset()->input_.size()) { + TF_RETURN_IF_ERROR(dataset()->input_[current_input_index_].ReadBatchRecord(ctx, current_input_state_, dataset()->batch_, count, &returned, out_tensors)); + if (returned == count) { + *end_of_sequence = false; + return Status::OK(); + } + // We have reached the end of the current input, move next. + ResetStreamsLocked(); + ++current_input_index_; + } + // Iteration ends when there are no more input to process. + if (current_input_index_ == dataset()->input_.size()) { + if (out_tensors->size() != 0) { + *end_of_sequence = false; + return Status::OK(); + } + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + }; + } + + private: + // Sets up streams to read from `current_input_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_input_index_ >= dataset()->input_.size()) { + return errors::InvalidArgument( + "current_input_index_:", current_input_index_, + " >= input_.size():", dataset()->input_.size()); + } + + // Actually move on to next entry. + current_input_state_.reset(nullptr); + + return Status::OK(); + } + + // Resets all streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + current_input_state_.reset(nullptr); + } + + mutex mu_; + size_t current_input_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr current_input_state_ GUARDED_BY(mu_); + }; + OpKernelContext* ctx_; + protected: + std::vector input_; + int64 batch_; + const DataTypeVector output_types_; + const std::vector output_shapes_; +}; + +template +class StreamInputDatasetOp : public DatasetOpKernel { public: using DatasetOpKernel::DatasetOpKernel; - explicit InputDatasetOp(OpKernelConstruction* ctx) + explicit StreamInputDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -548,7 +814,7 @@ class InputDatasetOp : public DatasetOpKernel { const Tensor* batch_tensor; OP_REQUIRES_OK(ctx, ctx->input("batch", &batch_tensor)); int64 batch = batch_tensor->scalar()(); - *output = new InputDatasetBase(ctx, input, batch, output_types_, output_shapes_); + *output = new StreamInputDatasetBase(ctx, input, batch, output_types_, output_shapes_); } DataTypeVector output_types_; std::vector output_shapes_; diff --git a/tensorflow_io/mnist/kernels/mnist_dataset_ops.cc b/tensorflow_io/mnist/kernels/mnist_dataset_ops.cc index 2a05962e2..2698170fd 100644 --- a/tensorflow_io/mnist/kernels/mnist_dataset_ops.cc +++ b/tensorflow_io/mnist/kernels/mnist_dataset_ops.cc @@ -18,12 +18,12 @@ limitations under the License. namespace tensorflow { namespace data { -class MNISTImageInput: public DataInput { +class MNISTImageInput: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { state.reset(new int64(0)); - TF_RETURN_IF_ERROR(s.SkipNBytes(16)); + TF_RETURN_IF_ERROR(s->SkipNBytes(16)); } string buffer; Status status = ReadInputStream(s, (rows_ * cols_), record_to_read, &buffer, record_read); @@ -38,9 +38,9 @@ class MNISTImageInput: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { string header; - TF_RETURN_IF_ERROR(s.ReadNBytes(16, &header)); + TF_RETURN_IF_ERROR(s->ReadNBytes(16, &header)); if (header[0] != 0x00 || header[1] != 0x00 || header[2] != 0x08 || header[3] != 0x03) { return errors::InvalidArgument("mnist image file header must starts with `0x00000803`"); } @@ -68,12 +68,12 @@ class MNISTImageInput: public DataInput { int64 rows_; int64 cols_; }; -class MNISTLabelInput: public DataInput { +class MNISTLabelInput: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { state.reset(new int64(0)); - TF_RETURN_IF_ERROR(s.SkipNBytes(8)); + TF_RETURN_IF_ERROR(s->SkipNBytes(8)); } string buffer; TF_RETURN_IF_ERROR(ReadInputStream(s, 1, record_to_read, &buffer, record_read)); @@ -85,9 +85,9 @@ class MNISTLabelInput: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { string header; - TF_RETURN_IF_ERROR(s.ReadNBytes(8, &header)); + TF_RETURN_IF_ERROR(s->ReadNBytes(8, &header)); if (header[0] != 0x00 || header[1] != 0x00 || header[2] != 0x08 || header[3] != 0x01) { return errors::InvalidArgument("mnist label file header must starts with `0x00000801`"); } @@ -110,12 +110,12 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(MNISTLabelInput, "tensorflow::data::MNIST REGISTER_UNARY_VARIANT_DECODE_FUNCTION(MNISTImageInput, "tensorflow::data::MNISTImageInput"); REGISTER_KERNEL_BUILDER(Name("MNISTLabelInput").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("MNISTImageInput").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("MNISTLabelDataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); REGISTER_KERNEL_BUILDER(Name("MNISTImageDataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/text/kernels/text_input.cc b/tensorflow_io/text/kernels/text_input.cc index 14effc47f..d47b5cc4c 100644 --- a/tensorflow_io/text/kernels/text_input.cc +++ b/tensorflow_io/text/kernels/text_input.cc @@ -19,11 +19,11 @@ limitations under the License. namespace tensorflow { namespace data { -class TextInput: public DataInput { +class TextInput: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { - state.reset(new io::BufferedInputStream(&s, 4096)); + state.reset(new io::BufferedInputStream(s, 4096)); } std::vector records; records.reserve(record_to_read); @@ -49,7 +49,7 @@ class TextInput: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { // TODO: Read 4K buffer to detect BOM. //string header; //TF_RETURN_IF_ERROR(s.ReadNBytes(4096, &header)); @@ -71,8 +71,8 @@ class TextInput: public DataInput { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TextInput, "tensorflow::data::TextInput"); REGISTER_KERNEL_BUILDER(Name("TextInput").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("TextDataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); } // namespace data } // namespace tensorflow From f1e1e90cc192642b62c291ed3682190099cef451 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 24 Apr 2019 22:04:39 +0000 Subject: [PATCH 2/3] Add GRPCDataset to allow pulling data from a gRPC server This fix tries to have a way to pull data from a gRPC server, as long as a gRPC server implement the protocol in endpoint.proto. The main purpose is to allow reading partial data from numpy array in memory. In the from_numpy method, the following is done: - Create a gRPC server and exposes ReadRecord endpoint. - Start the gRPC server on a random port in local host. - Pass endpoint () to GRPCDataset - GRPCDataset will pull the data from gRPC server. Note from_numpy is just a facility method to setup a gRPC server. In theory, a GRPC server could be created by any other process and by any language. Only the endpoint information is needed to create GRPCDataset. In case numpy array is huge, then this method could be helpful as it is not required to save numpy file, and load back by tf.data, or pass the whole numpy data into a tensor. This could open up doors for other languages as well. For example, in case of R, we could setup a grpc server with R, and pass R dataframes in memory and allows GRPCDataset to pull a chunk of the data at a time. Note we create gRPC server locally, but it is possible to expose gRPC server remotely so that the workload could be distributed. Signed-off-by: Yong Tang --- .travis/python3.7+.release.sh | 6 +- .travis/wheel.test.sh | 3 +- configure.sh | 2 +- setup.py | 4 +- tensorflow_io/grpc/BUILD | 65 +++++++++++++++ tensorflow_io/grpc/__init__.py | 32 +++++++ tensorflow_io/grpc/endpoint.proto | 17 ++++ tensorflow_io/grpc/kernels/grpc_input.cc | 76 +++++++++++++++++ tensorflow_io/grpc/ops/grpc_ops.cc | 43 ++++++++++ tensorflow_io/grpc/python/__init__.py | 0 tensorflow_io/grpc/python/ops/__init__.py | 0 .../grpc/python/ops/grpc_endpoint.py | 63 ++++++++++++++ tensorflow_io/grpc/python/ops/grpc_ops.py | 83 +++++++++++++++++++ tests/test_grpc.py | 46 ++++++++++ third_party/libarchive.BUILD | 2 +- 15 files changed, 437 insertions(+), 5 deletions(-) create mode 100644 tensorflow_io/grpc/BUILD create mode 100644 tensorflow_io/grpc/__init__.py create mode 100644 tensorflow_io/grpc/endpoint.proto create mode 100644 tensorflow_io/grpc/kernels/grpc_input.cc create mode 100644 tensorflow_io/grpc/ops/grpc_ops.cc create mode 100644 tensorflow_io/grpc/python/__init__.py create mode 100644 tensorflow_io/grpc/python/ops/__init__.py create mode 100644 tensorflow_io/grpc/python/ops/grpc_endpoint.py create mode 100644 tensorflow_io/grpc/python/ops/grpc_ops.py create mode 100644 tests/test_grpc.py diff --git a/.travis/python3.7+.release.sh b/.travis/python3.7+.release.sh index 111505ae1..da98ce3d8 100755 --- a/.travis/python3.7+.release.sh +++ b/.travis/python3.7+.release.sh @@ -36,7 +36,11 @@ if [[ "$#" -gt 0 ]]; then shift fi -apt-get -y -qq update && apt-get -y -qq install $PYTHON_VERSION +apt-get -y -qq update && apt-get -y -qq install python $PYTHON_VERSION +python get-pip.py -q +python -m pip --version +python -m pip install -q grpcio-tools + $PYTHON_VERSION get-pip.py -q $PYTHON_VERSION -m pip --version diff --git a/.travis/wheel.test.sh b/.travis/wheel.test.sh index 6afd48fa5..bb6a9da9c 100755 --- a/.travis/wheel.test.sh +++ b/.travis/wheel.test.sh @@ -5,8 +5,9 @@ run_test() { CPYTHON_VERSION=$($entry -c 'import sys; print(str(sys.version_info[0])+str(sys.version_info[1]))') (cd wheelhouse && $entry -m pip install *-cp${CPYTHON_VERSION}-*.whl) $entry -m pip install -q pytest boto3 google-cloud-pubsub==0.39.1 pyarrow==0.11.1 pandas==0.19.2 - (cd tests && $entry -m pytest -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! -iname "test_*_eager.py" \))) + (cd tests && $entry -m pytest -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! \( -iname "test_*_eager.py" -o -iname "test_grpc.py" \) \))) (cd tests && $entry -m pytest -v --import-mode=append $(find . -type f \( -iname "test_*_eager.py" \))) + (cd tests && $entry -m pytest -v --import-mode=append $(find . -type f \( -iname "test_grpc.py" \))) } PYTHON_VERSION=python diff --git a/configure.sh b/configure.sh index 90a27a0b0..16d7e9006 100755 --- a/configure.sh +++ b/configure.sh @@ -20,5 +20,5 @@ if python -c "import tensorflow" &> /dev/null; then else pip install tensorflow fi - +python -m pip install grpcio-tools python config_helper.py diff --git a/setup.py b/setup.py index dc5f67174..ba74903e5 100644 --- a/setup.py +++ b/setup.py @@ -149,7 +149,9 @@ def has_ext_modules(self): os.path.join(datapath, "tensorflow_io")): if (not fnmatch.fnmatch(rootname, "*test*") and not fnmatch.fnmatch(rootname, "*runfiles*")): - for filename in fnmatch.filter(filenames, "*.so"): + for filename in [ + f for f in filenames if fnmatch.fnmatch( + f, "*.so") or fnmatch.fnmatch(f, "*.py")]: src = os.path.join(rootname, filename) dst = os.path.join( rootpath, diff --git a/tensorflow_io/grpc/BUILD b/tensorflow_io/grpc/BUILD new file mode 100644 index 000000000..1b108d311 --- /dev/null +++ b/tensorflow_io/grpc/BUILD @@ -0,0 +1,65 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") + +genrule( + name = "endpoint_py", + srcs = [ + "endpoint.proto", + ], + outs = [ + "python/ops/__init__.py", + "python/ops/endpoint_pb2.py", + "python/ops/endpoint_pb2_grpc.py", + ], + cmd = "python -m grpc_tools.protoc -Itensorflow_io/grpc --python_out=$(BINDIR)/tensorflow_io/grpc/python/ops/ --grpc_python_out=$(BINDIR)/tensorflow_io/grpc/python/ops/ $< ; touch $(BINDIR)/tensorflow_io/grpc/python/ops/__init__.py", + output_to_bindir = True, +) + +proto_library( + name = "_any_proto_only", + deps = ["@com_google_protobuf//:any_proto"], +) + +cc_proto_library( + name = "any_proto", + deps = ["@com_google_protobuf//:any_proto"], +) + +cc_grpc_library( + name = "endpoint_cc", + srcs = [ + "endpoint.proto", + ], + proto_only = False, + well_known_protos = True, + deps = [":any_proto"], +) + +cc_binary( + name = "python/ops/_grpc_ops.so", + srcs = [ + "kernels/grpc_input.cc", + "ops/grpc_ops.cc", + ], + copts = [ + "-pthread", + "-std=c++11", + "-DNDEBUG", + ], + includes = [ + ".", + ], + linkshared = 1, + deps = [ + ":endpoint_cc", + "//tensorflow_io/core:dataset_ops", + "@com_github_grpc_grpc//:grpc++", + "@com_google_protobuf//:protobuf", + "@libarchive", + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], +) diff --git a/tensorflow_io/grpc/__init__.py b/tensorflow_io/grpc/__init__.py new file mode 100644 index 000000000..625d25f84 --- /dev/null +++ b/tensorflow_io/grpc/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GRPCInput + +@@GRPCDataset +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow_io.grpc.python.ops.grpc_ops import GRPCDataset + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "GRPCDataset", +] + +remove_undocumented(__name__) diff --git a/tensorflow_io/grpc/endpoint.proto b/tensorflow_io/grpc/endpoint.proto new file mode 100644 index 000000000..84dedb4c3 --- /dev/null +++ b/tensorflow_io/grpc/endpoint.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +import "google/protobuf/any.proto"; + +message Request { + int64 offset = 1; + int64 length = 2; +} + +message Response { + google.protobuf.Any record = 1; +} + +service GRPCEndpoint { + rpc ReadRecord(Request) returns (Response){} +} + diff --git a/tensorflow_io/grpc/kernels/grpc_input.cc b/tensorflow_io/grpc/kernels/grpc_input.cc new file mode 100644 index 000000000..c9c5aca45 --- /dev/null +++ b/tensorflow_io/grpc/kernels/grpc_input.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernels/dataset_ops.h" +#include "endpoint.grpc.pb.h" +#include + +namespace tensorflow { +namespace data { + +class GRPCInputState { +public: + GRPCInputState(const string& endpoint) : offset_(0) { + stub_ = GRPCEndpoint::NewStub(grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials())); + } + int64 offset_; + std::unique_ptr stub_; +}; + +class GRPCInput: public StreamInput { + public: + Status ReadRecord(IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + if (state.get() == nullptr) { + state.reset(new GRPCInputState(endpoint_)); + } + Request request; + request.set_offset(state.get()->offset_); + request.set_length(record_to_read); + Response response; + grpc::ClientContext context; + grpc::Status status = state.get()->stub_.get()->ReadRecord(&context, request, &response); + if (!status.ok()) { + return errors::InvalidArgument("unable to fetch data from grpc (", status.error_code(), "): ", status.error_message()); + } + TensorProto record; + response.record().UnpackTo(&record); + Tensor value_tensor; + value_tensor.FromProto(ctx->allocator({}), record); + out_tensors->emplace_back(std::move(value_tensor)); + + *record_read = value_tensor.dim_size(0); + state.get()->offset_ += *record_read; + + return Status::OK(); + } + Status FromEndpoint(const string& endpoint) override { + return Status::OK(); + } + void EncodeAttributes(VariantTensorData* data) const override { + } + bool DecodeAttributes(const VariantTensorData& data) override { + return true; + } + protected: +}; + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(GRPCInput, "tensorflow::data::GRPCInput"); + +REGISTER_KERNEL_BUILDER(Name("GRPCInput").Device(DEVICE_CPU), + StreamInputOp); +REGISTER_KERNEL_BUILDER(Name("GRPCDataset").Device(DEVICE_CPU), + StreamInputDatasetOp); +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/grpc/ops/grpc_ops.cc b/tensorflow_io/grpc/ops/grpc_ops.cc new file mode 100644 index 000000000..411e9af9b --- /dev/null +++ b/tensorflow_io/grpc/ops/grpc_ops.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("GRPCInput") + .Input("source: string") + .Output("handle: variant") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("GRPCDataset") + .Input("input: T") + .Input("batch: int64") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("T: {string, variant} = DT_VARIANT") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({})); + return Status::OK(); + }); + +} // namespace tensorflow diff --git a/tensorflow_io/grpc/python/__init__.py b/tensorflow_io/grpc/python/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorflow_io/grpc/python/ops/__init__.py b/tensorflow_io/grpc/python/ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorflow_io/grpc/python/ops/grpc_endpoint.py b/tensorflow_io/grpc/python/ops/grpc_endpoint.py new file mode 100644 index 000000000..b6a9d1d0f --- /dev/null +++ b/tensorflow_io/grpc/python/ops/grpc_endpoint.py @@ -0,0 +1,63 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GRPCEndpoint.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import concurrent.futures +import grpc +import google.protobuf.any_pb2 + +import tensorflow +# Incase test is done with TFIO_DATAPATH specified, the +# import path need to be extended to capture generated +# grpc files: +datapath = os.environ.get('TFIO_DATAPATH') +sys.path.append(os.path.abspath( + os.path.dirname(__file__) if datapath is None else os.path.join( + datapath, "tensorflow_io", "grpc", "python", "ops"))) +import endpoint_pb2 # pylint: disable=wrong-import-position,unused-import +import endpoint_pb2_grpc # pylint: disable=wrong-import-position,unused-import + +class GRPCEndpoint(endpoint_pb2_grpc.GRPCEndpointServicer): + """GRPCEndpoint""" + def __init__(self, data): + self._grpc_server = grpc.server( + concurrent.futures.ThreadPoolExecutor(max_workers=4)) + port = self._grpc_server.add_insecure_port("localhost:0") + self._endpoint = "localhost:"+str(port) + self._data = data + super(GRPCEndpoint, self).__init__() + endpoint_pb2_grpc.add_GRPCEndpointServicer_to_server( + self, self._grpc_server) + + def start(self): + self._grpc_server.start() + + def stop(self): + self._grpc_server.stop(0) + + def endpoint(self): + return self._endpoint + + def ReadRecord(self, request, context): # pylint: disable=unused-argument + tensor = tensorflow.compat.v1.make_tensor_proto( + self._data[request.offset:request.offset+request.length, :]) + record = google.protobuf.any_pb2.Any() + record.Pack(tensor) + return endpoint_pb2.Response(record=record) diff --git a/tensorflow_io/grpc/python/ops/grpc_ops.py b/tensorflow_io/grpc/python/ops/grpc_ops.py new file mode 100644 index 000000000..60975177e --- /dev/null +++ b/tensorflow_io/grpc/python/ops/grpc_ops.py @@ -0,0 +1,83 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GRPCInput.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow +from tensorflow.compat.v1 import data +from tensorflow_io import _load_library +grpc_ops = _load_library('_grpc_ops.so') + +class GRPCDataset(data.Dataset): + """A GRPC Dataset + """ + + def __init__(self, endpoint, shape, dtype, batch=None): + """Create a GRPC Reader. + + Args: + endpoint: A `tf.string` tensor containing one or more endpoints. + """ + self._data_input = grpc_ops.grpc_input(endpoint) + self._batch = 0 if batch is None else batch + shape[0] = None + self._output_shapes = tuple([ + tensorflow.TensorShape(shape[1:])]) if self._batch == 0 else tuple([ + tensorflow.TensorShape(shape)]) + self._output_types = tuple([dtype]) + self._batch = 0 if batch is None else batch + super(GRPCDataset, self).__init__() + + @staticmethod + def from_numpy(a, batch=None): + """from_numpy""" + from tensorflow_io.grpc.python.ops import grpc_endpoint + grpc_server = grpc_endpoint.GRPCEndpoint(a) + grpc_server.start() + endpoint = grpc_server.endpoint() + dtype = a.dtype + shape = list(a.shape) + batch = batch + dataset = GRPCDataset(endpoint, shape, dtype, batch=batch) + dataset._grpc_server = grpc_server # pylint: disable=protected-access + return dataset + + def __del__(self): + if self._grpc_server is not None: + self._grpc_server.stop() + + def _inputs(self): + return [] + + def _as_variant_tensor(self): + return grpc_ops.grpc_dataset( + self._data_input, + self._batch, + output_types=self.output_types, + output_shapes=self.output_shapes) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return tensorflow.Tensor + + @property + def output_types(self): + return self._output_types diff --git a/tests/test_grpc.py b/tests/test_grpc.py new file mode 100644 index 000000000..fa8957b99 --- /dev/null +++ b/tests/test_grpc.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for GRPC Input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import pytest +import tensorflow +tensorflow.compat.v1.disable_eager_execution() + +from tensorflow import errors # pylint: disable=wrong-import-position +import tensorflow_io.grpc as grpc_io # pylint: disable=wrong-import-position + +def test_grpc_input(): + """test_grpc_input""" + data = np.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + dataset = grpc_io.GRPCDataset.from_numpy(data, batch=2) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with tensorflow.compat.v1.Session() as sess: + sess.run(init_op) + v = sess.run(get_next) + assert np.alltrue(data[0:2] == v) + v = sess.run(get_next) + assert np.alltrue(data[2:3] == v) + with pytest.raises(errors.OutOfRangeError): + sess.run(get_next) + +if __name__ == "__main__": + test.main() diff --git a/third_party/libarchive.BUILD b/third_party/libarchive.BUILD index 4e215a5d7..eb2621811 100644 --- a/third_party/libarchive.BUILD +++ b/third_party/libarchive.BUILD @@ -33,7 +33,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "@zlib", + "@com_github_madler_zlib//:z", ], ) From 1ad6a2f78384bd8bafee226fa598b31e2bdef806 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 27 Apr 2019 16:46:20 +0000 Subject: [PATCH 3/3] Add addiitonal test case with tf.keras, and covers 1D situation Signed-off-by: Yong Tang --- .../grpc/python/ops/grpc_endpoint.py | 8 ++- tensorflow_io/grpc/python/ops/grpc_ops.py | 1 - tests/test_grpc_eager.py | 54 +++++++++++++++++++ 3 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 tests/test_grpc_eager.py diff --git a/tensorflow_io/grpc/python/ops/grpc_endpoint.py b/tensorflow_io/grpc/python/ops/grpc_endpoint.py index b6a9d1d0f..35b548097 100644 --- a/tensorflow_io/grpc/python/ops/grpc_endpoint.py +++ b/tensorflow_io/grpc/python/ops/grpc_endpoint.py @@ -56,8 +56,12 @@ def endpoint(self): return self._endpoint def ReadRecord(self, request, context): # pylint: disable=unused-argument - tensor = tensorflow.compat.v1.make_tensor_proto( - self._data[request.offset:request.offset+request.length, :]) + if len(self._data.shape) == 1: + tensor = tensorflow.compat.v1.make_tensor_proto( + self._data[request.offset:request.offset+request.length]) + else: + tensor = tensorflow.compat.v1.make_tensor_proto( + self._data[request.offset:request.offset+request.length, :]) record = google.protobuf.any_pb2.Any() record.Pack(tensor) return endpoint_pb2.Response(record=record) diff --git a/tensorflow_io/grpc/python/ops/grpc_ops.py b/tensorflow_io/grpc/python/ops/grpc_ops.py index 60975177e..ade5dca94 100644 --- a/tensorflow_io/grpc/python/ops/grpc_ops.py +++ b/tensorflow_io/grpc/python/ops/grpc_ops.py @@ -51,7 +51,6 @@ def from_numpy(a, batch=None): endpoint = grpc_server.endpoint() dtype = a.dtype shape = list(a.shape) - batch = batch dataset = GRPCDataset(endpoint, shape, dtype, batch=batch) dataset._grpc_server = grpc_server # pylint: disable=protected-access return dataset diff --git a/tests/test_grpc_eager.py b/tests/test_grpc_eager.py new file mode 100644 index 000000000..11c7ae378 --- /dev/null +++ b/tests/test_grpc_eager.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for GRPC Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest +import numpy as np + +import tensorflow as tf +import tensorflow_io.grpc as grpc_io + +@pytest.mark.skipif( + not (hasattr(tf, "version") and + tf.version.VERSION.startswith("2.0.")), reason=None) +def test_grpc_with_mnist_tutorial(): + """test_mnist_tutorial""" + (x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data() + x = grpc_io.GRPCDataset.from_numpy(x_train, batch=1000) + y = grpc_io.GRPCDataset.from_numpy(y_train, batch=1000) + for (i, v) in zip(range(0, 50000, 1000), x): + assert np.alltrue(x_train[i:i+1000, :] == v.numpy()) + for (i, v) in zip(range(0, 50000, 1000), y): + assert np.alltrue(y_train[i:i+1000] == v.numpy()) + d_train = tf.data.Dataset.zip((x, y)) + + model = tf.keras.models.Sequential([ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(512, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10, activation=tf.nn.softmax) + ]) + model.compile(optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(d_train, epochs=5) + +if __name__ == "__main__": + test.main()