diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index fae97f270f455e..365376b6c2eef0 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -39,6 +39,22 @@ namespace experimental { PrivateThreadPoolDatasetOp::kDatasetType; /* static */ constexpr const char* const PrivateThreadPoolDatasetOp::kDatasetOp; +namespace { +// To prevent integer overflow issues when allocating threadpool memory for an +// unreasonable number of threads. +constexpr int kThreadLimit = 65536; + +Status ValidateNumThreads(int32_t num_threads) { + if (num_threads < 0) { + return errors::InvalidArgument("`num_threads` must be >= 0"); + } + if (num_threads >= kThreadLimit) { + return errors::InvalidArgument("`num_threads` must be < ", kThreadLimit); + } + return Status::OK(); +} +} // namespace + class ThreadPoolResource : public ResourceBase { public: ThreadPoolResource(Env* env, const ThreadOptions& thread_options, @@ -83,9 +99,7 @@ class ThreadPoolHandleOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", &max_intra_op_parallelism_)); - OP_REQUIRES( - ctx, num_threads_ > 0, - errors::InvalidArgument("`num_threads` must be greater than zero.")); + OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads_)); } // The resource is deleted from the resource manager only when it is private @@ -529,8 +543,7 @@ void PrivateThreadPoolDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx, DatasetBase* input, int32_t num_threads, DatasetBase** output) { - OP_REQUIRES(ctx, num_threads >= 0, - errors::InvalidArgument("`num_threads` must be >= 0")); + OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads)); *output = new Dataset(ctx, DatasetContext(DatasetContext::Params( {PrivateThreadPoolDatasetOp::kDatasetType, @@ -544,8 +557,7 @@ void PrivateThreadPoolDatasetOp::MakeDataset(OpKernelContext* ctx, int64_t num_threads = 0; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, "num_threads", &num_threads)); - OP_REQUIRES(ctx, num_threads >= 0, - errors::InvalidArgument("`num_threads` must be >= 0")); + OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads)); *output = new Dataset(ctx, input, num_threads); }