Skip to content

Commit

Permalink
Merge pull request #52835 from tensorflow/mm-cp-6-on-r2.7
Browse files Browse the repository at this point in the history
Add missing validation
  • Loading branch information
mihaimaruseac committed Oct 28, 2021
2 parents 9cf36b5 + f43a11c commit 1a3174c
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 8 deletions.
15 changes: 12 additions & 3 deletions tensorflow/core/kernels/conv_ops.cc
Expand Up @@ -183,20 +183,29 @@ struct LaunchGrouped {
auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); };

// Shuffle input into temporary tensor.
Tensor input_shuffled(input.dtype(), TensorShape(post_shuffle(input)));
Tensor input_shuffled;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(input.dtype(), TensorShape(post_shuffle(input)),
&input_shuffled));
input_shuffled.tensor<T, 5>().device(device, on_shuffled) =
input.shaped<T, 5>(pre_shuffle(input)).shuffle(shuffle);

// Shuffle filter into temporary tensor.
Tensor filter_shuffled(filter.dtype(), TensorShape(post_shuffle(filter)));
Tensor filter_shuffled;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(filter.dtype(),
TensorShape(post_shuffle(filter)),
&filter_shuffled));
filter_shuffled.tensor<T, 5>().device(device, on_shuffled) =
filter.shaped<T, 5>(pre_shuffle(filter)).shuffle(shuffle);

// Wait for the completion of input/filter shuffles.
shuffles_completed.Wait();

// Write group convolution results into temporary output tensor.
Tensor output_shuffled(output->dtype(), TensorShape(post_shuffle(*output)));
Tensor output_shuffled;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(output->dtype(),
TensorShape(post_shuffle(*output)),
&output_shuffled));

for (int64_t i = 0; i < num_groups; ++i) {
// TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor
Expand Down
39 changes: 39 additions & 0 deletions tensorflow/core/kernels/linalg/tridiagonal_matmul_op_gpu.cu.cc
Expand Up @@ -66,6 +66,12 @@ class TridiagonalMatMulOpGpu : public OpKernel {
const Tensor& rhs = context->input(3);

const int ndims = rhs.dims();
OP_REQUIRES(
context, ndims >= 2,
errors::InvalidArgument("Input must have rank >= 2, but got ", ndims));
OP_REQUIRES_OK(context, ValidateInputTensor(superdiag, "superdiag", rhs));
OP_REQUIRES_OK(context, ValidateInputTensor(maindiag, "maindiag", rhs));
OP_REQUIRES_OK(context, ValidateInputTensor(subdiag, "subdiag", rhs));
int64 batch_size = 1;
for (int i = 0; i < ndims - 2; i++) {
batch_size *= rhs.dim_size(i);
Expand All @@ -85,6 +91,39 @@ class TridiagonalMatMulOpGpu : public OpKernel {
maindiag.flat<Scalar>().data(), subdiag.flat<Scalar>().data(),
rhs.flat<Scalar>().data(), output->flat<Scalar>().data()));
}

private:
Status ValidateInputTensor(const Tensor& tensor,
const std::string& tensor_name,
const Tensor& rhs) {
const int ndims = rhs.dims();
if (tensor.dims() != ndims) {
return errors::InvalidArgument(tensor_name,
" must have same rank as rhs, but got ",
tensor.dims(), " and ", ndims);
}
for (int i = 0; i < ndims - 2; i++) {
if (tensor.dim_size(i) != rhs.dim_size(i)) {
return errors::InvalidArgument(
tensor_name,
" must have same outer dimensions as rhs, but for index ", i,
", got ", tensor.dim_size(i), " and ", rhs.dim_size(i));
}
}
if (tensor.dim_size(ndims - 2) != 1) {
return errors::InvalidArgument(
tensor_name, "'s second-to-last dimension must be 1, but got ",
tensor.dim_size(ndims - 2));
}
if (tensor.dim_size(ndims - 1) != rhs.dim_size(ndims - 2)) {
return errors::InvalidArgument(tensor_name,
"'s last dimension size must be rhs's "
"second-to-last dimension size, but got ",
tensor.dim_size(ndims - 1), " and ",
rhs.dim_size(ndims - 2));
}
return Status::OK();
}
};

