Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent division by 0 in common shape functions.
PiperOrigin-RevId: 387712197
Change-Id: Id25c7460e35b68aeeeac23b9a88e455b443ee149
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 30, 2021
1 parent 1071f55 commit 8a793b5
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensorflow/core/framework/common_shape_fns.cc
Expand Up @@ -672,6 +672,8 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
int64_t input_depth_value = c->Value(input_depth_dim),
filter_input_depth_value = c->Value(filter_input_depth_dim);
if (filter_input_depth_value == 0)
return errors::InvalidArgument("Depth of filter must not be 0");
if (input_depth_value % filter_input_depth_value != 0)
return errors::InvalidArgument(
"Depth of input (", input_depth_value,
Expand All @@ -681,6 +683,8 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
int64_t num_groups = input_depth_value / filter_input_depth_value;
if (c->ValueKnown(output_depth_dim)) {
int64_t output_depth_value = c->Value(output_depth_dim);
if (num_groups == 0)
return errors::InvalidArgument("Number of groups must not be 0");
if (output_depth_value % num_groups != 0)
return errors::InvalidArgument(
"Depth of output (", output_depth_value,
Expand Down Expand Up @@ -816,6 +820,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
int64_t input_depth_value = c->Value(input_depth_dim),
filter_input_depth_value = c->Value(filter_input_depth_dim);
if (filter_input_depth_value == 0)
return errors::InvalidArgument("Depth of filter must not be 0");
if (input_depth_value % filter_input_depth_value != 0)
return errors::InvalidArgument(
"Depth of input (", input_depth_value,
Expand All @@ -825,6 +831,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
int64_t num_groups = input_depth_value / filter_input_depth_value;
if (c->ValueKnown(output_depth_dim)) {
int64_t output_depth_value = c->Value(output_depth_dim);
if (num_groups == 0)
return errors::InvalidArgument("Number of groups must not be 0");
if (output_depth_value % num_groups != 0)
return errors::InvalidArgument(
"Depth of output (", output_depth_value,
Expand Down Expand Up @@ -2456,6 +2464,9 @@ Status SparseReduceShapeFn(InferenceContext* c) {

int64_t ndims = shape_vec.size();
absl::flat_hash_set<int64> axes;
if (ndims == 0)
return errors::InvalidArgument(
"Number of dims in shape tensor must not be 0");
for (int i = 0; i < axes_vec.size(); i++) {
axes.insert((axes_vec(i) + ndims) % ndims);
}
Expand Down

0 comments on commit 8a793b5

Please sign in to comment.