Skip to content

Commit

Permalink
Merge pull request #47060 from WindQAQ:partially-infer-conv-return-types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 357681939
Change-Id: Iad97b80eda7008fa8bbe594a1a087916c50c46b1
  • Loading branch information
tensorflower-gardener committed Feb 16, 2021
2 parents fe3a2d8 + e7cf096 commit e7a5816
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 34 deletions.
72 changes: 38 additions & 34 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1628,44 +1628,48 @@ static LogicalResult inferConvReturnTypes(
return failure();
}

// For operands having dynamic shape.
// Output always have `num_dims` rank. All dimensions are initialized to
// dynamic size and can be partially inferred.
SmallVector<int64_t, 4> return_shape(num_dims, ShapedType::kDynamicSize);
if (!input_ty.hasStaticShape() || !filter_ty.hasStaticShape()) {
inferredReturnTypes.assign(
{RankedTensorType::get(return_shape, input_ty.getElementType())});
return success();
}

// Checks the size of each of the output dimension.
for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
int64_t stride = get_int(strides[dim]);
tensorflow::int64 expected_output_size;
tensorflow::int64 pad_low;
tensorflow::int64 pad_high;
// Retrieve padding, if defined explicitly.
if (padding == tensorflow::Padding::EXPLICIT) {
pad_low = get_int(explicit_padding[2 * dim]);
pad_high = get_int(explicit_padding[2 * dim + 1]);
// Output batch and channel dimension can be obtained using utilities from
// tensorflow/core/util/tensor_format.h.
if (input_ty.hasRank()) {
return_shape[GetTensorBatchDimIndex(num_dims, format)] =
input_ty.getDimSize(GetTensorBatchDimIndex(num_dims, format));
}
if (filter_ty.hasRank()) {
return_shape[GetTensorFeatureDimIndex(num_dims, format)] =
filter_ty.getDimSize(GetFilterTensorOutputChannelsDimIndex(
num_dims, tensorflow::FORMAT_HWIO));
}
// Spatial dimensions can be inferred only when both input and filter are
// ranked because we need to get their spatial dimensions.
if (input_ty.hasRank() && filter_ty.hasRank()) {
// Checks the size of each of the output spatial dimensions.
for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
int64_t stride = get_int(strides[dim]);
tensorflow::int64 expected_output_size;
tensorflow::int64 pad_low;
tensorflow::int64 pad_high;
// Retrieve padding, if defined explicitly.
if (padding == tensorflow::Padding::EXPLICIT) {
pad_low = get_int(explicit_padding[2 * dim]);
pad_high = get_int(explicit_padding[2 * dim + 1]);
}
// Skip if input or filter size is dynamic.
if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue;
// Calculate the expected_output_size.
tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
input_ty.getDimSize(dim), filter_ty.getDimSize(i),
get_int(dilations[dim]), stride, padding, &expected_output_size,
&pad_low, &pad_high);
// Return failure if expected_output_size could not be calculated.
if (!status.ok()) return failure();
return_shape[dim] = expected_output_size;
}
// Calculate the expected_output_size.
tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
input_ty.getDimSize(dim), filter_ty.getDimSize(i),
get_int(dilations[dim]), stride, padding, &expected_output_size,
&pad_low, &pad_high);
// Return failure if expected_output_size could not be calculated.
if (!status.ok()) return failure();
return_shape[dim] = expected_output_size;
}

// The remaining dimensions can be obtained using utilities from
// tensorflow/core/util/tensor_format.h.
return_shape[GetTensorBatchDimIndex(num_dims, format)] =
input_ty.getShape()[GetTensorBatchDimIndex(num_dims, format)];
return_shape[GetTensorFeatureDimIndex(num_dims, format)] =
filter_ty.getShape()[GetFilterTensorOutputChannelsDimIndex(
num_dims, tensorflow::FORMAT_HWIO)];

inferredReturnTypes.assign(
{RankedTensorType::get(return_shape, input_ty.getElementType())});
return success();
Expand Down
91 changes: 91 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1045,4 +1045,95 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK-SAME: tensor<i32>
return %arg0 : tensor<*xi32>
}

// Test conv2d inferReturnTypes can infer some information when input or
// filter does not have fully static shape.

// CHECK-LABEL: func @conv2d_unranked_input_and_filter
func @conv2d_unranked_input_and_filter(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<?x?x?x?xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_unranked_filter
func @conv2d_unranked_filter(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<256x?x?x?xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_unranked_filter_and_dynamic_batch
func @conv2d_unranked_filter_and_dynamic_batch(%arg0: tensor<?x32x32x3xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<?x?x?x?xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x32x32x3xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_unranked_input
func @conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<?x?x?x16xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_unranked_input_and_dynamic_channel
func @conv2d_unranked_input_and_dynamic_channel(%arg0: tensor<*xf32>, %arg1: tensor<3x3x3x?xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<?x?x?x?xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<3x3x3x?xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_dynamic_batch
func @conv2d_dynamic_batch(%arg0: tensor<?x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<?x32x32x16xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_dynamic_channel
func @conv2d_dynamic_channel(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x?xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<256x32x32x?xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x?xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_fully_dynamic_spatial_dim
func @conv2d_fully_dynamic_spatial_dim(%arg0: tensor<256x?x?x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<256x?x?x16xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x?x?x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_partially_dynamic_spatial_dim
func @conv2d_partially_dynamic_spatial_dim(%arg0: tensor<256x?x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<256x?x32x16xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x?x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_dynamic_batch_and_partially_dynamic_spatial_dim
func @conv2d_dynamic_batch_and_partially_dynamic_spatial_dim(%arg0: tensor<?x?x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<?x?x32x16xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x?x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @conv2d_dynamic_batch_and_fully_dynamic_spatial_dim
func @conv2d_dynamic_batch_and_fully_dynamic_spatial_dim(%arg0: tensor<?x?x?x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<*xf32> {
// CHECK: "tf.Conv2D"
// CHECK-SAME: -> tensor<?x?x?x16xf32>
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x?x?x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
}

0 comments on commit e7a5816

Please sign in to comment.