Skip to content
Permalink
Browse files Browse the repository at this point in the history
[tf.data] Set limit on number of threads used in threadpool_dataset.
PiperOrigin-RevId: 410922677
Change-Id: Ib25814a99043ab10805b5d2d7088ae0e0b7b04fd
  • Loading branch information
aaudiber authored and tensorflower-gardener committed Nov 19, 2021
1 parent dc94fe9 commit e3749a6
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
Expand Up @@ -39,6 +39,22 @@ namespace experimental {
PrivateThreadPoolDatasetOp::kDatasetType;
/* static */ constexpr const char* const PrivateThreadPoolDatasetOp::kDatasetOp;

namespace {
// To prevent integer overflow issues when allocating threadpool memory for an
// unreasonable number of threads.
constexpr int kThreadLimit = 65536;

Status ValidateNumThreads(int32_t num_threads) {
if (num_threads < 0) {
return errors::InvalidArgument("`num_threads` must be >= 0");
}
if (num_threads >= kThreadLimit) {
return errors::InvalidArgument("`num_threads` must be < ", kThreadLimit);
}
return Status::OK();
}
} // namespace

class ThreadPoolResource : public ResourceBase {
public:
ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
Expand Down Expand Up @@ -83,9 +99,7 @@ class ThreadPoolHandleOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
&max_intra_op_parallelism_));
OP_REQUIRES(
ctx, num_threads_ > 0,
errors::InvalidArgument("`num_threads` must be greater than zero."));
OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads_));
}

// The resource is deleted from the resource manager only when it is private
Expand Down Expand Up @@ -531,8 +545,7 @@ void PrivateThreadPoolDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx,
DatasetBase* input,
int32_t num_threads,
DatasetBase** output) {
OP_REQUIRES(ctx, num_threads >= 0,
errors::InvalidArgument("`num_threads` must be >= 0"));
OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads));
*output = new Dataset(ctx,
DatasetContext(DatasetContext::Params(
{PrivateThreadPoolDatasetOp::kDatasetType,
Expand All @@ -546,8 +559,7 @@ void PrivateThreadPoolDatasetOp::MakeDataset(OpKernelContext* ctx,
int64_t num_threads = 0;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64_t>(ctx, "num_threads", &num_threads));
OP_REQUIRES(ctx, num_threads >= 0,
errors::InvalidArgument("`num_threads` must be >= 0"));
OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads));
*output = new Dataset(ctx, input, num_threads);
}

Expand Down

0 comments on commit e3749a6

Please sign in to comment.