Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix size check for large input shape and rates.
To address a check failure for exceedingly large output shapes, we need to
`AppendShapeWithStatus`.

PiperOrigin-RevId: 462885894
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jul 24, 2022
1 parent 7f26c09 commit 552bfce
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/random_op.cc
Expand Up @@ -166,7 +166,7 @@ class RandomGammaOp : public OpKernel {
}
const int64_t samples_per_alpha = samples_shape.num_elements();

samples_shape.AppendShape(alpha_t.shape());
OP_REQUIRES_OK(ctx, samples_shape.AppendShapeWithStatus(alpha_t.shape()));
// Allocate output samples.
Tensor* samples_t = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/random_poisson_op.cc
Expand Up @@ -296,8 +296,8 @@ class RandomPoissonOp : public OpKernel {
TensorShape samples_shape;
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &samples_shape));
const int64_t num_samples = samples_shape.num_elements();
OP_REQUIRES_OK(ctx, samples_shape.AppendShapeWithStatus(rate_t.shape()));

samples_shape.AppendShape(rate_t.shape());
// Allocate output samples.
Tensor* samples_t = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/kernel_tests/random/random_gamma_test.py
Expand Up @@ -16,7 +16,10 @@

import numpy as np

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
Expand Down Expand Up @@ -216,6 +219,16 @@ def testPositive(self):
self.assertEqual(0, math_ops.reduce_sum(math_ops.cast(
math_ops.less_equal(x, 0.), dtype=dtypes.int64)).eval())

def testSizeTooLarge(self):
# Grappler asserts on size overflow, so this error is only caught when
# running eagerly.
if context.executing_eagerly():
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"overflow"):
rate = constant_op.constant(1.0, shape=(4, 4, 4, 4, 4))
self.evaluate(
random_ops.random_gamma(
shape=[46902, 51188, 34063, 59195], alpha=rate))

if __name__ == "__main__":
test.main()
9 changes: 9 additions & 0 deletions tensorflow/python/kernel_tests/random/random_poisson_test.py
Expand Up @@ -17,6 +17,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests.random import util
Expand Down Expand Up @@ -171,6 +172,14 @@ def testInfRate(self):
sample = random_ops.random_poisson(shape=[2], lam=np.inf)
self.assertAllEqual([np.inf, np.inf], self.evaluate(sample))

def testSizeTooLarge(self):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"overflow"):
rate = constant_op.constant(1.0, shape=(4, 4, 4, 4, 4))
self.evaluate(
random_ops.random_poisson(
shape=[46902, 51188, 34063, 59195], lam=rate))


if __name__ == "__main__":
test.main()

0 comments on commit 552bfce

Please sign in to comment.