REGISTER_LINALG_OP_GPU("TridiagonalMatMul", (TridiagonalMatMulOpGpu<float>),
Expand Down
47 changes: 47 additions & 0 deletions tensorflow/core/kernels/maxpooling_op.cc
Expand Up @@ -325,6 +325,14 @@ class MaxPoolingGradOp : public OpKernel {
if (!context->status().ok()) {
return;
}
OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected orig_output shape to be ",
params.forward_output_shape(),
", but got ", tensor_out.shape()));
OP_REQUIRES(context, out_backprop.shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected grad shape to be ",
params.forward_output_shape(),
", but got ", out_backprop.shape()));

Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
Expand Down Expand Up @@ -538,6 +546,18 @@ class MaxPoolingGradGradOp : public OpKernel {
/*explicit_paddings=*/{},
FORMAT_NHWC,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected orig_output shape to be ",
params.forward_output_shape(),
", but got ", tensor_out.shape()));
OP_REQUIRES(
context, out_grad_backprop.shape() == tensor_in.shape(),
errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
", but got ", out_grad_backprop.shape()));

Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{2}, 0, tensor_out.shape(), &output));
Expand Down Expand Up @@ -742,6 +762,17 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
/*explicit_paddings=*/{},
data_format_,
tensor_in.shape()};
if (!context->status().ok()) {
return;
}
OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected orig_output shape to be ",
params.forward_output_shape(),
", but got ", tensor_out.shape()));
OP_REQUIRES(
context, out_grad_backprop.shape() == tensor_in.shape(),
errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
", but got ", out_grad_backprop.shape()));

functor::MaxPoolGradBackwardNoMask<T>()(
data_format_, tensor_in.flat<T>().data(), tensor_out.flat<T>().data(),
Expand Down Expand Up @@ -1096,6 +1127,14 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
if (!context->status().ok()) {
return;
}
OP_REQUIRES(context, grad_in.shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected grad shape to be ",
params.forward_output_shape(),
", but got ", grad_in.shape()));
OP_REQUIRES(context, argmax.shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected argmax shape to be ",
params.forward_output_shape(),
", but got ", argmax.shape()));

TensorShape out_shape({params.tensor_in_batch, params.tensor_in_rows,
params.tensor_in_cols, params.depth});
Expand Down Expand Up @@ -1156,6 +1195,14 @@ class MaxPoolingGradGradWithArgmaxOp : public OpKernel {
if (!context->status().ok()) {
return;
}
OP_REQUIRES(
context, grad_in.shape() == tensor_in.shape(),
errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
", but got ", grad_in.shape()));
OP_REQUIRES(context, argmax.shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected argmax shape to be ",
params.forward_output_shape(),
", but got ", argmax.shape()));

TensorShape out_shape({params.tensor_in_batch, params.out_height,
params.out_width, params.depth});
Expand Down
21 changes: 21 additions & 0 deletions tensorflow/core/kernels/pooling_ops_3d.cc
Expand Up @@ -361,6 +361,19 @@ class MaxPooling3dGradOp : public OpKernel {

OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
padding_, &out, &padding));

const int64_t depth = GetTensorDim(tensor_in, data_format_, 'C');
const int64_t in_batch = GetTensorDim(tensor_in, data_format_, 'N');
TensorShape out_shape = ShapeFromFormat(data_format_, in_batch,
{{out[2], out[1], out[0]}}, depth);
OP_REQUIRES(
context, tensor_out.shape() == out_shape,
errors::InvalidArgument("Expected orig_output shape to be ", out_shape,
", but got ", tensor_out.shape()));
OP_REQUIRES(context, out_backprop.shape() == out_shape,
errors::InvalidArgument("Expected grad shape to be ", out_shape,
", but got ", out_backprop.shape()));

LaunchMaxPooling3dGradOp<Device, T>::launch(
context, tensor_in, tensor_out, out_backprop, window, stride, out,
padding, data_format_, input_backprop);
Expand Down Expand Up @@ -707,6 +720,14 @@ class MaxPooling3dGradGradOp : public OpKernel {
Pool3dParameters params{context, ksize_, stride_,
padding_, data_format_, tensor_in.shape()};
if (!context->status().ok()) return; // params is invalid
OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected orig_output shape to be ",
params.forward_output_shape(),
", but got ", tensor_out.shape()));
OP_REQUIRES(
context, out_grad_backprop.shape() == tensor_in.shape(),
errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
", but got ", out_grad_backprop.shape()));

Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/core/kernels/pooling_ops_common.cc
Expand Up @@ -465,6 +465,16 @@ void DnnPoolingGradOp<T>::Compute(
if (!context->status().ok()) {
return;
}
if (tensor_out) {
OP_REQUIRES(context, tensor_out->shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected orig_output shape to be ",
params.forward_output_shape(),
", but got ", tensor_out->shape()));
}
OP_REQUIRES(context, out_backprop.shape() == params.forward_output_shape(),
errors::InvalidArgument("Expected grad shape to be ",
params.forward_output_shape(),
", but got ", out_backprop.shape()));

