Skip to content
Permalink
Branch: master
Find file Copy path
951 lines (837 sloc) 37.8 KB
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_inputstream.h"
namespace tensorflow {
namespace data {
namespace {
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following ops.
constexpr char kTextLineDatasetName[] = "TextLine";
class TextLineDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
const Tensor* filenames_tensor;
OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
OP_REQUIRES(
ctx, filenames_tensor->dims() <= 1,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
string compression_type;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
&compression_type));
int64 buffer_size = -1;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
OP_REQUIRES(
ctx, buffer_size >= 0,
errors::InvalidArgument("`buffer_size` must be >= 0 (0 == default)"));
io::ZlibCompressionOptions zlib_compression_options =
io::ZlibCompressionOptions::DEFAULT();
if (compression_type == "ZLIB") {
zlib_compression_options = io::ZlibCompressionOptions::DEFAULT();
} else if (compression_type == "GZIP") {
zlib_compression_options = io::ZlibCompressionOptions::GZIP();
} else {
OP_REQUIRES(ctx, compression_type.empty(),
errors::InvalidArgument("Unsupported compression_type."));
}
if (buffer_size != 0) {
// Set the override size.
zlib_compression_options.input_buffer_size = buffer_size;
}
std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
filenames.push_back(filenames_tensor->flat<string>()(i));
}
*output = new Dataset(ctx, std::move(filenames), compression_type,
zlib_compression_options);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<string> filenames,
const string& compression_type,
const io::ZlibCompressionOptions& options)
: DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
compression_type_(compression_type),
use_compression_(!compression_type.empty()),
options_(options) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, strings::StrCat(prefix, "::", kTextLineDatasetName)});
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}
string DebugString() const override { return "TextLineDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
Node* compression_type = nullptr;
Node* buffer_size = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
TF_RETURN_IF_ERROR(
b->AddScalar(options_.input_buffer_size, &buffer_size));
TF_RETURN_IF_ERROR(b->AddDataset(
this, {filenames, compression_type, buffer_size}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
do {
// We are currently processing a file, so try to read the next line.
if (buffered_input_stream_) {
string line_contents;
Status s = buffered_input_stream_->ReadLine(&line_contents);
if (s.ok()) {
// Produce the line as output.
metrics::RecordTFDataBytesRead(kTextLineDatasetName,
line_contents.size());
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
TensorShape({}));
out_tensors->back().scalar<string>()() = std::move(line_contents);
*end_of_sequence = false;
return Status::OK();
} else if (!errors::IsOutOfRange(s)) {
// Report non-EOF errors to the caller.
return s;
}
// We have reached the end of the current file, so maybe
// move on to next file.
ResetStreamsLocked();
++current_file_index_;
}
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
return Status::OK();
}
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
} while (true);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeSourceNode(std::move(args));
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
current_file_index_));
// `buffered_input_stream_` is empty if
// 1. GetNext has not been called even once.
// 2. All files have been read and iterator has been exhausted.
if (buffered_input_stream_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("current_pos"), buffered_input_stream_->Tell()));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
ResetStreamsLocked();
int64 current_file_index;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
&current_file_index));
current_file_index_ = size_t(current_file_index);
// The key "current_pos" is written only if the iterator was saved
// with an open file.
if (reader->Contains(full_name("current_pos"))) {
int64 current_pos;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("current_pos"), &current_pos));
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
TF_RETURN_IF_ERROR(buffered_input_stream_->Seek(current_pos));
}
return Status::OK();
}
private:
// Sets up reader streams to read from the file at `current_file_index_`.
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (current_file_index_ >= dataset()->filenames_.size()) {
return errors::InvalidArgument(
"current_file_index_:", current_file_index_,
" >= filenames_.size():", dataset()->filenames_.size());
}
// Actually move on to next file.
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
dataset()->filenames_[current_file_index_], &file_));
input_stream_ =
absl::make_unique<io::RandomAccessInputStream>(file_.get(), false);
if (dataset()->use_compression_) {
zlib_input_stream_ = absl::make_unique<io::ZlibInputStream>(
input_stream_.get(), dataset()->options_.input_buffer_size,
dataset()->options_.input_buffer_size, dataset()->options_);
buffered_input_stream_ = absl::make_unique<io::BufferedInputStream>(
zlib_input_stream_.get(), dataset()->options_.input_buffer_size,
false);
} else {
buffered_input_stream_ = absl::make_unique<io::BufferedInputStream>(
input_stream_.get(), dataset()->options_.input_buffer_size,
false);
}
return Status::OK();
}
// Resets all reader streams.
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
input_stream_.reset();
zlib_input_stream_.reset();
buffered_input_stream_.reset();
file_.reset();
}
mutex mu_;
std::unique_ptr<io::RandomAccessInputStream> input_stream_
GUARDED_BY(mu_);
std::unique_ptr<io::ZlibInputStream> zlib_input_stream_ GUARDED_BY(mu_);
std::unique_ptr<io::BufferedInputStream> buffered_input_stream_
GUARDED_BY(mu_);
size_t current_file_index_ GUARDED_BY(mu_) = 0;
std::unique_ptr<RandomAccessFile> file_
GUARDED_BY(mu_); // must outlive input_stream_
};
const std::vector<string> filenames_;
const string compression_type_;
const bool use_compression_;
const io::ZlibCompressionOptions options_;
};
};
REGISTER_KERNEL_BUILDER(Name("TextLineDataset").Device(DEVICE_CPU),
TextLineDatasetOp);
constexpr char kFixedLengthRecordDatasetName[] = "FixedLengthRecord";
class FixedLengthRecordDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
explicit FixedLengthRecordDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx),
op_version_(ctx->def().op() == "FixedLengthRecordDataset" ? 1 : 2) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
const Tensor* filenames_tensor;
OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
OP_REQUIRES(
ctx, filenames_tensor->dims() <= 1,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
filenames.push_back(filenames_tensor->flat<string>()(i));
}
int64 header_bytes = -1;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "header_bytes", &header_bytes));
OP_REQUIRES(ctx, header_bytes >= 0,
errors::InvalidArgument("`header_bytes` must be >= 0"));
int64 record_bytes = -1;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "record_bytes", &record_bytes));
OP_REQUIRES(ctx, record_bytes > 0,
errors::InvalidArgument("`record_bytes` must be > 0"));
int64 footer_bytes = -1;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "footer_bytes", &footer_bytes));
OP_REQUIRES(ctx, footer_bytes >= 0,
errors::InvalidArgument("`footer_bytes` must be >= 0"));
int64 buffer_size = -1;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
OP_REQUIRES(ctx, buffer_size >= 0,
errors::InvalidArgument("`buffer_size` must be >= 0"));
if (buffer_size == 0) {
buffer_size = 256 << 10; // 256 kB as default.
}
string compression_type;
if (op_version_ > 1) {
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
&compression_type));
OP_REQUIRES(ctx,
compression_type.empty() || compression_type == "ZLIB" ||
compression_type == "GZIP",
errors::InvalidArgument("Unsupported compression_type."));
}
*output = new Dataset(ctx, std::move(filenames), header_bytes, record_bytes,
footer_bytes, buffer_size, compression_type);
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
int64 header_bytes, int64 record_bytes, int64 footer_bytes,
int64 buffer_size, const string& compression_type)
: DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
header_bytes_(header_bytes),
record_bytes_(record_bytes),
footer_bytes_(footer_bytes),
buffer_size_(buffer_size),
compression_type_(compression_type) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (compression_type_.empty()) {
return absl::make_unique<UncompressedIterator>(
UncompressedIterator::Params{
this,
strings::StrCat(prefix, "::", kFixedLengthRecordDatasetName)});
} else {
return absl::make_unique<CompressedIterator>(CompressedIterator::Params{
this,
strings::StrCat(prefix, "::", kFixedLengthRecordDatasetName)});
}
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}
string DebugString() const override {
return "FixedLengthRecordDatasetOp::Dataset";
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
Node* header_bytes = nullptr;
Node* record_bytes = nullptr;
Node* footer_bytes = nullptr;
Node* buffer_size = nullptr;
Node* compression_type = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
TF_RETURN_IF_ERROR(b->AddScalar(header_bytes_, &header_bytes));
TF_RETURN_IF_ERROR(b->AddScalar(record_bytes_, &record_bytes));
TF_RETURN_IF_ERROR(b->AddScalar(footer_bytes_, &footer_bytes));
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
TF_RETURN_IF_ERROR(
b->AddDataset(this,
{filenames, header_bytes, record_bytes, footer_bytes,
buffer_size, compression_type},
output));
return Status::OK();
}
private:
class UncompressedIterator : public DatasetIterator<Dataset> {
public:
explicit UncompressedIterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
do {
// We are currently processing a file, so try to read the next record.
if (input_buffer_) {
const int64 current_pos = input_buffer_->Tell();
DCHECK_GE(file_pos_limit_, 0);
if (current_pos < file_pos_limit_) {
string record;
TF_RETURN_IF_ERROR(
input_buffer_->ReadNBytes(dataset()->record_bytes_, &record));
metrics::RecordTFDataBytesRead(kFixedLengthRecordDatasetName,
dataset()->record_bytes_);
// Produce the record as output.
Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
record_tensor.scalar<string>()() = record;
out_tensors->emplace_back(std::move(record_tensor));
*end_of_sequence = false;
return Status::OK();
}
// We have reached the end of the current file, so maybe
// move on to next file.
input_buffer_.reset();
file_.reset();
++current_file_index_;
}
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
return Status::OK();
}
// Actually move on to next file.
uint64 file_size;
TF_RETURN_IF_ERROR(ctx->env()->GetFileSize(
dataset()->filenames_[current_file_index_], &file_size));
file_pos_limit_ = file_size - dataset()->footer_bytes_;
uint64 body_size =
file_size - (dataset()->header_bytes_ + dataset()->footer_bytes_);
if (body_size % dataset()->record_bytes_ != 0) {
return errors::InvalidArgument(
"Excluding the header (", dataset()->header_bytes_,
" bytes) and footer (", dataset()->footer_bytes_,
" bytes), input file \"",
dataset()->filenames_[current_file_index_],
"\" has body length ", body_size,
" bytes, which is not an exact multiple of the record length (",
dataset()->record_bytes_, " bytes).");
}
TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile(
dataset()->filenames_[current_file_index_], &file_));
input_buffer_ = absl::make_unique<io::InputBuffer>(
file_.get(), dataset()->buffer_size_);
TF_RETURN_IF_ERROR(
input_buffer_->SkipNBytes(dataset()->header_bytes_));
} while (true);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
current_file_index_));
// `input_buffer_` is empty if
// 1. GetNext has not been called even once.
// 2. All files have been read and iterator has been exhausted.
int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1;
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("current_pos"), current_pos));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
int64 current_file_index;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
&current_file_index));
current_file_index_ = size_t(current_file_index);
int64 current_pos;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("current_pos"), &current_pos));
// Seek to current_pos.
input_buffer_.reset();
file_.reset();
if (current_pos >= 0) { // There was an active input_buffer_.
uint64 file_size;
TF_RETURN_IF_ERROR(ctx->env()->GetFileSize(
dataset()->filenames_[current_file_index_], &file_size));
file_pos_limit_ = file_size - dataset()->footer_bytes_;
TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile(
dataset()->filenames_[current_file_index_], &file_));
input_buffer_ = absl::make_unique<io::InputBuffer>(
file_.get(), dataset()->buffer_size_);
TF_RETURN_IF_ERROR(input_buffer_->Seek(current_pos));
}
return Status::OK();
}
private:
mutex mu_;
size_t current_file_index_ GUARDED_BY(mu_) = 0;
std::unique_ptr<RandomAccessFile> file_
GUARDED_BY(mu_); // must outlive input_buffer_
std::unique_ptr<io::InputBuffer> input_buffer_ GUARDED_BY(mu_);
int64 file_pos_limit_ GUARDED_BY(mu_) = -1;
};
class CompressedIterator : public DatasetIterator<Dataset> {
public:
explicit CompressedIterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
do {
// We are currently processing a file, so try to read the next record.
if (buffered_input_stream_) {
const int64 current_pos = buffered_input_stream_->Tell();
if (dataset()->compression_type_.empty()) {
DCHECK_GE(file_pos_limit_, 0);
if (current_pos < file_pos_limit_) {
string record;
TF_RETURN_IF_ERROR(buffered_input_stream_->ReadNBytes(
dataset()->record_bytes_, &record));
metrics::RecordTFDataBytesRead(kFixedLengthRecordDatasetName,
dataset()->record_bytes_);
// Produce the record as output.
Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
record_tensor.scalar<string>()() = std::move(record);
out_tensors->emplace_back(std::move(record_tensor));
*end_of_sequence = false;
return Status::OK();
}
} else {
string record;
Status s = buffered_input_stream_->ReadNBytes(
dataset()->record_bytes_, &record);
if (s.ok()) {
metrics::RecordTFDataBytesRead(kFixedLengthRecordDatasetName,
dataset()->record_bytes_);
lookahead_cache_.append(record);
record = lookahead_cache_.substr(0, dataset()->record_bytes_);
lookahead_cache_ =
lookahead_cache_.substr(dataset()->record_bytes_);
// Produce the record as output.
Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
record_tensor.scalar<string>()() = std::move(record);
out_tensors->emplace_back(std::move(record_tensor));
*end_of_sequence = false;
return Status::OK();
}
if (errors::IsOutOfRange(s) && !record.empty()) {
uint64 body_size =
current_pos + record.size() -
(dataset()->header_bytes_ + dataset()->footer_bytes_);
return errors::DataLoss(
"Excluding the header (", dataset()->header_bytes_,
" bytes) and footer (", dataset()->footer_bytes_,
" bytes), input file \"",
dataset()->filenames_[current_file_index_],
"\" has body length ", body_size,
" bytes, which is not an exact multiple of the record "
"length (",
dataset()->record_bytes_, " bytes).");
}
}
// We have reached the end of the current file, so maybe
// move on to next file.
buffered_input_stream_.reset();
file_.reset();
++current_file_index_;
}
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
return Status::OK();
}
// Actually move on to next file.
if (dataset()->compression_type_.empty()) {
uint64 file_size;
TF_RETURN_IF_ERROR(ctx->env()->GetFileSize(
dataset()->filenames_[current_file_index_], &file_size));
file_pos_limit_ = file_size - dataset()->footer_bytes_;
uint64 body_size = file_size - (dataset()->header_bytes_ +
dataset()->footer_bytes_);
if (body_size % dataset()->record_bytes_ != 0) {
return errors::InvalidArgument(
"Excluding the header (", dataset()->header_bytes_,
" bytes) and footer (", dataset()->footer_bytes_,
" bytes), input file \"",
dataset()->filenames_[current_file_index_],
"\" has body length ", body_size,
" bytes, which is not an exact multiple of the record length "
"(",
dataset()->record_bytes_, " bytes).");
}
}
TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile(
dataset()->filenames_[current_file_index_], &file_));
if (!dataset()->compression_type_.empty()) {
const io::ZlibCompressionOptions zlib_options =
dataset()->compression_type_ == "ZLIB"
? io::ZlibCompressionOptions::DEFAULT()
: io::ZlibCompressionOptions::GZIP();
file_stream_ =
absl::make_unique<io::RandomAccessInputStream>(file_.get());
buffered_input_stream_ = absl::make_unique<io::ZlibInputStream>(
file_stream_.get(), dataset()->buffer_size_,
dataset()->buffer_size_, zlib_options);
} else {
buffered_input_stream_ = absl::make_unique<io::BufferedInputStream>(
file_.get(), dataset()->buffer_size_);
}
TF_RETURN_IF_ERROR(
buffered_input_stream_->SkipNBytes(dataset()->header_bytes_));
lookahead_cache_.clear();
if (!dataset()->compression_type_.empty()) {
TF_RETURN_IF_ERROR(buffered_input_stream_->ReadNBytes(
dataset()->footer_bytes_, &lookahead_cache_));
}
} while (true);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeSourceNode(std::move(args));
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
current_file_index_));
// `buffered_input_stream_` is empty if
// 1. GetNext has not been called even once.
// 2. All files have been read and iterator has been exhausted.
int64 current_pos =
buffered_input_stream_ ? buffered_input_stream_->Tell() : -1;
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("current_pos"), current_pos));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
int64 current_file_index;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
&current_file_index));
current_file_index_ = size_t(current_file_index);
int64 current_pos;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("current_pos"), &current_pos));
// Seek to current_pos.
buffered_input_stream_.reset();
file_.reset();
if (current_pos >= 0) { // There was an active buffered_input_stream_.
TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile(
dataset()->filenames_[current_file_index_], &file_));
const io::ZlibCompressionOptions zlib_options =
dataset()->compression_type_ == "ZLIB"
? io::ZlibCompressionOptions::DEFAULT()
: io::ZlibCompressionOptions::GZIP();
file_stream_ =
absl::make_unique<io::RandomAccessInputStream>(file_.get());
buffered_input_stream_ = absl::make_unique<io::ZlibInputStream>(
file_stream_.get(), dataset()->buffer_size_,
dataset()->buffer_size_, zlib_options);
lookahead_cache_.clear();
TF_RETURN_IF_ERROR(buffered_input_stream_->SkipNBytes(
current_pos - dataset()->footer_bytes_));
TF_RETURN_IF_ERROR(buffered_input_stream_->ReadNBytes(
dataset()->footer_bytes_, &lookahead_cache_));
}
return Status::OK();
}
private:
mutex mu_;
size_t current_file_index_ GUARDED_BY(mu_) = 0;
std::unique_ptr<RandomAccessFile> file_
GUARDED_BY(mu_); // must outlive buffered_input_stream_
std::unique_ptr<io::RandomAccessInputStream>
file_stream_; // must outlive buffered_input_stream_
std::unique_ptr<io::InputStreamInterface> buffered_input_stream_
GUARDED_BY(mu_);
int64 file_pos_limit_ GUARDED_BY(mu_) = -1;
string lookahead_cache_ GUARDED_BY(mu_);
};
const std::vector<string> filenames_;
const int64 header_bytes_;
const int64 record_bytes_;
const int64 footer_bytes_;
const int64 buffer_size_;
const string compression_type_;
};
const int op_version_;
};
REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordDataset").Device(DEVICE_CPU),
FixedLengthRecordDatasetOp);
REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordDatasetV2").Device(DEVICE_CPU),
FixedLengthRecordDatasetOp);
constexpr char kTFRecordDatasetName[] = "TFRecord";
class TFRecordDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
const Tensor* filenames_tensor;
OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
OP_REQUIRES(
ctx, filenames_tensor->dims() <= 1,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
filenames.push_back(filenames_tensor->flat<string>()(i));
}
string compression_type;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
&compression_type));
int64 buffer_size = -1;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
OP_REQUIRES(ctx, buffer_size >= 0,
errors::InvalidArgument(
"`buffer_size` must be >= 0 (0 == no buffering)"));
*output =
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
const string& compression_type, int64 buffer_size)
: DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
compression_type_(compression_type),
options_(io::RecordReaderOptions::CreateRecordReaderOptions(
compression_type)) {
if (buffer_size > 0) {
options_.buffer_size = buffer_size;
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, strings::StrCat(prefix, "::", kTFRecordDatasetName)});
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}
string DebugString() const override { return "TFRecordDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
Node* compression_type = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
Node* buffer_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(options_.buffer_size, &buffer_size));
TF_RETURN_IF_ERROR(b->AddDataset(
this, {filenames, compression_type, buffer_size}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
do {
// We are currently processing a file, so try to read the next record.
if (reader_) {
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
TensorShape({}));
Status s =
reader_->ReadRecord(&out_tensors->back().scalar<string>()());
if (s.ok()) {
metrics::RecordTFDataBytesRead(
kTFRecordDatasetName,
out_tensors->back().scalar<string>()().size());
*end_of_sequence = false;
return Status::OK();
}
out_tensors->pop_back();
if (!errors::IsOutOfRange(s)) {
// In case of other errors e.g., DataLoss, we still move forward
// the file index so that it works with ignore_errors.
// Otherwise the same file will repeat.
ResetStreamsLocked();
++current_file_index_;
return s;
}
// We have reached the end of the current file, so maybe
// move on to next file.
ResetStreamsLocked();
++current_file_index_;
}
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
return Status::OK();
}
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
} while (true);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeSourceNode(std::move(args));
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
current_file_index_));
if (reader_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("offset"), reader_->TellOffset()));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
ResetStreamsLocked();
int64 current_file_index;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
&current_file_index));
current_file_index_ = size_t(current_file_index);
if (reader->Contains(full_name("offset"))) {
int64 offset;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset"), &offset));
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
TF_RETURN_IF_ERROR(reader_->SeekOffset(offset));
}
return Status::OK();
}
private:
// Sets up reader streams to read from the file at `current_file_index_`.
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (current_file_index_ >= dataset()->filenames_.size()) {
return errors::InvalidArgument(
"current_file_index_:", current_file_index_,
" >= filenames_.size():", dataset()->filenames_.size());
}
// Actually move on to next file.
const string& next_filename =
dataset()->filenames_[current_file_index_];
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(next_filename, &file_));
reader_ = absl::make_unique<io::SequentialRecordReader>(
file_.get(), dataset()->options_);
return Status::OK();
}
// Resets all reader streams.
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
reader_.reset();
file_.reset();
}
mutex mu_;
size_t current_file_index_ GUARDED_BY(mu_) = 0;
// `reader_` will borrow the object that `file_` points to, so
// we must destroy `reader_` before `file_`.
std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_);
std::unique_ptr<io::SequentialRecordReader> reader_ GUARDED_BY(mu_);
};
const std::vector<string> filenames_;
const string compression_type_;
io::RecordReaderOptions options_;
};
};
REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU),
TFRecordDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow
You can’t perform that action at this time.