Skip to content

Commit

Permalink
Avoid Segfault for scalar shapes.
Browse files Browse the repository at this point in the history
Calling tensor::FromElementsOp with an empty vector of elements and no type
causes a segfault. We need to let the FromElementsOp know which scalar type it
should have.
Also add back the DynamicBroadcastInDimOp canonicalization patterns, which
previously prevented this bug from happening.
Add a regression test that demonstrates the bug.

PiperOrigin-RevId: 417561444
Change-Id: I6d1d6cfb71aabbad6102422625a00bbe253ac95a
  • Loading branch information
akuegel authored and tensorflower-gardener committed Dec 21, 2021
1 parent 69db6c4 commit 35f0fab
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
Expand Up @@ -157,6 +157,10 @@ llvm::Optional<Value> simplifyBroadcast(ShapeComponentAnalysis& analysis,
shapes_found.push_back(*found_shape);
maxRank = std::max(maxRank, found_shape->size());
}
if (maxRank == 0) {
return Value(builder->create<tensor::FromElementsOp>(
loc, shapes[0].getType(), SmallVector<Value>()));
}

SmallVector<const ShapeComponentAnalysis::SymbolicExpr*> joined_dimensions(
maxRank);
Expand Down
@@ -0,0 +1,11 @@
builtin.func @test(%V__0 : tensor<i1> { python_test_attrs.static_type = tensor<i1> }, %V__1 : tensor<f32> { python_test_attrs.static_type = tensor<f32> }, %V__2 : tensor<f32> { python_test_attrs.static_type = tensor<f32> }) -> tensor<f32> {
%0 = "tf.Cast"(%V__0) : (tensor<i1>) -> tensor<i1>
%1 = "tf.Selu"(%V__2) : (tensor<f32>) -> tensor<f32>
%2 = "tf.NextAfter"(%1, %V__2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%3 = "tf.Elu"(%2) : (tensor<f32>) -> tensor<f32>
%4 = "tf.Cosh"(%3) : (tensor<f32>) -> tensor<f32>
%5 = "tf.Elu"(%4) : (tensor<f32>) -> tensor<f32>
%6 = "tf.Div"(%V__1, %5) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%7 = "tf.Select"(%0, %6, %V__1) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %7 : tensor<f32>
}

0 comments on commit 35f0fab

Please sign in to comment.