Skip to content
Permalink
Browse files Browse the repository at this point in the history
[tf2xla] Validate that stride and window size are positive
PiperOrigin-RevId: 504866231
  • Loading branch information
majnemer authored and tensorflower-gardener committed Jan 26, 2023
1 parent 3b1b9de commit 1295ae4
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 36 deletions.
30 changes: 30 additions & 0 deletions tensorflow/compiler/tests/pooling_ops_test.py
Expand Up @@ -18,7 +18,9 @@

from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
Expand Down Expand Up @@ -560,6 +562,34 @@ def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding,

self._TestPooling(nn_ops.avg_pool, AvgPoolGrad)

@test_util.disable_mlir_bridge(
"TODO(b/266613412): investigate FPE in AvgPoolGrad for TPU"
)
def testAvgPoolGradSamePaddingZeroStrideZeroSize(self):
output_gradient_vals = np.array([0.39117979], dtype=np.float32)
output_gradient_vals = output_gradient_vals.reshape([1, 1, 1, 1])
with self.session() as sess:
with self.test_scope():
output_gradients = array_ops.placeholder(
dtypes.float32, shape=output_gradient_vals.shape
)
t = gen_nn_ops.avg_pool_grad(
orig_input_shape=[1, 0, 0, 0],
grad=output_gradients,
ksize=[1, 0, 0, 0],
strides=[1, 0, 0, 0],
padding="SAME",
data_format="NCHW",
)
with self.assertRaisesRegex(
errors.InvalidArgumentError,
(
"Sliding window ksize field for dimension 1 must be positive but"
" is 0"
),
):
sess.run(t, {output_gradients: output_gradient_vals})

# The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than
# the stride size, so we only run the following tests on MaxPoolGrad.

Expand Down
108 changes: 72 additions & 36 deletions tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
Expand Up @@ -33,15 +33,41 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/determinism.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/tsl/platform/errors.h"

