From 37a5a962f8237dc856db0e3558c1c426ded3fc35 Mon Sep 17 00:00:00 2001 From: frreiss Date: Thu, 8 Aug 2019 10:17:56 -0700 Subject: [PATCH] Refactor RandomDataset into .h and .cc files Refactor RandomDataset into .h and .cc files Cleanup after refactoring --- .../core/kernels/data/experimental/BUILD | 1 + .../data/experimental/random_dataset_op.cc | 218 +++++++++--------- .../data/experimental/random_dataset_op.h | 50 ++++ 3 files changed, 162 insertions(+), 107 deletions(-) create mode 100644 tensorflow/core/kernels/data/experimental/random_dataset_op.h diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 65d8a1dbbd2d66..c95e3bf03e523c 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -285,6 +285,7 @@ tf_kernel_library( tf_kernel_library( name = "random_dataset_op", srcs = ["random_dataset_op.cc"], + hdrs = ["random_dataset_op.h"], deps = [ "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc index 404cfdf7cb987f..87fa32309ee521 100644 --- a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/random_dataset_op.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -23,134 +24,137 @@ limitations under the License. namespace tensorflow { namespace data { namespace experimental { -namespace { -class RandomDatasetOp : public DatasetOpKernel { +// Constants declared in random_dataset_op.h and used both here and in test +// cases. +/* static */ constexpr const char* const RandomDatasetOp::kDatasetType; +/* static */ constexpr const char* const RandomDatasetOp::kSeed; +/* static */ constexpr const char* const RandomDatasetOp::kSeed2; +/* static */ constexpr const char* const RandomDatasetOp::kOutputTypes; +/* static */ constexpr const char* const RandomDatasetOp::kOutputShapes; + +class RandomDatasetOp::Dataset : public DatasetBase { public: - explicit RandomDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + Dataset(OpKernelContext* ctx, int64 seed, int64 seed2) + : DatasetBase(DatasetContext(ctx)), seed_(seed), seed2_(seed2) {} - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - int64 seed; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "seed", &seed)); + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique( + Iterator::Params{this, strings::StrCat(prefix, "::Random")}); + } - int64 seed2; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "seed2", &seed2)); + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_INT64}); + return *dtypes; + } - // By TensorFlow convention, passing 0 for both seeds indicates - // that the shuffling should be seeded non-deterministically. - if (seed == 0 && seed2 == 0) { - seed = random::New64(); - seed2 = random::New64(); - } + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } - *output = new Dataset(ctx, seed, seed2); + string DebugString() const override { + return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_, + ")::Dataset"); + } + + int64 Cardinality() const override { return kInfiniteCardinality; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* seed = nullptr; + Node* seed2 = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); + TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {seed, seed2}, output)); + return Status::OK(); } private: - class Dataset : public DatasetBase { + class Iterator : public DatasetIterator { public: - Dataset(OpKernelContext* ctx, int64 seed, int64 seed2) - : DatasetBase(DatasetContext(ctx)), seed_(seed), seed2_(seed2) {} - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::Random")}); - } - - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = new DataTypeVector({DT_INT64}); - return *dtypes; + explicit Iterator(const Params& params) + : DatasetIterator(params), + parent_generator_(dataset()->seed_, dataset()->seed2_), + generator_(&parent_generator_) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + out_tensors->emplace_back(ctx->allocator({}), DT_INT64, TensorShape({})); + out_tensors->back().scalar()() = Random(); + *end_of_sequence = false; + return Status::OK(); } - const std::vector& output_shapes() const override { - static std::vector* shapes = - new std::vector({{}}); - return *shapes; + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); } - string DebugString() const override { - return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_, - ")::Dataset"); + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"), + num_random_samples_)); + return Status::OK(); } - int64 Cardinality() const override { return kInfiniteCardinality; } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* seed = nullptr; - Node* seed2 = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); - TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {seed, seed2}, output)); + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"), + &num_random_samples_)); + parent_generator_ = + random::PhiloxRandom(dataset()->seed_, dataset()->seed2_); + generator_ = + random::SingleSampleAdapter(&parent_generator_); + generator_.Skip(num_random_samples_); return Status::OK(); } private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params), - parent_generator_(dataset()->seed_, dataset()->seed2_), - generator_(&parent_generator_) {} - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - out_tensors->emplace_back(ctx->allocator({}), DT_INT64, - TensorShape({})); - out_tensors->back().scalar()() = Random(); - *end_of_sequence = false; - return Status::OK(); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeSourceNode(std::move(args)); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"), - num_random_samples_)); - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"), - &num_random_samples_)); - parent_generator_ = - random::PhiloxRandom(dataset()->seed_, dataset()->seed2_); - generator_ = random::SingleSampleAdapter( - &parent_generator_); - generator_.Skip(num_random_samples_); - return Status::OK(); - } - - private: - random::SingleSampleAdapter::ResultType Random() - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - num_random_samples_++; - auto out = generator_(); - return out; - } - mutex mu_; - random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); - random::SingleSampleAdapter generator_ - GUARDED_BY(mu_); - int64 num_random_samples_ GUARDED_BY(mu_) = 0; - }; - - const int64 seed_; - const int64 seed2_; + random::SingleSampleAdapter::ResultType Random() + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + num_random_samples_++; + auto out = generator_(); + return out; + } + mutex mu_; + random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); + random::SingleSampleAdapter generator_ + GUARDED_BY(mu_); + int64 num_random_samples_ GUARDED_BY(mu_) = 0; }; -}; + + const int64 seed_; + const int64 seed2_; +}; // RandomDatasetOp::Dataset + +RandomDatasetOp::RandomDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + +void RandomDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { + int64 seed; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "seed", &seed)); + + int64 seed2; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "seed2", &seed2)); + + // By TensorFlow convention, passing 0 for both seeds indicates + // that the shuffling should be seeded non-deterministically. + if (seed == 0 && seed2 == 0) { + seed = random::New64(); + seed2 = random::New64(); + } + + *output = new Dataset(ctx, seed, seed2); +} +namespace { REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU), RandomDatasetOp); diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op.h b/tensorflow/core/kernels/data/experimental/random_dataset_op.h new file mode 100644 index 00000000000000..649da90572d3f2 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/random_dataset_op.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +// See tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt for the +// API definition that corresponds to this kernel. +class RandomDatasetOp : public DatasetOpKernel { + public: + // Names of op parameters, public so that they can be accessed by test cases. + // Make sure that these are kept in sync with the REGISTER_OP call in + // tensorflow/core/ops/experimental_dataset_ops.cc + static constexpr const char* const kDatasetType = "Random"; + static constexpr const char* const kSeed = "seed"; + static constexpr const char* const kSeed2 = "seed2"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit RandomDatasetOp(OpKernelConstruction* ctx); + + protected: + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_DATASET_OP_H_