TensorFormat transformed_input_data_format = data_format;

Expand Down
14 changes: 9 additions & 5 deletions tensorflow/core/kernels/pooling_ops_common.h
Expand Up @@ -83,11 +83,6 @@ struct PoolParameters {
TensorFormat data_format;
};

// Checks if the sizes of the paddings are less than the size of window.
// This is required for MaxPool because it pads with -inf, so the pooling
// window cannot fully cover the padded area.
Status CheckPaddingSize(PoolParameters& params);

// An implementation of MaxPooling (forward).
// TODO (yongtang): Remove MaxPoolingOp and use MaxPoolingV2Op,
// QuantizedMaxPoolingOp depends on MaxPoolingOp so keep intact for now
Expand Down Expand Up @@ -194,6 +189,9 @@ class MaxPoolingOp : public OpKernel {
void SpatialMaxPool(OpKernelContext* context, Tensor* output,
const Tensor& tensor_in, const PoolParameters& params,
const Padding& padding) {
if (output->NumElements() == 0) {
return;
}
// On GPU, use Eigen's Spatial Max Pooling. On CPU, use an
// EigenMatrix version that is currently faster than Eigen's
// Spatial MaxPooling implementation.
Expand Down Expand Up @@ -448,6 +446,9 @@ class MaxPoolingV2Op : public OpKernel {
void SpatialMaxPool(OpKernelContext* context, Tensor* output,
const Tensor& tensor_in, const PoolParameters& params,
const Padding& padding) {
if (output->NumElements() == 0) {
return;
}
// On GPU, use Eigen's Spatial Max Pooling. On CPU, use an
// EigenMatrix version that is currently faster than Eigen's
// Spatial MaxPooling implementation.
Expand Down Expand Up @@ -566,6 +567,9 @@ template <typename Device, typename T>
void SpatialAvgPool(OpKernelContext* context, Tensor* output,
const Tensor& input, const PoolParameters& params,
const Padding& padding) {
if (output->NumElements() == 0) {
return;
}
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
Expand Down
44 changes: 44 additions & 0 deletions tensorflow/python/kernel_tests/pooling_ops_3d_test.py
Expand Up @@ -20,8 +20,13 @@

import numpy as np

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
Expand Down Expand Up @@ -506,5 +511,44 @@ def testAvgPoolGradSamePadding3_1_3d(self):
padding="SAME")


def testMaxPoolGradEagerShapeErrors(self):
with context.eager_mode():
orig_in = array_ops.ones((1, 1, 1, 1, 1))

# Test invalid orig_out shape
orig_out = array_ops.ones((1, 1, 1, 1, 2))
grad = array_ops.ones((1, 1, 1, 1, 1))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
r"Expected orig_output shape to be \[1,1,1,1,1\], but got "
r"\[1,1,1,1,2\]"):
gen_nn_ops.max_pool3d_grad(
orig_in, orig_out, grad, ksize=[1, 1, 1, 1, 1],
strides=[1, 1, 1, 1, 1], padding="VALID")
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
r"Expected orig_output shape to be \[1,1,1,1,1\], but got "
r"\[1,1,1,1,2\]"):
gen_nn_ops.max_pool3d_grad_grad(
orig_in, orig_out, grad, ksize=[1, 1, 1, 1, 1],
strides=[1, 1, 1, 1, 1], padding="VALID")

# Test invalid grad shape
orig_out = array_ops.ones((1, 1, 1, 1, 1))
grad = array_ops.ones((1, 1, 1, 1, 2))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
r"Expected grad shape to be \[1,1,1,1,1\], but got \[1,1,1,1,2\]"):
gen_nn_ops.max_pool3d_grad(
orig_in, orig_out, grad, ksize=[1, 1, 1, 1, 1],
strides=[1, 1, 1, 1, 1], padding="VALID")
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
r"Expected grad shape to be \[1,1,1,1,1\], but got \[1,1,1,1,2\]"):
gen_nn_ops.max_pool3d_grad_grad(
orig_in, orig_out, grad, ksize=[1, 1, 1, 1, 1],
strides=[1, 1, 1, 1, 1], padding="VALID")


if __name__ == "__main__":
test.main()

0 comments on commit 1a3174c

Please sign in to comment.