namespace tensorflow {
namespace {

template <typename T>
static Status ValidateKernelSizes(const T& ksizes) {
for (size_t i = 0; i < ksizes.size(); ++i) {
if (ksizes[i] <= 0) {
return errors::InvalidArgument(
"Sliding window ksize field for dimension ", i,
" must be positive but is ", ksizes[i]);
}
}
return OkStatus();
}

template <typename T>
static Status ValidateStrides(const T& strides) {
for (size_t i = 0; i < strides.size(); ++i) {
if (strides[i] <= 0) {
return errors::InvalidArgument(
"Sliding window stride field for dimension ", i,
" must be positive but is ", strides[i]);
}
}
return OkStatus();
}

// Superclass of pooling ops.
class PoolingOp : public XlaOpKernel {
public:
Expand Down Expand Up @@ -83,50 +109,54 @@ class PoolingOp : public XlaOpKernel {

protected:
StatusOr<std::vector<int64_t>> GetKernelSize(XlaOpKernelContext* ctx) {
if (ctx->num_inputs() == 1) {
return ksize_;
}
const TensorShape ksize_shape = ctx->InputShape(1);
// Validate input sizes.
if (!TensorShapeUtils::IsVector(ksize_shape)) {
return errors::InvalidArgument("ksize must be a vector, not shape ",
ksize_shape.DebugString());
}
if (ksize_shape.num_elements() != num_dims()) {
return errors::InvalidArgument(
"Sliding window ksize field must "
"specify ",
num_dims(), " dimensions");
}
std::vector<int64_t> ksize;
auto status = ctx->ConstantInputAsIntVector(1, &ksize);
if (!status.ok()) {
return status;
if (ctx->num_inputs() == 1) {
ksize = ksize_;
} else {
const TensorShape ksize_shape = ctx->InputShape(1);
// Validate input sizes.
if (!TensorShapeUtils::IsVector(ksize_shape)) {
return errors::InvalidArgument("ksize must be a vector, not shape ",
ksize_shape.DebugString());
}
if (ksize_shape.num_elements() != num_dims()) {
return errors::InvalidArgument(
"Sliding window ksize field must "
"specify ",
num_dims(), " dimensions");
}
auto status = ctx->ConstantInputAsIntVector(1, &ksize);
if (!status.ok()) {
return status;
}
}
TF_RETURN_IF_ERROR(ValidateKernelSizes(ksize));
return ksize;
}

StatusOr<std::vector<int64_t>> GetStride(XlaOpKernelContext* ctx) {
if (ctx->num_inputs() == 1) {
return stride_;
}
const TensorShape stride_shape = ctx->InputShape(2);
// Validate input sizes.
if (!TensorShapeUtils::IsVector(stride_shape)) {
return errors::InvalidArgument("stride must be a vector, not shape ",
stride_shape.DebugString());
}
if (stride_shape.num_elements() != num_dims()) {
return errors::InvalidArgument(
"Sliding window stride field must "
"specify ",
num_dims(), " dimensions");
}
std::vector<int64_t> stride;
auto status = ctx->ConstantInputAsIntVector(2, &stride);
if (!status.ok()) {
return status;
if (ctx->num_inputs() == 1) {
stride = stride_;
} else {
const TensorShape stride_shape = ctx->InputShape(2);
// Validate input sizes.
if (!TensorShapeUtils::IsVector(stride_shape)) {
return errors::InvalidArgument("stride must be a vector, not shape ",
stride_shape.DebugString());
}
if (stride_shape.num_elements() != num_dims()) {
return errors::InvalidArgument(
"Sliding window stride field must "
"specify ",
num_dims(), " dimensions");
}
auto status = ctx->ConstantInputAsIntVector(2, &stride);
if (!status.ok()) {
return status;
}
}
TF_RETURN_IF_ERROR(ValidateStrides(stride));
return stride;
}

Expand Down Expand Up @@ -355,10 +385,12 @@ class MaxPoolGradOp : public XlaOpKernel {
errors::InvalidArgument("Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES_OK(ctx, ValidateKernelSizes(ksize_));
OP_REQUIRES(ctx, stride_.size() == num_dims(),
errors::InvalidArgument("Sliding window strides field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES_OK(ctx, ValidateStrides(stride_));

const TensorShape tensor_in_shape = ctx->InputShape(0);
const TensorShape tensor_out_shape = ctx->InputShape(1);
Expand Down Expand Up @@ -446,11 +478,13 @@ class AvgPoolGradOp : public XlaOpKernel {
errors::InvalidArgument("Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES_OK(ctx, ValidateKernelSizes(ksize_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
OP_REQUIRES(ctx, stride_.size() == num_dims(),
errors::InvalidArgument("Sliding window strides field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES_OK(ctx, ValidateStrides(stride_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
OP_REQUIRES(ctx, padding_ != EXPLICIT,
errors::Unimplemented(
Expand Down Expand Up @@ -579,10 +613,12 @@ class MaxPoolGradGradOp : public XlaOpKernel {
errors::InvalidArgument("Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES_OK(ctx, ValidateKernelSizes(ksize_));
OP_REQUIRES(ctx, stride_.size() == num_dims(),
errors::InvalidArgument("Sliding window strides field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES_OK(ctx, ValidateStrides(stride_));

const TensorShape tensor_in_shape = ctx->InputShape(0);
const TensorShape tensor_out_shape = ctx->InputShape(1);
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/xla/client/padding.cc
Expand Up @@ -35,6 +35,16 @@ Status ValidatePaddingValues(absl::Span<const int64_t> input_dimensions,
input_dimensions.size(), window_dimensions.size(),
window_strides.size());
}
for (size_t i = 0; i < input_dimensions.size(); ++i) {
if (window_dimensions[i] <= 0) {
return InvalidArgument("Window dimension %u has non-positive size %d", i,
window_dimensions[i]);
}
if (window_strides[i] <= 0) {
return InvalidArgument("Window dimension %u has non-positive stride %d",
i, window_strides[i]);
}
}
return OkStatus();
}

Expand Down

0 comments on commit 1295ae4

Please sign in to comment.