Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tf.contrib.data.LMDBDataset support #21148

Merged
merged 6 commits into from
Aug 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow/contrib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
@@Counter
@@CheckpointInputPipelineHook
@@CsvDataset
@@LMDBDataset
@@RandomDataset
@@Reducer
@@SqlDataset
Expand Down Expand Up @@ -93,6 +94,7 @@
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
from tensorflow.contrib.data.python.ops.random_ops import RandomDataset
from tensorflow.contrib.data.python.ops.readers import CsvDataset
from tensorflow.contrib.data.python.ops.readers import LMDBDataset
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import make_csv_dataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/contrib/data/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "lmdb_dataset_op",
srcs = ["lmdb_dataset_op.cc"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@lmdb",
"@protobuf_archive//:protobuf_headers",
],
)

cc_library(
name = "threadpool_dataset_op",
srcs = ["threadpool_dataset_op.cc"],
Expand Down Expand Up @@ -91,6 +102,7 @@ cc_library(
":csv_dataset_op",
":directed_interleave_dataset_op",
":ignore_errors_dataset_op",
":lmdb_dataset_op",
":prefetching_kernels",
":threadpool_dataset_op",
":unique_dataset_op",
Expand Down
212 changes: 212 additions & 0 deletions tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/platform/file_system.h"

#include <sys/stat.h>
#include "lmdb.h"

namespace tensorflow {
namespace {

class LMDBDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
const Tensor* filenames_tensor;
OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
OP_REQUIRES(
ctx, filenames_tensor->dims() <= 1,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));

std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
filenames.push_back(filenames_tensor->flat<string>()(i));
}

*output = new Dataset(ctx, filenames);
}

private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const std::vector<string>& filenames)
: DatasetBase(DatasetContext(ctx)), filenames_(filenames) {}

std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::LMDB")}));
}

const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes =
new DataTypeVector({DT_STRING, DT_STRING});
return *dtypes;
}

const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}, {}});
return *shapes;
}

string DebugString() const override { return "LMDBDatasetOp::Dataset"; }

protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
return Status::OK();
}

private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}

Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
do {
if (mdb_cursor_) {
Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
key_tensor.scalar<string>()() = string(
static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
out_tensors->emplace_back(std::move(key_tensor));

Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
value_tensor.scalar<string>()() =
string(static_cast<const char*>(mdb_value_.mv_data),
mdb_value_.mv_size);
out_tensors->emplace_back(std::move(value_tensor));

int val;
val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT);
if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
return errors::InvalidArgument(mdb_strerror(val));
}
if (val == MDB_NOTFOUND) {
ResetStreamsLocked();
++current_file_index_;
}
*end_of_sequence = false;
return Status::OK();
}
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
return Status::OK();
}

TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
} while (true);
}

protected:
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}

Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"RestoreInternal is currently not supported");
}

private:
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (current_file_index_ >= dataset()->filenames_.size()) {
return errors::InvalidArgument(
"current_file_index_:", current_file_index_,
" >= filenames_.size():", dataset()->filenames_.size());
}
const string& filename = dataset()->filenames_[current_file_index_];

int val = mdb_env_create(&mdb_env_);
if (val != MDB_SUCCESS) {
return errors::InvalidArgument(mdb_strerror(val));
}
int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK;

struct stat source_stat;
if (stat(filename.c_str(), &source_stat) == 0 &&
(source_stat.st_mode & S_IFREG)) {
flags |= MDB_NOSUBDIR;
}
val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664);
if (val != MDB_SUCCESS) {
return errors::InvalidArgument(mdb_strerror(val));
}
val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_);
if (val != MDB_SUCCESS) {
return errors::InvalidArgument(mdb_strerror(val));
}
val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_);
if (val != MDB_SUCCESS) {
return errors::InvalidArgument(mdb_strerror(val));
}
val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_);
if (val != MDB_SUCCESS) {
return errors::InvalidArgument(mdb_strerror(val));
}
val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST);
if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
return errors::InvalidArgument(mdb_strerror(val));
}
if (val == MDB_NOTFOUND) {
ResetStreamsLocked();
}
return Status::OK();
}
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (mdb_env_ != nullptr) {
if (mdb_cursor_) {
mdb_cursor_close(mdb_cursor_);
mdb_cursor_ = nullptr;
}
mdb_dbi_close(mdb_env_, mdb_dbi_);
mdb_txn_abort(mdb_txn_);
mdb_env_close(mdb_env_);
mdb_txn_ = nullptr;
mdb_dbi_ = 0;
mdb_env_ = nullptr;
}
}
mutex mu_;
size_t current_file_index_ GUARDED_BY(mu_) = 0;
MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr;
MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr;
MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0;
MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr;

MDB_val mdb_key_ GUARDED_BY(mu_);
MDB_val mdb_value_ GUARDED_BY(mu_);
};

const std::vector<string> filenames_;
};
};
}
REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);

} // namespace tensorflow
6 changes: 6 additions & 0 deletions tensorflow/contrib/data/ops/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,10 @@ REGISTER_OP("AssertNextDataset")
return shape_inference::ScalarShape(c);
});

REGISTER_OP("LMDBDataset")
.Input("filenames: string")
.Output("handle: variant")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);

} // namespace tensorflow
25 changes: 25 additions & 0 deletions tensorflow/contrib/data/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,31 @@ py_test(
],
)

py_test(
name = "lmdb_dataset_op_test",
size = "medium",
srcs = ["lmdb_dataset_op_test.py"],
data = ["//tensorflow/core:lmdb_testdata"],
srcs_version = "PY2AND3",
tags = [
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//third_party/py/numpy",
],
)

py_test(
name = "map_dataset_op_test",
size = "medium",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for LMDBDatasetOp."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import shutil

from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.util import compat

class LMDBDatasetTest(test.TestCase):

def setUp(self):
super(LMDBDatasetTest, self).setUp()
path = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
"tensorflow",
"core",
"lib",
"lmdb",
"testdata",
"data.mdb")

print(path)
# Copy database out because we need the path to be writable to use locks.
self.db_path = os.path.join(self.get_temp_dir(), "data.mdb")
shutil.copy(path, self.db_path)

def testReadFromFile(self):
filename = self.db_path

filenames = constant_op.constant([filename], dtypes.string)
num_repeats = 2

dataset = readers.LMDBDataset(
filenames).repeat(num_repeats)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()

with self.test_session() as sess:
sess.run(init_op)
for _ in range(num_repeats): # Dataset is repeated.
for i in range(10): # 10 records.
k = compat.as_bytes(str(i))
v = compat.as_bytes(str(chr(ord("a") + i)))
self.assertEqual((k, v), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

if __name__ == "__main__":
test.main()