diff --git a/tensorflow_io/core/kernels/serialization_kernels.cc b/tensorflow_io/core/kernels/serialization_kernels.cc index 705aba31f..ce0de5010 100644 --- a/tensorflow_io/core/kernels/serialization_kernels.cc +++ b/tensorflow_io/core/kernels/serialization_kernels.cc @@ -13,17 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/resource_op_kernel.h" - -#include "rapidjson/document.h" -#include "rapidjson/pointer.h" - #include "api/Compiler.hh" #include "api/DataFile.hh" #include "api/Generic.hh" #include "api/Stream.hh" #include "api/Validator.hh" +#include "rapidjson/document.h" +#include "rapidjson/pointer.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace data { @@ -33,7 +30,6 @@ class DecodeJSONOp : public OpKernel { public: explicit DecodeJSONOp(OpKernelConstruction* context) : OpKernel(context) { env_ = context->env(); - OP_REQUIRES_OK(context, context->GetAttr("shapes", &shapes_)); } void Compute(OpKernelContext* context) override { @@ -45,28 +41,26 @@ class DecodeJSONOp : public OpKernel { const Tensor* names_tensor; OP_REQUIRES_OK(context, context->input("names", &names_tensor)); - OP_REQUIRES(context, (names_tensor->NumElements() == shapes_.size()), - errors::InvalidArgument( - "shapes and names should have same number: ", - shapes_.size(), " vs. ", names_tensor->NumElements())); + OP_REQUIRES( + context, (names_tensor->NumElements() == context->num_outputs()), + errors::InvalidArgument("names should have same number as outputs: ", + names_tensor->NumElements(), " vs. ", + context->num_outputs())); rapidjson::Document d; d.Parse(input.c_str()); OP_REQUIRES(context, d.IsObject(), errors::InvalidArgument("not a valid JSON object")); - for (size_t i = 0; i < shapes_.size(); i++) { - Tensor* value_tensor; - OP_REQUIRES_OK(context, - context->allocate_output(i, shapes_[i], &value_tensor)); + for (size_t i = 0; i < names_tensor->NumElements(); i++) { rapidjson::Value* entry = rapidjson::Pointer(names_tensor->flat()(i).c_str()).Get(d); OP_REQUIRES(context, (entry != nullptr), errors::InvalidArgument("no value for ", names_tensor->flat()(i))); + Tensor* value_tensor; if (entry->IsArray()) { - OP_REQUIRES(context, entry->Size() == value_tensor->NumElements(), - errors::InvalidArgument( - "number of elements in JSON does not match spec: ", - entry->Size(), " vs. ", value_tensor->NumElements())); + OP_REQUIRES_OK(context, + context->allocate_output(i, TensorShape({entry->Size()}), + &value_tensor)); switch (value_tensor->dtype()) { case DT_INT32: @@ -103,21 +97,23 @@ class DecodeJSONOp : public OpKernel { } } else { + OP_REQUIRES_OK(context, context->allocate_output(i, TensorShape({1}), + &value_tensor)); switch (value_tensor->dtype()) { case DT_INT32: - value_tensor->scalar()() = entry->GetInt(); + value_tensor->flat()(0) = entry->GetInt(); break; case DT_INT64: - value_tensor->scalar()() = entry->GetInt64(); + value_tensor->flat()(0) = entry->GetInt64(); break; case DT_FLOAT: - value_tensor->scalar()() = entry->GetDouble(); + value_tensor->flat()(0) = entry->GetDouble(); break; case DT_DOUBLE: - value_tensor->scalar()() = entry->GetDouble(); + value_tensor->flat()(0) = entry->GetDouble(); break; case DT_STRING: - value_tensor->scalar()() = entry->GetString(); + value_tensor->flat()(0) = entry->GetString(); break; default: OP_REQUIRES( @@ -133,7 +129,6 @@ class DecodeJSONOp : public OpKernel { private: mutable mutex mu_; Env* env_ TF_GUARDED_BY(mu_); - std::vector shapes_ TF_GUARDED_BY(mu_); }; class DecodeAvroOp : public OpKernel { diff --git a/tensorflow_io/core/ops/serialization_ops.cc b/tensorflow_io/core/ops/serialization_ops.cc index a25a84160..db45534fe 100644 --- a/tensorflow_io/core/ops/serialization_ops.cc +++ b/tensorflow_io/core/ops/serialization_ops.cc @@ -25,24 +25,13 @@ REGISTER_OP("IO>DecodeJSON") .Input("input: string") .Input("names: string") .Output("value: dtypes") - .Attr("shapes: list(shape)") .Attr("dtypes: list(type)") .SetShapeFn([](shape_inference::InferenceContext* c) { // TODO: support batch (1-D) input shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); - std::vector shapes; - TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); - if (shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "shapes and types should be the same: ", shapes.size(), " vs. ", - c->num_outputs()); - } - for (size_t i = 0; i < shapes.size(); ++i) { - shape_inference::ShapeHandle shape; - TF_RETURN_IF_ERROR( - c->MakeShapeFromPartialTensorShape(shapes[i], &shape)); - c->set_output(static_cast(i), shape); + for (size_t i = 0; i < c->num_outputs(); ++i) { + c->set_output(static_cast(i), c->MakeShape({c->UnknownDim()})); } return Status::OK(); }); diff --git a/tensorflow_io/core/python/experimental/serialization_ops.py b/tensorflow_io/core/python/experimental/serialization_ops.py index 25e3a0f79..fddfb8b5b 100644 --- a/tensorflow_io/core/python/experimental/serialization_ops.py +++ b/tensorflow_io/core/python/experimental/serialization_ops.py @@ -67,10 +67,14 @@ def decode_json(data, specs, name=None): named_spec(named) named = tf.nest.flatten(named) names = [e.named() for e in named] - shapes = [e.shape for e in named] + shapes = [ + tf.constant([-1 if d is None else d for d in e.shape.as_list()], tf.int32) + for e in named + ] dtypes = [e.dtype for e in named] - values = core_ops.io_decode_json(data, names, shapes, dtypes, name=name) + values = core_ops.io_decode_json(data, names, dtypes, name=name) + values = [tf.reshape(value, shape) for value, shape in zip(values, shapes)] return tf.nest.pack_sequence_as(specs, values) diff --git a/tests/test_serialization_eager.py b/tests/test_serialization_eager.py index 613ff9052..85968ceac 100644 --- a/tests/test_serialization_eager.py +++ b/tests/test_serialization_eager.py @@ -15,6 +15,7 @@ """Test Serialization""" import os +import json import numpy as np import pytest @@ -185,3 +186,17 @@ def test_serialization_decode_in_dataset( for v, r in zip(tf.nest.flatten(value), tf.nest.flatten(returned)) ] ) + + +def test_json_partial_shape(): + """Test case for partial shape GitHub 918.""" + r = json.dumps({"foo": [1, 2, 3, 4, 5]}) + + @tf.function(autograph=False) + def parse_json(json_text): + specs = {"foo": tf.TensorSpec(tf.TensorShape([None]), tf.int32)} + parsed = tfio.experimental.serialization.decode_json(json_text, specs) + return parsed["foo"] + + v = parse_json(r) + assert np.array_equal(v, [1, 2, 3, 4, 5])