Skip to content

Commit

Permalink
Merge pull request #49356 from geetachavan1/cherrypicks_7P4OH
Browse files Browse the repository at this point in the history
Eliminate a division by 0 in 3D convolutions.
  • Loading branch information
mihaimaruseac committed May 24, 2021
2 parents cd3bbd4 + 07c4828 commit 69f1467
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tensorflow/core/kernels/conv_grad_ops_3d.cc
Expand Up @@ -239,6 +239,14 @@ class Conv3DBackpropInputOp : public OpKernel {
input_shape = context->input(0).shape();
}

OP_REQUIRES(context, input_shape.dims() == 5,
errors::InvalidArgument("input tensor must have 5 dimensions"));
OP_REQUIRES(
context, filter_shape.dims() == 5,
errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
OP_REQUIRES(
context, out_backprop_shape.dims() == 5,
errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
OP_REQUIRES(
context, input_shape.dim_size(4) == filter_shape.dim_size(3),
errors::InvalidArgument("input and filter_sizes must have the same "
Expand Down Expand Up @@ -360,6 +368,14 @@ class Conv3DCustomBackpropInputOp : public OpKernel {
input_shape = context->input(0).shape();
}

OP_REQUIRES(context, input_shape.dims() == 5,
errors::InvalidArgument("input tensor must have 5 dimensions"));
OP_REQUIRES(
context, filter_shape.dims() == 5,
errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
OP_REQUIRES(
context, out_backprop_shape.dims() == 5,
errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
OP_REQUIRES(
context, input_shape.dim_size(4) == filter_shape.dim_size(3),
errors::InvalidArgument("input and filter_sizes must have the same "
Expand Down Expand Up @@ -444,6 +460,11 @@ class Conv3DCustomBackpropInputOp : public OpKernel {
// contraction compared to sharding and matmuls.
const bool use_parallel_contraction = dims.batch_size == 1;

OP_REQUIRES(
context, work_unit_size > 0,
errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
"must all have at least 1 element"));

const size_t shard_size =
use_parallel_contraction
? 1
Expand Down Expand Up @@ -724,6 +745,14 @@ class Conv3DBackpropFilterOp : public OpKernel {
filter_shape = context->input(1).shape();
}

OP_REQUIRES(context, input_shape.dims() == 5,
errors::InvalidArgument("input tensor must have 5 dimensions"));
OP_REQUIRES(
context, filter_shape.dims() == 5,
errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
OP_REQUIRES(
context, out_backprop_shape.dims() == 5,
errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
OP_REQUIRES(
context, input_shape.dim_size(4) == filter_shape.dim_size(3),
errors::InvalidArgument("input and filter_sizes must have the same "
Expand Down Expand Up @@ -850,6 +879,14 @@ class Conv3DCustomBackpropFilterOp : public OpKernel {
filter_shape = context->input(1).shape();
}

OP_REQUIRES(context, input_shape.dims() == 5,
errors::InvalidArgument("input tensor must have 5 dimensions"));
OP_REQUIRES(
context, filter_shape.dims() == 5,
errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
OP_REQUIRES(
context, out_backprop_shape.dims() == 5,
errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
OP_REQUIRES(
context, input_shape.dim_size(4) == filter_shape.dim_size(3),
errors::InvalidArgument("input and filter_sizes must have the same "
Expand Down Expand Up @@ -936,6 +973,11 @@ class Conv3DCustomBackpropFilterOp : public OpKernel {

const int64 work_unit_size = size_A + size_B + size_C;

OP_REQUIRES(
context, work_unit_size > 0,
errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
"must all have at least 1 element"));

const size_t shard_size =
(target_working_set_size + work_unit_size - 1) / work_unit_size;

Expand Down

0 comments on commit 69f1467

Please sign in to comment.