Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to initialize decode_avro resource, so that compiled avro schema could be reused #628

Merged
merged 2 commits into from Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow_io/kafka/__init__.py
Expand Up @@ -19,6 +19,7 @@
@@write_kafka
@@decode_avro
@@encode_avro
@@decode_avro_init
"""

from __future__ import absolute_import
Expand All @@ -30,6 +31,7 @@
from tensorflow_io.kafka.python.ops.kafka_dataset_ops import write_kafka
from tensorflow_io.kafka.python.ops.kafka_dataset_ops import decode_avro
from tensorflow_io.kafka.python.ops.kafka_dataset_ops import encode_avro
from tensorflow_io.kafka.python.ops.kafka_dataset_ops import decode_avro_init

from tensorflow.python.util.all_util import remove_undocumented

Expand Down
90 changes: 79 additions & 11 deletions tensorflow_io/kafka/kernels/kafka_kernels.cc
Expand Up @@ -318,30 +318,72 @@ REGISTER_KERNEL_BUILDER(Name("IO>KafkaReadableInit").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("IO>KafkaReadableRead").Device(DEVICE_CPU),
IOReadableReadOp<KafkaReadable>);

class DecodeAvroResource : public ResourceBase {
public:
DecodeAvroResource(Env* env) : env_(env) {}
~DecodeAvroResource() {}

Status Init(const string& input) {
mutex_lock lock(mu_);
schema_ = input;
schema_stream_ = std::istringstream(schema_);

string error;
if (!(avro::compileJsonSchema(schema_stream_, avro_schema_, error))) {
return errors::Unimplemented("Avro schema error: ", error);
}

return Status::OK();
}
const avro::ValidSchema& avro_schema() {
return avro_schema_;
}
string DebugString() const override {
return "DecodeAvroResource";
}
private:
mutable mutex mu_;
Env* env_ GUARDED_BY(mu_);
string schema_ GUARDED_BY(mu_);
std::istringstream schema_stream_;
avro::ValidSchema avro_schema_;
};

class DecodeAvroOp : public OpKernel {
public:
explicit DecodeAvroOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("schema", &schema_));
env_ = context->env();
}

void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));

avro::ValidSchema avro_schema;
std::istringstream ss(schema_);
string error;
OP_REQUIRES(context, (avro::compileJsonSchema(ss, avro_schema, error)), errors::Unimplemented("Avro schema error: ", error));
DecodeAvroResource* resource;
std::unique_ptr<DecodeAvroResource> resource_scope;
if (context->input_dtype(1) == DT_RESOURCE) {
OP_REQUIRES_OK(context, GetResourceFromContext(context, "schema", &resource));
} else {
const Tensor* schema_tensor;
OP_REQUIRES_OK(context, context->input("schema", &schema_tensor));
const string& schema = schema_tensor->scalar<string>()();

resource_scope.reset(new DecodeAvroResource(env_));
OP_REQUIRES_OK(context, resource_scope->Init(schema));
resource_scope->Ref();
resource = resource_scope.get();
}
core::ScopedUnref unref(resource);

avro::GenericDatum datum(avro_schema);
std::vector<Tensor*> value;
value.reserve(avro_schema.root()->names());
for (size_t i = 0; i < avro_schema.root()->names(); i++) {
value.reserve(resource->avro_schema().root()->names());
for (size_t i = 0; i < resource->avro_schema().root()->names(); i++) {
Tensor* value_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(static_cast<int64>(i), input_tensor->shape(), &value_tensor));
value.push_back(value_tensor);
}

avro::GenericDatum datum(resource->avro_schema());
for (int64 entry_index = 0; entry_index < input_tensor->NumElements(); entry_index++) {
const string& entry = input_tensor->flat<string>()(entry_index);
std::unique_ptr<avro::InputStream> in = avro::memoryInputStream((const uint8_t*)entry.data(), entry.size());
Expand All @@ -350,7 +392,7 @@ class DecodeAvroOp : public OpKernel {
d->init(*in);
avro::decode(*d, datum);
const avro::GenericRecord& record = datum.value<avro::GenericRecord>();
for (int i = 0; i < avro_schema.root()->names(); i++) {
for (int i = 0; i < resource->avro_schema().root()->names(); i++) {
const avro::GenericDatum& field = record.fieldAt(i);
switch(field.type()) {
case avro::AVRO_NULL:
Expand Down Expand Up @@ -435,8 +477,9 @@ class DecodeAvroOp : public OpKernel {
}
}
}
private:
string schema_;
private:
mutable mutex mu_;
Env* env_ GUARDED_BY(mu_);
};

class EncodeAvroOp : public OpKernel {
Expand Down Expand Up @@ -525,12 +568,37 @@ class EncodeAvroOp : public OpKernel {
private:
string schema_;
};
class DecodeAvroInitOp : public ResourceOpKernel<DecodeAvroResource> {
public:
explicit DecodeAvroInitOp(OpKernelConstruction* context)
: ResourceOpKernel<DecodeAvroResource>(context) {
env_ = context->env();
}
private:
void Compute(OpKernelContext* context) override {
ResourceOpKernel<DecodeAvroResource>::Compute(context);

const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));

OP_REQUIRES_OK(context, resource_->Init(input_tensor->scalar<string>()()));
}
Status CreateResource(DecodeAvroResource** resource)
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
*resource = new DecodeAvroResource(env_);
return Status::OK();
}
private:
mutable mutex mu_;
Env* env_ GUARDED_BY(mu_);
};

REGISTER_KERNEL_BUILDER(Name("IO>DecodeAvro").Device(DEVICE_CPU),
DecodeAvroOp);
REGISTER_KERNEL_BUILDER(Name("IO>EncodeAvro").Device(DEVICE_CPU),
EncodeAvroOp);
REGISTER_KERNEL_BUILDER(Name("IO>DecodeAvroInit").Device(DEVICE_CPU),
DecodeAvroInitOp);

} // namespace data
} // namespace tensorflow
13 changes: 12 additions & 1 deletion tensorflow_io/kafka/ops/kafka_ops.cc
Expand Up @@ -29,11 +29,22 @@ REGISTER_OP("IO>EncodeAvro")
return Status::OK();
});

REGISTER_OP("IO>DecodeAvroInit")
.Input("input: string")
.Output("resource: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
return Status::OK();
});

REGISTER_OP("IO>DecodeAvro")
.Input("input: string")
.Input("schema: T")
.Output("value: dtype")
.Attr("schema: string")
.Attr("dtype: list({float,double,int32,int64,string})")
.Attr("T: {string, resource}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
for (int64 i = 0; i < c->num_outputs(); i++) {
c->set_output(i, c->input(0));
Expand Down
1 change: 1 addition & 0 deletions tensorflow_io/kafka/python/ops/kafka_dataset_ops.py
Expand Up @@ -24,6 +24,7 @@
from tensorflow.compat.v1 import data
from tensorflow_io.core.python.ops import core_ops

decode_avro_init = core_ops.io_decode_avro_init
decode_avro = core_ops.io_decode_avro
encode_avro = core_ops.io_encode_avro

Expand Down
19 changes: 19 additions & 0 deletions tests/test_kafka_eager.py
Expand Up @@ -107,6 +107,25 @@ def test_avro_kafka_dataset():
entries = [(f1.numpy(), f2.numpy(), f3.numpy()) for (f1, f2, f3) in dataset]
np.all(entries == [('value1', 1), ('value2', 2), ('value3', 3)])

def test_avro_kafka_dataset_with_resource():
"""test_avro_kafka_dataset_with_resource"""
schema = ('{"type":"record","name":"myrecord","fields":['
'{"name":"f1","type":"string"},'
'{"name":"f2","type":"long"},'
'{"name":"f3","type":["null","string"],"default":null}'
']}"')
schema_resource = kafka_io.decode_avro_init(schema)
dataset = kafka_io.KafkaDataset(
["avro-test:0"], group="avro-test", eof=True)
# remove kafka framing
dataset = dataset.map(lambda e: tf.strings.substr(e, 5, -1))
# deserialize avro
dataset = dataset.map(
lambda e: kafka_io.decode_avro(
e, schema=schema_resource, dtype=[tf.string, tf.int64, tf.string]))
entries = [(f1.numpy(), f2.numpy(), f3.numpy()) for (f1, f2, f3) in dataset]
np.all(entries == [('value1', 1), ('value2', 2), ('value3', 3)])

def test_kafka_stream_dataset():
dataset = tfio.IODataset.stream().from_kafka("test").batch(2)
assert np.all([
Expand Down