Skip to content

Commit

Permalink
Refactor RandomDataset into .h and .cc files
Browse files Browse the repository at this point in the history
Refactor RandomDataset into .h and .cc files

Cleanup after refactoring
  • Loading branch information
frreiss committed Aug 13, 2019
1 parent deabf71 commit 37a5a96
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 107 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/kernels/data/experimental/BUILD
Expand Up @@ -285,6 +285,7 @@ tf_kernel_library(
tf_kernel_library(
name = "random_dataset_op",
srcs = ["random_dataset_op.cc"],
hdrs = ["random_dataset_op.h"],
deps = [
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
Expand Down
218 changes: 111 additions & 107 deletions tensorflow/core/kernels/data/experimental/random_dataset_op.cc
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/random_dataset_op.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
Expand All @@ -23,134 +24,137 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace experimental {
namespace {

class RandomDatasetOp : public DatasetOpKernel {
// Constants declared in random_dataset_op.h and used both here and in test
// cases.
/* static */ constexpr const char* const RandomDatasetOp::kDatasetType;
/* static */ constexpr const char* const RandomDatasetOp::kSeed;
/* static */ constexpr const char* const RandomDatasetOp::kSeed2;
/* static */ constexpr const char* const RandomDatasetOp::kOutputTypes;
/* static */ constexpr const char* const RandomDatasetOp::kOutputShapes;

class RandomDatasetOp::Dataset : public DatasetBase {
public:
explicit RandomDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
: DatasetBase(DatasetContext(ctx)), seed_(seed), seed2_(seed2) {}

void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
int64 seed;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(
Iterator::Params{this, strings::StrCat(prefix, "::Random")});
}

int64 seed2;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_INT64});
return *dtypes;
}

// By TensorFlow convention, passing 0 for both seeds indicates
// that the shuffling should be seeded non-deterministically.
if (seed == 0 && seed2 == 0) {
seed = random::New64();
seed2 = random::New64();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}

*output = new Dataset(ctx, seed, seed2);
string DebugString() const override {
return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_,
")::Dataset");
}

int64 Cardinality() const override { return kInfiniteCardinality; }

protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* seed = nullptr;
Node* seed2 = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
TF_RETURN_IF_ERROR(b->AddDataset(this, {seed, seed2}, output));
return Status::OK();
}

private:
class Dataset : public DatasetBase {
class Iterator : public DatasetIterator<Dataset> {
public:
Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
: DatasetBase(DatasetContext(ctx)), seed_(seed), seed2_(seed2) {}

std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(
Iterator::Params{this, strings::StrCat(prefix, "::Random")});
}

const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_INT64});
return *dtypes;
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
parent_generator_(dataset()->seed_, dataset()->seed2_),
generator_(&parent_generator_) {}

Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
out_tensors->emplace_back(ctx->allocator({}), DT_INT64, TensorShape({}));
out_tensors->back().scalar<int64>()() = Random();
*end_of_sequence = false;
return Status::OK();
}

const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeSourceNode(std::move(args));
}

string DebugString() const override {
return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_,
")::Dataset");
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
num_random_samples_));
return Status::OK();
}

int64 Cardinality() const override { return kInfiniteCardinality; }

protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* seed = nullptr;
Node* seed2 = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
TF_RETURN_IF_ERROR(b->AddDataset(this, {seed, seed2}, output));
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
&num_random_samples_));
parent_generator_ =
random::PhiloxRandom(dataset()->seed_, dataset()->seed2_);
generator_ =
random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
generator_.Skip(num_random_samples_);
return Status::OK();
}

private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
parent_generator_(dataset()->seed_, dataset()->seed2_),
generator_(&parent_generator_) {}

Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
out_tensors->emplace_back(ctx->allocator({}), DT_INT64,
TensorShape({}));
out_tensors->back().scalar<int64>()() = Random();
*end_of_sequence = false;
return Status::OK();
}

protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeSourceNode(std::move(args));
}

Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
num_random_samples_));
return Status::OK();
}

Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
&num_random_samples_));
parent_generator_ =
random::PhiloxRandom(dataset()->seed_, dataset()->seed2_);
generator_ = random::SingleSampleAdapter<random::PhiloxRandom>(
&parent_generator_);
generator_.Skip(num_random_samples_);
return Status::OK();
}

private:
random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_random_samples_++;
auto out = generator_();
return out;
}
mutex mu_;
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
random::SingleSampleAdapter<random::PhiloxRandom> generator_
GUARDED_BY(mu_);
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
};

const int64 seed_;
const int64 seed2_;
random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_random_samples_++;
auto out = generator_();
return out;
}
mutex mu_;
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
random::SingleSampleAdapter<random::PhiloxRandom> generator_
GUARDED_BY(mu_);
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
};
};

const int64 seed_;
const int64 seed2_;
}; // RandomDatasetOp::Dataset

RandomDatasetOp::RandomDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}

void RandomDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
int64 seed;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));

int64 seed2;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));

// By TensorFlow convention, passing 0 for both seeds indicates
// that the shuffling should be seeded non-deterministically.
if (seed == 0 && seed2 == 0) {
seed = random::New64();
seed2 = random::New64();
}

*output = new Dataset(ctx, seed, seed2);
}
namespace {

REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU),
RandomDatasetOp);
Expand Down
50 changes: 50 additions & 0 deletions tensorflow/core/kernels/data/experimental/random_dataset_op.h
@@ -0,0 +1,50 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_DATASET_OP_H_

#include "tensorflow/core/framework/dataset.h"

namespace tensorflow {
namespace data {
namespace experimental {

// See tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt for the
// API definition that corresponds to this kernel.
class RandomDatasetOp : public DatasetOpKernel {
public:
// Names of op parameters, public so that they can be accessed by test cases.
// Make sure that these are kept in sync with the REGISTER_OP call in
// tensorflow/core/ops/experimental_dataset_ops.cc
static constexpr const char* const kDatasetType = "Random";
static constexpr const char* const kSeed = "seed";
static constexpr const char* const kSeed2 = "seed2";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";

explicit RandomDatasetOp(OpKernelConstruction* ctx);

protected:
virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;

private:
class Dataset;
};

} // namespace experimental
} // namespace data
} // namespace tensorflow

#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_DATASET_OP_H_

0 comments on commit 37a5a96

Please sign in to comment.