Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix tensor shape dtype bug in parameterized_truncated_normal.
The original made the assumption that the dtype is int32 when it could
also be int64 - leading to a crash due to mismatched type.

PiperOrigin-RevId: 456685278
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jun 23, 2022
1 parent 2232d17 commit 72180be
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
13 changes: 6 additions & 7 deletions tensorflow/core/kernels/parameterized_truncated_normal_op.cc
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/stateless_random_ops.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h"
Expand Down Expand Up @@ -630,20 +631,18 @@ class ParameterizedTruncatedNormalOp : public OpKernel {
OP_REQUIRES(ctx, shape_tensor.NumElements() > 0,
errors::InvalidArgument("Shape tensor must not be empty, got ",
shape_tensor.DebugString()));
int32_t num_batches = shape_tensor.flat<int32>()(0);
TensorShape tensor_shape;
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_tensor, &tensor_shape));

int32_t num_batches = tensor_shape.dim_size(0);
int32_t samples_per_batch = 1;
const int32_t num_dims = shape_tensor.dim_size(0);
const int32_t num_dims = tensor_shape.dims();
for (int32_t i = 1; i < num_dims; i++) {
samples_per_batch *= shape_tensor.flat<int32>()(i);
samples_per_batch *= tensor_shape.dim_size(i);
}
const int32_t num_elements = num_batches * samples_per_batch;

// Allocate the output before fudging num_batches and samples_per_batch.
auto shape_vec = shape_tensor.flat<int32>();
TensorShape tensor_shape;
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
shape_vec.data(), shape_vec.size(), &tensor_shape));
Tensor* samples_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor));

Expand Down
Expand Up @@ -303,6 +303,29 @@ def testSamplingWithSmallStdDevFarFromBound(self):
self.assertAllGreater(samples, 0.)
self.assertAllGreater(samples_stateless, 0.)

def testShapeTypes(self):
for shape_dtype in [np.int32, np.int64]:
shape = np.array([1000], dtype=shape_dtype)
sample_op = random_ops.parameterized_truncated_normal(
shape=shape, means=0.0, stddevs=0.1, minvals=-1., maxvals=1.)
new_seed = random_ops.random_uniform([2],
seed=1234,
minval=0,
maxval=(2**31 - 1),
dtype=np.int32)
sample_op_stateless = stateless.stateless_parameterized_truncated_normal(
shape=shape,
seed=new_seed,
means=0.0,
stddevs=0.1,
minvals=-1.,
maxvals=1.)

samples = self.evaluate(sample_op)
stateless_samples = self.evaluate(sample_op_stateless)
self.assertAllEqual(samples.shape, shape)
self.assertAllEqual(stateless_samples.shape, shape)

def testStatelessParameterizedTruncatedNormalHasGrads(self):
mean = variables.Variable(0.01)
stddev = variables.Variable(1.)
Expand Down

0 comments on commit 72180be

Please sign in to comment.