Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add true_classes input validation for candidate sampler ops.
The values must be within the valid range of the sampler.  Added a
check for this.

PiperOrigin-RevId: 479441496
  • Loading branch information
cantonios authored and tensorflower-gardener committed Oct 6, 2022
1 parent 42d7ad5 commit b389f5c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tensorflow/core/kernels/candidate_sampler_ops.cc
Expand Up @@ -73,6 +73,14 @@ class BaseCandidateSamplerOp : public OpKernel {

gtl::ArraySlice<int64_t> true_candidate(
true_classes.matrix<int64_t>().data(), batch_size * num_true_);

for (const auto& candidate : true_candidate) {
OP_REQUIRES(context, candidate >= 0 && candidate < sampler_->range(),
errors::InvalidArgument("`true_candidate` out of range [", 0,
", ", sampler_->range(),
"), received ", candidate));
}

gtl::MutableArraySlice<int64_t> sampled_candidate(
out_sampled_candidates->vec<int64_t>().data(), num_sampled_);
gtl::MutableArraySlice<float> true_expected_count(
Expand Down
Expand Up @@ -18,6 +18,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import candidate_sampling_ops
Expand Down Expand Up @@ -127,6 +128,27 @@ def draw(seed):
# twice very rarely.
self.assertLessEqual(num_same, 2)

def testCandidateOutOfRange(self):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"out of range"):
self.evaluate(
candidate_sampling_ops.log_uniform_candidate_sampler(
true_classes=[[0, 10]],
num_true=2,
num_sampled=1000,
unique=False,
range_max=2))

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"out of range"):
self.evaluate(
candidate_sampling_ops.log_uniform_candidate_sampler(
true_classes=[[0, -10]],
num_true=2,
num_sampled=1000,
unique=False,
range_max=2))


if __name__ == "__main__":
test.main()

0 comments on commit b389f5c

Please sign in to comment.