Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 20 additions & 25 deletions tensorflow_io/core/kernels/serialization_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"

#include "rapidjson/document.h"
#include "rapidjson/pointer.h"

#include "api/Compiler.hh"
#include "api/DataFile.hh"
#include "api/Generic.hh"
#include "api/Stream.hh"
#include "api/Validator.hh"
#include "rapidjson/document.h"
#include "rapidjson/pointer.h"
#include "tensorflow/core/framework/op_kernel.h"

namespace tensorflow {
namespace data {
Expand All @@ -33,7 +30,6 @@ class DecodeJSONOp : public OpKernel {
public:
explicit DecodeJSONOp(OpKernelConstruction* context) : OpKernel(context) {
env_ = context->env();
OP_REQUIRES_OK(context, context->GetAttr("shapes", &shapes_));
}

void Compute(OpKernelContext* context) override {
Expand All @@ -45,28 +41,26 @@ class DecodeJSONOp : public OpKernel {
const Tensor* names_tensor;
OP_REQUIRES_OK(context, context->input("names", &names_tensor));

OP_REQUIRES(context, (names_tensor->NumElements() == shapes_.size()),
errors::InvalidArgument(
"shapes and names should have same number: ",
shapes_.size(), " vs. ", names_tensor->NumElements()));
OP_REQUIRES(
context, (names_tensor->NumElements() == context->num_outputs()),
errors::InvalidArgument("names should have same number as outputs: ",
names_tensor->NumElements(), " vs. ",
context->num_outputs()));
rapidjson::Document d;
d.Parse(input.c_str());
OP_REQUIRES(context, d.IsObject(),
errors::InvalidArgument("not a valid JSON object"));
for (size_t i = 0; i < shapes_.size(); i++) {
Tensor* value_tensor;
OP_REQUIRES_OK(context,
context->allocate_output(i, shapes_[i], &value_tensor));
for (size_t i = 0; i < names_tensor->NumElements(); i++) {
rapidjson::Value* entry =
rapidjson::Pointer(names_tensor->flat<tstring>()(i).c_str()).Get(d);
OP_REQUIRES(context, (entry != nullptr),
errors::InvalidArgument("no value for ",
names_tensor->flat<tstring>()(i)));
Tensor* value_tensor;
if (entry->IsArray()) {
OP_REQUIRES(context, entry->Size() == value_tensor->NumElements(),
errors::InvalidArgument(
"number of elements in JSON does not match spec: ",
entry->Size(), " vs. ", value_tensor->NumElements()));
OP_REQUIRES_OK(context,
context->allocate_output(i, TensorShape({entry->Size()}),
&value_tensor));

switch (value_tensor->dtype()) {
case DT_INT32:
Expand Down Expand Up @@ -103,21 +97,23 @@ class DecodeJSONOp : public OpKernel {
}

} else {
OP_REQUIRES_OK(context, context->allocate_output(i, TensorShape({1}),
&value_tensor));
switch (value_tensor->dtype()) {
case DT_INT32:
value_tensor->scalar<int32>()() = entry->GetInt();
value_tensor->flat<int32>()(0) = entry->GetInt();
break;
case DT_INT64:
value_tensor->scalar<int64>()() = entry->GetInt64();
value_tensor->flat<int64>()(0) = entry->GetInt64();
break;
case DT_FLOAT:
value_tensor->scalar<float>()() = entry->GetDouble();
value_tensor->flat<float>()(0) = entry->GetDouble();
break;
case DT_DOUBLE:
value_tensor->scalar<double>()() = entry->GetDouble();
value_tensor->flat<double>()(0) = entry->GetDouble();
break;
case DT_STRING:
value_tensor->scalar<tstring>()() = entry->GetString();
value_tensor->flat<tstring>()(0) = entry->GetString();
break;
default:
OP_REQUIRES(
Expand All @@ -133,7 +129,6 @@ class DecodeJSONOp : public OpKernel {
private:
mutable mutex mu_;
Env* env_ TF_GUARDED_BY(mu_);
std::vector<TensorShape> shapes_ TF_GUARDED_BY(mu_);
};

class DecodeAvroOp : public OpKernel {
Expand Down
15 changes: 2 additions & 13 deletions tensorflow_io/core/ops/serialization_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,13 @@ REGISTER_OP("IO>DecodeJSON")
.Input("input: string")
.Input("names: string")
.Output("value: dtypes")
.Attr("shapes: list(shape)")
.Attr("dtypes: list(type)")
.SetShapeFn([](shape_inference::InferenceContext* c) {
// TODO: support batch (1-D) input
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused));
std::vector<TensorShape> shapes;
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
if (shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"shapes and types should be the same: ", shapes.size(), " vs. ",
c->num_outputs());
}
for (size_t i = 0; i < shapes.size(); ++i) {
shape_inference::ShapeHandle shape;
TF_RETURN_IF_ERROR(
c->MakeShapeFromPartialTensorShape(shapes[i], &shape));
c->set_output(static_cast<int64>(i), shape);
for (size_t i = 0; i < c->num_outputs(); ++i) {
c->set_output(static_cast<int64>(i), c->MakeShape({c->UnknownDim()}));
}
return Status::OK();
});
Expand Down
8 changes: 6 additions & 2 deletions tensorflow_io/core/python/experimental/serialization_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,14 @@ def decode_json(data, specs, name=None):
named_spec(named)
named = tf.nest.flatten(named)
names = [e.named() for e in named]
shapes = [e.shape for e in named]
shapes = [
tf.constant([-1 if d is None else d for d in e.shape.as_list()], tf.int32)
for e in named
]
dtypes = [e.dtype for e in named]

values = core_ops.io_decode_json(data, names, shapes, dtypes, name=name)
values = core_ops.io_decode_json(data, names, dtypes, name=name)
values = [tf.reshape(value, shape) for value, shape in zip(values, shapes)]
return tf.nest.pack_sequence_as(specs, values)


Expand Down
15 changes: 15 additions & 0 deletions tests/test_serialization_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Test Serialization"""

import os
import json
import numpy as np

import pytest
Expand Down Expand Up @@ -185,3 +186,17 @@ def test_serialization_decode_in_dataset(
for v, r in zip(tf.nest.flatten(value), tf.nest.flatten(returned))
]
)


def test_json_partial_shape():
"""Test case for partial shape GitHub 918."""
r = json.dumps({"foo": [1, 2, 3, 4, 5]})

@tf.function(autograph=False)
def parse_json(json_text):
specs = {"foo": tf.TensorSpec(tf.TensorShape([None]), tf.int32)}
parsed = tfio.experimental.serialization.decode_json(json_text, specs)
return parsed["foo"]

v = parse_json(r)
assert np.array_equal(v, [1, 2, 3, 4, 5])