Skip to content

Commit

Permalink
Merge pull request #51289 from tensorflow/mm-cherrypick-8a793b5d7f59e…
Browse files Browse the repository at this point in the history
…37ac7f3cd0954a750a2fe76bad4-on-r2.3

Prevent division by 0 in common shape functions.
  • Loading branch information
mihaimaruseac committed Aug 5, 2021
2 parents cdb6b5e + f639a0e commit b43738a
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensorflow/core/framework/common_shape_fns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,8 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
int64 input_depth_value = c->Value(input_depth_dim),
filter_input_depth_value = c->Value(filter_input_depth_dim);
if (filter_input_depth_value == 0)
return errors::InvalidArgument("Depth of filter must not be 0");
if (input_depth_value % filter_input_depth_value != 0)
return errors::InvalidArgument(
"Depth of input (", input_depth_value,
Expand All @@ -668,6 +670,8 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
int64 num_groups = input_depth_value / filter_input_depth_value;
if (c->ValueKnown(output_depth_dim)) {
int64 output_depth_value = c->Value(output_depth_dim);
if (num_groups == 0)
return errors::InvalidArgument("Number of groups must not be 0");
if (output_depth_value % num_groups != 0)
return errors::InvalidArgument(
"Depth of output (", output_depth_value,
Expand Down Expand Up @@ -798,6 +802,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
int64 input_depth_value = c->Value(input_depth_dim),
filter_input_depth_value = c->Value(filter_input_depth_dim);
if (filter_input_depth_value == 0)
return errors::InvalidArgument("Depth of filter must not be 0");
if (input_depth_value % filter_input_depth_value != 0)
return errors::InvalidArgument(
"Depth of input (", input_depth_value,
Expand All @@ -807,6 +813,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
int64 num_groups = input_depth_value / filter_input_depth_value;
if (c->ValueKnown(output_depth_dim)) {
int64 output_depth_value = c->Value(output_depth_dim);
if (num_groups == 0)
return errors::InvalidArgument("Number of groups must not be 0");
if (output_depth_value % num_groups != 0)
return errors::InvalidArgument(
"Depth of output (", output_depth_value,
Expand Down Expand Up @@ -2364,6 +2372,9 @@ Status SparseReduceShapeFn(InferenceContext* c) {

int64 ndims = shape_vec.size();
absl::flat_hash_set<int64> axes;
if (ndims == 0)
return errors::InvalidArgument(
"Number of dims in shape tensor must not be 0");
for (int i = 0; i < axes_vec.size(); i++) {
axes.insert((axes_vec(i) + ndims) % ndims);
}
Expand Down

0 comments on commit b43738a

Please sign in to comment.