diff --git a/.travis/python3.7+.release.sh b/.travis/python3.7+.release.sh index 111505ae1..da98ce3d8 100755 --- a/.travis/python3.7+.release.sh +++ b/.travis/python3.7+.release.sh @@ -36,7 +36,11 @@ if [[ "$#" -gt 0 ]]; then shift fi -apt-get -y -qq update && apt-get -y -qq install $PYTHON_VERSION +apt-get -y -qq update && apt-get -y -qq install python $PYTHON_VERSION +python get-pip.py -q +python -m pip --version +python -m pip install -q grpcio-tools + $PYTHON_VERSION get-pip.py -q $PYTHON_VERSION -m pip --version diff --git a/.travis/wheel.test.sh b/.travis/wheel.test.sh index 6afd48fa5..bb6a9da9c 100755 --- a/.travis/wheel.test.sh +++ b/.travis/wheel.test.sh @@ -5,8 +5,9 @@ run_test() { CPYTHON_VERSION=$($entry -c 'import sys; print(str(sys.version_info[0])+str(sys.version_info[1]))') (cd wheelhouse && $entry -m pip install *-cp${CPYTHON_VERSION}-*.whl) $entry -m pip install -q pytest boto3 google-cloud-pubsub==0.39.1 pyarrow==0.11.1 pandas==0.19.2 - (cd tests && $entry -m pytest -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! -iname "test_*_eager.py" \))) + (cd tests && $entry -m pytest -v --import-mode=append $(find . -type f \( -iname "test_*.py" ! \( -iname "test_*_eager.py" -o -iname "test_grpc.py" \) \))) (cd tests && $entry -m pytest -v --import-mode=append $(find . -type f \( -iname "test_*_eager.py" \))) + (cd tests && $entry -m pytest -v --import-mode=append $(find . -type f \( -iname "test_grpc.py" \))) } PYTHON_VERSION=python diff --git a/configure.sh b/configure.sh index 90a27a0b0..16d7e9006 100755 --- a/configure.sh +++ b/configure.sh @@ -20,5 +20,5 @@ if python -c "import tensorflow" &> /dev/null; then else pip install tensorflow fi - +python -m pip install grpcio-tools python config_helper.py diff --git a/setup.py b/setup.py index dc5f67174..ba74903e5 100644 --- a/setup.py +++ b/setup.py @@ -149,7 +149,9 @@ def has_ext_modules(self): os.path.join(datapath, "tensorflow_io")): if (not fnmatch.fnmatch(rootname, "*test*") and not fnmatch.fnmatch(rootname, "*runfiles*")): - for filename in fnmatch.filter(filenames, "*.so"): + for filename in [ + f for f in filenames if fnmatch.fnmatch( + f, "*.so") or fnmatch.fnmatch(f, "*.py")]: src = os.path.join(rootname, filename) dst = os.path.join( rootpath, diff --git a/tensorflow_io/cifar/kernels/cifar_dataset_ops.cc b/tensorflow_io/cifar/kernels/cifar_dataset_ops.cc index ea5330b45..79e64390c 100644 --- a/tensorflow_io/cifar/kernels/cifar_dataset_ops.cc +++ b/tensorflow_io/cifar/kernels/cifar_dataset_ops.cc @@ -23,9 +23,9 @@ limitations under the License. namespace tensorflow { namespace data { namespace { -class CIFAR10Input: public DataInput { +class CIFAR10Input: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { state.reset(new int64(0)); } @@ -45,7 +45,7 @@ class CIFAR10Input: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { return Status::OK(); } void EncodeAttributes(VariantTensorData* data) const override { @@ -56,9 +56,9 @@ class CIFAR10Input: public DataInput { protected: }; -class CIFAR100Input: public DataInput { +class CIFAR100Input: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { state.reset(new int64(0)); } @@ -81,7 +81,7 @@ class CIFAR100Input: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { return Status::OK(); } void EncodeAttributes(VariantTensorData* data) const override { @@ -96,13 +96,13 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(CIFAR10Input, "tensorflow::CIFAR10Input") REGISTER_UNARY_VARIANT_DECODE_FUNCTION(CIFAR100Input, "tensorflow::CIFAR100Input"); REGISTER_KERNEL_BUILDER(Name("CIFAR10Input").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("CIFAR100Input").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("CIFAR10Dataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); REGISTER_KERNEL_BUILDER(Name("CIFAR100Dataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/dataset_ops.h b/tensorflow_io/core/kernels/dataset_ops.h index 1c4b31714..d0fc12231 100644 --- a/tensorflow_io/core/kernels/dataset_ops.h +++ b/tensorflow_io/core/kernels/dataset_ops.h @@ -147,36 +147,84 @@ class ArchiveInputStream : public io::InputStreamInterface { TF_DISALLOW_COPY_AND_ASSIGN(ArchiveInputStream); }; +// Note: Forward declaration for friend class. +template class FileInput; +template class StreamInput; + template class DataInput { public: DataInput() {} virtual ~DataInput() {} - virtual Status FromStream(io::InputStreamInterface& s) = 0; - virtual Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const = 0; + protected: virtual void EncodeAttributes(VariantTensorData* data) const = 0; virtual bool DecodeAttributes(const VariantTensorData& data) = 0; - - Status ReadInputStream(io::InputStreamInterface& s, int64 chunk, int64 count, string* buffer, int64* returned) const { - int64 offset = s.Tell(); - int64 bytes_to_read = count * chunk; - Status status = (buffer == nullptr) ? s.SkipNBytes(bytes_to_read) : s.ReadNBytes(bytes_to_read, buffer); - if (!(status.ok() || status == errors::OutOfRange("EOF reached"))) { - return status; - } - int64 bytes_read = s.Tell() - offset; - if (bytes_read % chunk != 0) { - return errors::DataLoss("corrupted data, expected multiple of ", chunk, ", received ", bytes_read); + virtual Status ReadReferenceRecord(void* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const = 0; + Status ReadReferenceBatchRecord(void* s, IteratorContext* ctx, std::unique_ptr& state, int64 batch, int64 count, int64* returned, std::vector* out_tensors) const { + int64 record_read = 0; + int64 record_to_read = count - (*returned); + std::vector chunk_tensors; + TF_RETURN_IF_ERROR(ReadReferenceRecord(s, ctx, state, record_to_read, &record_read, &chunk_tensors)); + if (record_read > 0) { + if (out_tensors->size() == 0) { + // Replace out_tensors with chunk_tensors + out_tensors->reserve(chunk_tensors.size()); + // batch == 0 could only read at most one record + // so it only happens here. + if (batch == 0) { + for (size_t i = 0; i < chunk_tensors.size(); i++) { + TensorShape shape = chunk_tensors[i].shape(); + shape.RemoveDim(0); + Tensor value_tensor(ctx->allocator({}), chunk_tensors[i].dtype(), shape); + value_tensor.CopyFrom(chunk_tensors[i], shape); + out_tensors->emplace_back(std::move(value_tensor)); + } + } else { + for (size_t i = 0; i < chunk_tensors.size(); i++) { + out_tensors->emplace_back(std::move(chunk_tensors[i])); + } + } + } else { + // Append out_tensors with chunk_tensors + for (size_t i = 0; i < out_tensors->size(); i++) { + TensorShape shape = (*out_tensors)[i].shape(); + shape.set_dim(0, shape.dim_size(0) + record_read); + Tensor value_tensor(ctx->allocator({}), (*out_tensors)[i].dtype(), shape); + TensorShape element_shape = shape; + element_shape.RemoveDim(0); + Tensor element(ctx->allocator({}), (*out_tensors)[i].dtype(), element_shape); + for (size_t index = 0; index < (*out_tensors)[i].shape().dim_size(0); index++) { + TF_RETURN_IF_ERROR(batch_util::CopySliceToElement((*out_tensors)[i], &element, index)); + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(element, &value_tensor, index)); + } + for (size_t index = 0; index < record_read; index++) { + TF_RETURN_IF_ERROR(batch_util::CopySliceToElement(chunk_tensors[i], &element, index)); + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(element, &value_tensor, (*out_tensors)[i].shape().dim_size(0) + index)); + } + (*out_tensors)[i] = std::move(value_tensor); + } + } + (*returned) += record_read; } - *returned = bytes_read / chunk; return Status::OK(); } - Status FromInputStream(io::InputStreamInterface& s, const string& filename, const string& entryname, const string& filtername) { + friend class FileInput; + friend class StreamInput; +}; +template +class FileInput : public DataInput { + public: + FileInput() {} + virtual ~FileInput() {} + Status FromInputStream(io::InputStreamInterface* s, const string& filename, const string& entryname, const string& filtername) { filename_ = filename; entryname_ = entryname; filtername_ = filtername; return FromStream(s); } + Status ReadBatchRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 batch, int64 count, int64* returned, std::vector* out_tensors) const { + return (static_cast *>(this))->ReadReferenceBatchRecord(static_cast(s), ctx, state, batch, count, returned, out_tensors); + } void Encode(VariantTensorData* data) const { data->tensors_ = {Tensor(DT_STRING, TensorShape({})), Tensor(DT_STRING, TensorShape({})), Tensor(DT_STRING, TensorShape({}))}; data->tensors_[0].scalar()() = filename_; @@ -202,15 +250,36 @@ class DataInput { return filtername_; } protected: + virtual Status FromStream(io::InputStreamInterface* s) = 0; + virtual Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const = 0; + virtual void EncodeAttributes(VariantTensorData* data) const = 0; + virtual bool DecodeAttributes(const VariantTensorData& data) = 0; + Status ReadInputStream(io::InputStreamInterface* s, int64 chunk, int64 count, string* buffer, int64* returned) const { + int64 offset = s->Tell(); + int64 bytes_to_read = count * chunk; + Status status = (buffer == nullptr) ? s->SkipNBytes(bytes_to_read) : s->ReadNBytes(bytes_to_read, buffer); + if (!(status.ok() || status == errors::OutOfRange("EOF reached"))) { + return status; + } + int64 bytes_read = s->Tell() - offset; + if (bytes_read % chunk != 0) { + return errors::DataLoss("corrupted data, expected multiple of ", chunk, ", received ", bytes_read); + } + *returned = bytes_read / chunk; + return Status::OK(); + } + Status ReadReferenceRecord(void* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + return ReadRecord(static_cast(s), ctx, state, record_to_read, record_read, out_tensors); + } string filename_; string entryname_; string filtername_; }; template -class DataInputOp: public OpKernel { +class FileInputOp: public OpKernel { public: - explicit DataInputOp(OpKernelConstruction* context) : OpKernel(context) { + explicit FileInputOp(OpKernelConstruction* context) : OpKernel(context) { env_ = context->env(); OP_REQUIRES_OK(context, context->GetAttr("filters", &filters_)); } @@ -236,7 +305,7 @@ class DataInputOp: public OpKernel { // No filter means only a file stream. io::RandomAccessInputStream file_stream(file.get()); T entry; - OP_REQUIRES_OK(ctx, entry.FromInputStream(file_stream, filename, string(""), string(""))); + OP_REQUIRES_OK(ctx, entry.FromInputStream(&file_stream, filename, string(""), string(""))); output.emplace_back(std::move(entry)); continue; } @@ -264,7 +333,7 @@ class DataInputOp: public OpKernel { // none with text type correctly (not reading data in none archive) // So use the shortcut here. io::RandomAccessInputStream file_stream(file.get()); - OP_REQUIRES_OK(ctx, entry.FromInputStream(file_stream, filename, entryname, filtername)); + OP_REQUIRES_OK(ctx, entry.FromInputStream(&file_stream, filename, entryname, filtername)); } else if (filtername == "gz") { // Treat gz file specially. Looks like libarchive always have issue // with text file so use ZlibInputStream. Now libarchive @@ -272,10 +341,10 @@ class DataInputOp: public OpKernel { io::RandomAccessInputStream file_stream(file.get()); io::ZlibCompressionOptions zlib_compression_options = zlib_compression_options = io::ZlibCompressionOptions::GZIP(); io::ZlibInputStream compression_stream(&file_stream, 65536, 65536, zlib_compression_options); - OP_REQUIRES_OK(ctx, entry.FromInputStream(compression_stream, filename, entryname, filtername)); + OP_REQUIRES_OK(ctx, entry.FromInputStream(&compression_stream, filename, entryname, filtername)); } else { archive_stream.ResetEntryOffset(); - OP_REQUIRES_OK(ctx, entry.FromInputStream(archive_stream, filename, entryname, filtername)); + OP_REQUIRES_OK(ctx, entry.FromInputStream(&archive_stream, filename, entryname, filtername)); } output.emplace_back(std::move(entry)); } @@ -297,9 +366,9 @@ class DataInputOp: public OpKernel { std::vector filters_ GUARDED_BY(mu_); }; template -class InputDatasetBase : public DatasetBase { +class FileInputDatasetBase : public DatasetBase { public: - InputDatasetBase(OpKernelContext* ctx, const std::vector& input, const int64 batch, const DataTypeVector& output_types, const std::vector& output_shapes) + FileInputDatasetBase(OpKernelContext* ctx, const std::vector& input, const int64 batch, const DataTypeVector& output_types, const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), ctx_(ctx), input_(input), @@ -349,11 +418,11 @@ class InputDatasetBase : public DatasetBase { return Status::OK(); } private: - class Iterator : public DatasetIterator> { + class Iterator : public DatasetIterator> { public: - using tensorflow::data::DatasetIterator>::dataset; - explicit Iterator(const typename tensorflow::data::DatasetIterator>::Params& params) - : DatasetIterator>(params), stream_(nullptr), archive_(nullptr, [](struct archive *a){ archive_read_free(a);}), file_(nullptr){} + using tensorflow::data::DatasetIterator>::dataset; + explicit Iterator(const typename tensorflow::data::DatasetIterator>::Params& params) + : DatasetIterator>(params), stream_(nullptr), archive_(nullptr, [](struct archive *a){ archive_read_free(a);}), file_(nullptr){} Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -363,51 +432,7 @@ class InputDatasetBase : public DatasetBase { int64 count = dataset()->batch_ == 0 ? 1 : dataset()->batch_; while (returned < count) { if (stream_) { - int64 record_read = 0; - int64 record_to_read = count - returned; - std::vector chunk_tensors; - TF_RETURN_IF_ERROR(dataset()->input_[current_input_index_].ReadRecord((*stream_.get()), ctx, current_input_state_, count - returned, &record_read, &chunk_tensors)); - if (record_read > 0) { - if (out_tensors->size() == 0) { - // Replace out_tensors with chunk_tensors - out_tensors->reserve(chunk_tensors.size()); - // dataset()->batch_ == 0 could only read at most one record - // so it only happens here. - if (dataset()->batch_ == 0) { - for (size_t i = 0; i < chunk_tensors.size(); i++) { - TensorShape shape = chunk_tensors[i].shape(); - shape.RemoveDim(0); - Tensor value_tensor(ctx->allocator({}), chunk_tensors[i].dtype(), shape); - value_tensor.CopyFrom(chunk_tensors[i], shape); - out_tensors->emplace_back(std::move(value_tensor)); - } - } else { - for (size_t i = 0; i < chunk_tensors.size(); i++) { - out_tensors->emplace_back(std::move(chunk_tensors[i])); - } - } - } else { - // Append out_tensors with chunk_tensors - for (size_t i = 0; i < out_tensors->size(); i++) { - TensorShape shape = (*out_tensors)[i].shape(); - shape.set_dim(0, shape.dim_size(0) + record_read); - Tensor value_tensor(ctx->allocator({}), (*out_tensors)[i].dtype(), shape); - TensorShape element_shape = shape; - element_shape.RemoveDim(0); - Tensor element(ctx->allocator({}), (*out_tensors)[i].dtype(), element_shape); - for (size_t index = 0; index < (*out_tensors)[i].shape().dim_size(0); index++) { - TF_RETURN_IF_ERROR(batch_util::CopySliceToElement((*out_tensors)[i], &element, index)); - TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(element, &value_tensor, index)); - } - for (size_t index = 0; index < record_read; index++) { - TF_RETURN_IF_ERROR(batch_util::CopySliceToElement(chunk_tensors[i], &element, index)); - TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(element, &value_tensor, (*out_tensors)[i].shape().dim_size(0) + index)); - } - (*out_tensors)[i] = std::move(value_tensor); - } - } - returned += record_read; - } + TF_RETURN_IF_ERROR(dataset()->input_[current_input_index_].ReadBatchRecord(stream_.get(), ctx, current_input_state_, dataset()->batch_, count, &returned, out_tensors)); if (returned == count) { *end_of_sequence = false; return Status::OK(); @@ -510,10 +535,251 @@ class InputDatasetBase : public DatasetBase { }; template -class InputDatasetOp : public DatasetOpKernel { +class FileInputDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + explicit FileInputDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + OP_REQUIRES( + ctx, (input_tensor->dtype() == DT_VARIANT || input_tensor->dtype() == DT_STRING), + errors::InvalidArgument("`input` must be a variant or string, received ", input_tensor->dtype())); + OP_REQUIRES( + ctx, input_tensor->dims() <= 1, + errors::InvalidArgument("`input` must be a scalar or a vector, dim = ", input_tensor->dims())); + std::vector input; + input.reserve(input_tensor->NumElements()); + if (input_tensor->dtype() == DT_VARIANT) { + for (int i = 0; i < input_tensor->NumElements(); ++i) { + input.push_back(*(input_tensor->flat()(i).get())); + } + } else { + for (int i = 0; i < input_tensor->NumElements(); ++i) { + string message = input_tensor->flat()(i); + VariantTensorDataProto serialized_proto_f; + VariantTensorData serialized_data_f; + DecodeVariant(&message, &serialized_proto_f); + serialized_data_f.FromProto(serialized_proto_f); + InputType entry; + entry.Decode(serialized_data_f); + input.emplace_back(entry); + } + } + const Tensor* batch_tensor; + OP_REQUIRES_OK(ctx, ctx->input("batch", &batch_tensor)); + int64 batch = batch_tensor->scalar()(); + *output = new FileInputDatasetBase(ctx, input, batch, output_types_, output_shapes_); + } + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +template +class StreamInput : public DataInput { + public: + StreamInput() {} + virtual ~StreamInput() {} + Status FromInputEndpoint(const string& endpoint) { + endpoint_ = endpoint; + return FromEndpoint(endpoint); + } + void Encode(VariantTensorData* data) const { + data->tensors_ = {Tensor(DT_STRING, TensorShape({}))}; + data->tensors_[0].scalar()() = endpoint_; + + EncodeAttributes(data); + } + bool Decode(const VariantTensorData& data) { + endpoint_ = data.tensors(0).scalar()(); + + return DecodeAttributes(data); + } + const string& endpoint() const { + return endpoint_; + } + Status ReadBatchRecord(IteratorContext* ctx, std::unique_ptr& state, int64 batch, int64 count, int64* returned, std::vector* out_tensors) const { + return (static_cast *>(this))->ReadReferenceBatchRecord(nullptr, ctx, state, batch, count, returned, out_tensors); + } + protected: + virtual Status FromEndpoint(const string& endpoint) = 0; + virtual Status ReadRecord(IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const = 0; + virtual void EncodeAttributes(VariantTensorData* data) const = 0; + virtual bool DecodeAttributes(const VariantTensorData& data) = 0; + Status ReadReferenceRecord(void* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + return ReadRecord(ctx, state, record_to_read, record_read, out_tensors); + } + string endpoint_; +}; + +template +class StreamInputOp: public OpKernel { + public: + explicit StreamInputOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + void Compute(OpKernelContext* ctx) override { + const Tensor* source_tensor; + OP_REQUIRES_OK(ctx, ctx->input("source", &source_tensor)); + OP_REQUIRES( + ctx, source_tensor->dims() <= 1, + errors::InvalidArgument("`source` must be a scalar or a vector.")); + + std::vector source; + source.reserve(source_tensor->NumElements()); + for (int i = 0; i < source_tensor->NumElements(); ++i) { + source.push_back(source_tensor->flat()(i)); + } + + std::vector output; + + for (const auto& endpoint: source) { + T entry; + OP_REQUIRES_OK(ctx, entry.FromInputEndpoint(endpoint)); + output.emplace_back(std::move(entry)); + } + + Tensor* output_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({static_cast(output.size())}), &output_tensor)); + for (int i = 0; i < output.size(); i++) { + output_tensor->flat()(i) = output[i]; + } + } + protected: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; +template +class StreamInputDatasetBase : public DatasetBase { + public: + StreamInputDatasetBase(OpKernelContext* ctx, const std::vector& input, const int64 batch, const DataTypeVector& output_types, const std::vector& output_shapes) + : DatasetBase(DatasetContext(ctx)), + ctx_(ctx), + input_(input), + batch_(batch), + output_types_(output_types), + output_shapes_(output_shapes) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, DebugString())})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "InputDatasetBase::Dataset"; + } + + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { + Node* input_node; + Tensor input_tensor(DT_STRING, TensorShape({static_cast(input_.size())})); + // GraphDefInternal has some trouble with Variant so use serialized string. + for (size_t i = 0; i < input_.size(); i++) { + string message; + VariantTensorData serialized_data_f; + VariantTensorDataProto serialized_proto_f; + input_[i].Encode(&serialized_data_f); + serialized_data_f.ToProto(&serialized_proto_f); + EncodeVariant(serialized_proto_f, &message); + input_tensor.flat()(i) = message; + } + TF_RETURN_IF_ERROR(b->AddTensor(input_tensor, &input_node)); + Node* batch_node; + Tensor batch_tensor(DT_INT64, TensorShape({})); + batch_tensor.scalar()() = batch_; + TF_RETURN_IF_ERROR(b->AddTensor(batch_tensor, &batch_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, batch_node}, node)); + return Status::OK(); + } + private: + class Iterator : public DatasetIterator> { + public: + using tensorflow::data::DatasetIterator>::dataset; + explicit Iterator(const typename tensorflow::data::DatasetIterator>::Params& params) + : DatasetIterator>(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + int64 returned = 0; + int64 count = dataset()->batch_ == 0 ? 1 : dataset()->batch_; + while (returned < count) { + if (current_input_index_ < dataset()->input_.size()) { + TF_RETURN_IF_ERROR(dataset()->input_[current_input_index_].ReadBatchRecord(ctx, current_input_state_, dataset()->batch_, count, &returned, out_tensors)); + if (returned == count) { + *end_of_sequence = false; + return Status::OK(); + } + // We have reached the end of the current input, move next. + ResetStreamsLocked(); + ++current_input_index_; + } + // Iteration ends when there are no more input to process. + if (current_input_index_ == dataset()->input_.size()) { + if (out_tensors->size() != 0) { + *end_of_sequence = false; + return Status::OK(); + } + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + }; + } + + private: + // Sets up streams to read from `current_input_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_input_index_ >= dataset()->input_.size()) { + return errors::InvalidArgument( + "current_input_index_:", current_input_index_, + " >= input_.size():", dataset()->input_.size()); + } + + // Actually move on to next entry. + current_input_state_.reset(nullptr); + + return Status::OK(); + } + + // Resets all streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + current_input_state_.reset(nullptr); + } + + mutex mu_; + size_t current_input_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr current_input_state_ GUARDED_BY(mu_); + }; + OpKernelContext* ctx_; + protected: + std::vector input_; + int64 batch_; + const DataTypeVector output_types_; + const std::vector output_shapes_; +}; + +template +class StreamInputDatasetOp : public DatasetOpKernel { public: using DatasetOpKernel::DatasetOpKernel; - explicit InputDatasetOp(OpKernelConstruction* ctx) + explicit StreamInputDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -548,7 +814,7 @@ class InputDatasetOp : public DatasetOpKernel { const Tensor* batch_tensor; OP_REQUIRES_OK(ctx, ctx->input("batch", &batch_tensor)); int64 batch = batch_tensor->scalar()(); - *output = new InputDatasetBase(ctx, input, batch, output_types_, output_shapes_); + *output = new StreamInputDatasetBase(ctx, input, batch, output_types_, output_shapes_); } DataTypeVector output_types_; std::vector output_shapes_; diff --git a/tensorflow_io/grpc/BUILD b/tensorflow_io/grpc/BUILD new file mode 100644 index 000000000..1b108d311 --- /dev/null +++ b/tensorflow_io/grpc/BUILD @@ -0,0 +1,65 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") + +genrule( + name = "endpoint_py", + srcs = [ + "endpoint.proto", + ], + outs = [ + "python/ops/__init__.py", + "python/ops/endpoint_pb2.py", + "python/ops/endpoint_pb2_grpc.py", + ], + cmd = "python -m grpc_tools.protoc -Itensorflow_io/grpc --python_out=$(BINDIR)/tensorflow_io/grpc/python/ops/ --grpc_python_out=$(BINDIR)/tensorflow_io/grpc/python/ops/ $< ; touch $(BINDIR)/tensorflow_io/grpc/python/ops/__init__.py", + output_to_bindir = True, +) + +proto_library( + name = "_any_proto_only", + deps = ["@com_google_protobuf//:any_proto"], +) + +cc_proto_library( + name = "any_proto", + deps = ["@com_google_protobuf//:any_proto"], +) + +cc_grpc_library( + name = "endpoint_cc", + srcs = [ + "endpoint.proto", + ], + proto_only = False, + well_known_protos = True, + deps = [":any_proto"], +) + +cc_binary( + name = "python/ops/_grpc_ops.so", + srcs = [ + "kernels/grpc_input.cc", + "ops/grpc_ops.cc", + ], + copts = [ + "-pthread", + "-std=c++11", + "-DNDEBUG", + ], + includes = [ + ".", + ], + linkshared = 1, + deps = [ + ":endpoint_cc", + "//tensorflow_io/core:dataset_ops", + "@com_github_grpc_grpc//:grpc++", + "@com_google_protobuf//:protobuf", + "@libarchive", + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], +) diff --git a/tensorflow_io/grpc/__init__.py b/tensorflow_io/grpc/__init__.py new file mode 100644 index 000000000..625d25f84 --- /dev/null +++ b/tensorflow_io/grpc/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GRPCInput + +@@GRPCDataset +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow_io.grpc.python.ops.grpc_ops import GRPCDataset + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "GRPCDataset", +] + +remove_undocumented(__name__) diff --git a/tensorflow_io/grpc/endpoint.proto b/tensorflow_io/grpc/endpoint.proto new file mode 100644 index 000000000..84dedb4c3 --- /dev/null +++ b/tensorflow_io/grpc/endpoint.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +import "google/protobuf/any.proto"; + +message Request { + int64 offset = 1; + int64 length = 2; +} + +message Response { + google.protobuf.Any record = 1; +} + +service GRPCEndpoint { + rpc ReadRecord(Request) returns (Response){} +} + diff --git a/tensorflow_io/grpc/kernels/grpc_input.cc b/tensorflow_io/grpc/kernels/grpc_input.cc new file mode 100644 index 000000000..c9c5aca45 --- /dev/null +++ b/tensorflow_io/grpc/kernels/grpc_input.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernels/dataset_ops.h" +#include "endpoint.grpc.pb.h" +#include + +namespace tensorflow { +namespace data { + +class GRPCInputState { +public: + GRPCInputState(const string& endpoint) : offset_(0) { + stub_ = GRPCEndpoint::NewStub(grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials())); + } + int64 offset_; + std::unique_ptr stub_; +}; + +class GRPCInput: public StreamInput { + public: + Status ReadRecord(IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + if (state.get() == nullptr) { + state.reset(new GRPCInputState(endpoint_)); + } + Request request; + request.set_offset(state.get()->offset_); + request.set_length(record_to_read); + Response response; + grpc::ClientContext context; + grpc::Status status = state.get()->stub_.get()->ReadRecord(&context, request, &response); + if (!status.ok()) { + return errors::InvalidArgument("unable to fetch data from grpc (", status.error_code(), "): ", status.error_message()); + } + TensorProto record; + response.record().UnpackTo(&record); + Tensor value_tensor; + value_tensor.FromProto(ctx->allocator({}), record); + out_tensors->emplace_back(std::move(value_tensor)); + + *record_read = value_tensor.dim_size(0); + state.get()->offset_ += *record_read; + + return Status::OK(); + } + Status FromEndpoint(const string& endpoint) override { + return Status::OK(); + } + void EncodeAttributes(VariantTensorData* data) const override { + } + bool DecodeAttributes(const VariantTensorData& data) override { + return true; + } + protected: +}; + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(GRPCInput, "tensorflow::data::GRPCInput"); + +REGISTER_KERNEL_BUILDER(Name("GRPCInput").Device(DEVICE_CPU), + StreamInputOp); +REGISTER_KERNEL_BUILDER(Name("GRPCDataset").Device(DEVICE_CPU), + StreamInputDatasetOp); +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/grpc/ops/grpc_ops.cc b/tensorflow_io/grpc/ops/grpc_ops.cc new file mode 100644 index 000000000..411e9af9b --- /dev/null +++ b/tensorflow_io/grpc/ops/grpc_ops.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("GRPCInput") + .Input("source: string") + .Output("handle: variant") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("GRPCDataset") + .Input("input: T") + .Input("batch: int64") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("T: {string, variant} = DT_VARIANT") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({})); + return Status::OK(); + }); + +} // namespace tensorflow diff --git a/tensorflow_io/grpc/python/__init__.py b/tensorflow_io/grpc/python/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorflow_io/grpc/python/ops/__init__.py b/tensorflow_io/grpc/python/ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorflow_io/grpc/python/ops/grpc_endpoint.py b/tensorflow_io/grpc/python/ops/grpc_endpoint.py new file mode 100644 index 000000000..35b548097 --- /dev/null +++ b/tensorflow_io/grpc/python/ops/grpc_endpoint.py @@ -0,0 +1,67 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GRPCEndpoint.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import concurrent.futures +import grpc +import google.protobuf.any_pb2 + +import tensorflow +# Incase test is done with TFIO_DATAPATH specified, the +# import path need to be extended to capture generated +# grpc files: +datapath = os.environ.get('TFIO_DATAPATH') +sys.path.append(os.path.abspath( + os.path.dirname(__file__) if datapath is None else os.path.join( + datapath, "tensorflow_io", "grpc", "python", "ops"))) +import endpoint_pb2 # pylint: disable=wrong-import-position,unused-import +import endpoint_pb2_grpc # pylint: disable=wrong-import-position,unused-import + +class GRPCEndpoint(endpoint_pb2_grpc.GRPCEndpointServicer): + """GRPCEndpoint""" + def __init__(self, data): + self._grpc_server = grpc.server( + concurrent.futures.ThreadPoolExecutor(max_workers=4)) + port = self._grpc_server.add_insecure_port("localhost:0") + self._endpoint = "localhost:"+str(port) + self._data = data + super(GRPCEndpoint, self).__init__() + endpoint_pb2_grpc.add_GRPCEndpointServicer_to_server( + self, self._grpc_server) + + def start(self): + self._grpc_server.start() + + def stop(self): + self._grpc_server.stop(0) + + def endpoint(self): + return self._endpoint + + def ReadRecord(self, request, context): # pylint: disable=unused-argument + if len(self._data.shape) == 1: + tensor = tensorflow.compat.v1.make_tensor_proto( + self._data[request.offset:request.offset+request.length]) + else: + tensor = tensorflow.compat.v1.make_tensor_proto( + self._data[request.offset:request.offset+request.length, :]) + record = google.protobuf.any_pb2.Any() + record.Pack(tensor) + return endpoint_pb2.Response(record=record) diff --git a/tensorflow_io/grpc/python/ops/grpc_ops.py b/tensorflow_io/grpc/python/ops/grpc_ops.py new file mode 100644 index 000000000..ade5dca94 --- /dev/null +++ b/tensorflow_io/grpc/python/ops/grpc_ops.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GRPCInput.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow +from tensorflow.compat.v1 import data +from tensorflow_io import _load_library +grpc_ops = _load_library('_grpc_ops.so') + +class GRPCDataset(data.Dataset): + """A GRPC Dataset + """ + + def __init__(self, endpoint, shape, dtype, batch=None): + """Create a GRPC Reader. + + Args: + endpoint: A `tf.string` tensor containing one or more endpoints. + """ + self._data_input = grpc_ops.grpc_input(endpoint) + self._batch = 0 if batch is None else batch + shape[0] = None + self._output_shapes = tuple([ + tensorflow.TensorShape(shape[1:])]) if self._batch == 0 else tuple([ + tensorflow.TensorShape(shape)]) + self._output_types = tuple([dtype]) + self._batch = 0 if batch is None else batch + super(GRPCDataset, self).__init__() + + @staticmethod + def from_numpy(a, batch=None): + """from_numpy""" + from tensorflow_io.grpc.python.ops import grpc_endpoint + grpc_server = grpc_endpoint.GRPCEndpoint(a) + grpc_server.start() + endpoint = grpc_server.endpoint() + dtype = a.dtype + shape = list(a.shape) + dataset = GRPCDataset(endpoint, shape, dtype, batch=batch) + dataset._grpc_server = grpc_server # pylint: disable=protected-access + return dataset + + def __del__(self): + if self._grpc_server is not None: + self._grpc_server.stop() + + def _inputs(self): + return [] + + def _as_variant_tensor(self): + return grpc_ops.grpc_dataset( + self._data_input, + self._batch, + output_types=self.output_types, + output_shapes=self.output_shapes) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return tensorflow.Tensor + + @property + def output_types(self): + return self._output_types diff --git a/tensorflow_io/mnist/kernels/mnist_dataset_ops.cc b/tensorflow_io/mnist/kernels/mnist_dataset_ops.cc index 2a05962e2..2698170fd 100644 --- a/tensorflow_io/mnist/kernels/mnist_dataset_ops.cc +++ b/tensorflow_io/mnist/kernels/mnist_dataset_ops.cc @@ -18,12 +18,12 @@ limitations under the License. namespace tensorflow { namespace data { -class MNISTImageInput: public DataInput { +class MNISTImageInput: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { state.reset(new int64(0)); - TF_RETURN_IF_ERROR(s.SkipNBytes(16)); + TF_RETURN_IF_ERROR(s->SkipNBytes(16)); } string buffer; Status status = ReadInputStream(s, (rows_ * cols_), record_to_read, &buffer, record_read); @@ -38,9 +38,9 @@ class MNISTImageInput: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { string header; - TF_RETURN_IF_ERROR(s.ReadNBytes(16, &header)); + TF_RETURN_IF_ERROR(s->ReadNBytes(16, &header)); if (header[0] != 0x00 || header[1] != 0x00 || header[2] != 0x08 || header[3] != 0x03) { return errors::InvalidArgument("mnist image file header must starts with `0x00000803`"); } @@ -68,12 +68,12 @@ class MNISTImageInput: public DataInput { int64 rows_; int64 cols_; }; -class MNISTLabelInput: public DataInput { +class MNISTLabelInput: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { state.reset(new int64(0)); - TF_RETURN_IF_ERROR(s.SkipNBytes(8)); + TF_RETURN_IF_ERROR(s->SkipNBytes(8)); } string buffer; TF_RETURN_IF_ERROR(ReadInputStream(s, 1, record_to_read, &buffer, record_read)); @@ -85,9 +85,9 @@ class MNISTLabelInput: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { string header; - TF_RETURN_IF_ERROR(s.ReadNBytes(8, &header)); + TF_RETURN_IF_ERROR(s->ReadNBytes(8, &header)); if (header[0] != 0x00 || header[1] != 0x00 || header[2] != 0x08 || header[3] != 0x01) { return errors::InvalidArgument("mnist label file header must starts with `0x00000801`"); } @@ -110,12 +110,12 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(MNISTLabelInput, "tensorflow::data::MNIST REGISTER_UNARY_VARIANT_DECODE_FUNCTION(MNISTImageInput, "tensorflow::data::MNISTImageInput"); REGISTER_KERNEL_BUILDER(Name("MNISTLabelInput").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("MNISTImageInput").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("MNISTLabelDataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); REGISTER_KERNEL_BUILDER(Name("MNISTImageDataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/text/kernels/text_input.cc b/tensorflow_io/text/kernels/text_input.cc index 14effc47f..d47b5cc4c 100644 --- a/tensorflow_io/text/kernels/text_input.cc +++ b/tensorflow_io/text/kernels/text_input.cc @@ -19,11 +19,11 @@ limitations under the License. namespace tensorflow { namespace data { -class TextInput: public DataInput { +class TextInput: public FileInput { public: - Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { - state.reset(new io::BufferedInputStream(&s, 4096)); + state.reset(new io::BufferedInputStream(s, 4096)); } std::vector records; records.reserve(record_to_read); @@ -49,7 +49,7 @@ class TextInput: public DataInput { } return Status::OK(); } - Status FromStream(io::InputStreamInterface& s) override { + Status FromStream(io::InputStreamInterface* s) override { // TODO: Read 4K buffer to detect BOM. //string header; //TF_RETURN_IF_ERROR(s.ReadNBytes(4096, &header)); @@ -71,8 +71,8 @@ class TextInput: public DataInput { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TextInput, "tensorflow::data::TextInput"); REGISTER_KERNEL_BUILDER(Name("TextInput").Device(DEVICE_CPU), - DataInputOp); + FileInputOp); REGISTER_KERNEL_BUILDER(Name("TextDataset").Device(DEVICE_CPU), - InputDatasetOp); + FileInputDatasetOp); } // namespace data } // namespace tensorflow diff --git a/tests/test_grpc.py b/tests/test_grpc.py new file mode 100644 index 000000000..fa8957b99 --- /dev/null +++ b/tests/test_grpc.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for GRPC Input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import pytest +import tensorflow +tensorflow.compat.v1.disable_eager_execution() + +from tensorflow import errors # pylint: disable=wrong-import-position +import tensorflow_io.grpc as grpc_io # pylint: disable=wrong-import-position + +def test_grpc_input(): + """test_grpc_input""" + data = np.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + dataset = grpc_io.GRPCDataset.from_numpy(data, batch=2) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with tensorflow.compat.v1.Session() as sess: + sess.run(init_op) + v = sess.run(get_next) + assert np.alltrue(data[0:2] == v) + v = sess.run(get_next) + assert np.alltrue(data[2:3] == v) + with pytest.raises(errors.OutOfRangeError): + sess.run(get_next) + +if __name__ == "__main__": + test.main() diff --git a/tests/test_grpc_eager.py b/tests/test_grpc_eager.py new file mode 100644 index 000000000..11c7ae378 --- /dev/null +++ b/tests/test_grpc_eager.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for GRPC Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest +import numpy as np + +import tensorflow as tf +import tensorflow_io.grpc as grpc_io + +@pytest.mark.skipif( + not (hasattr(tf, "version") and + tf.version.VERSION.startswith("2.0.")), reason=None) +def test_grpc_with_mnist_tutorial(): + """test_mnist_tutorial""" + (x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data() + x = grpc_io.GRPCDataset.from_numpy(x_train, batch=1000) + y = grpc_io.GRPCDataset.from_numpy(y_train, batch=1000) + for (i, v) in zip(range(0, 50000, 1000), x): + assert np.alltrue(x_train[i:i+1000, :] == v.numpy()) + for (i, v) in zip(range(0, 50000, 1000), y): + assert np.alltrue(y_train[i:i+1000] == v.numpy()) + d_train = tf.data.Dataset.zip((x, y)) + + model = tf.keras.models.Sequential([ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(512, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10, activation=tf.nn.softmax) + ]) + model.compile(optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(d_train, epochs=5) + +if __name__ == "__main__": + test.main() diff --git a/third_party/libarchive.BUILD b/third_party/libarchive.BUILD index 4e215a5d7..eb2621811 100644 --- a/third_party/libarchive.BUILD +++ b/third_party/libarchive.BUILD @@ -33,7 +33,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "@zlib", + "@com_github_madler_zlib//:z", ], )