Skip to content

Commit

Permalink
Extend fuse_convolution_pass to support dynamic cases.
Browse files Browse the repository at this point in the history
Change the op order of the pattern produced by unfuse_mhlo_batch_norm.

PiperOrigin-RevId: 609201538
  • Loading branch information
tensorflower-gardener committed Feb 22, 2024
1 parent 1b8632b commit cb5bd00
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 94 deletions.
3 changes: 0 additions & 3 deletions tensorflow/compiler/mlir/lite/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,11 @@ cc_library(
deps = [
":passes_inc_gen",
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@local_xla//xla/mlir_hlo",
],
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
load("//tensorflow:tensorflow.default.bzl", "filegroup")
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,3 @@ func.func @fuseMulAndConv2D(%input: tensor<1x256x256x3xf32>) -> (tensor<1x256x25
// CHECK-DAG: return %[[RESULT]]
func.return %2 : tensor<1x256x256x2xf32>
}

// -----

// CHECK-LABEL: @fuseMulAndConv2DDynamic
// CHECK-SAME: %[[INPUT:[^:[:space:]]+]]
func.func @fuseMulAndConv2DDynamic(%input: tensor<?x256x256x3xf32>) -> (tensor<?x256x256x2xf32>) {
// CHECK-DAG: %[[FILTER:.+]] = mhlo.constant dense<{{\[\[\[\[}}1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00], [5.000000e+00, 6.000000e+00]]]]> : tensor<1x1x3x2xf32>
// CHECK-DAG: %[[CST_0:.+]] = mhlo.constant dense<[1.000000e-01, 2.000000e-01]> : tensor<2xf32>
// CHECK-DAG: %[[CST_1:.+]] = mhlo.constant dense<[3.000000e-01, 4.000000e-01]> : tensor<2xf32>
// CHECK: %[[CST_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[CST_0]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>) -> tensor<1x1x3x2xf32>
// CHECK: %[[NEW_FILTER:.+]] = mhlo.multiply %[[CST_BCAST]], %[[FILTER]] : tensor<1x1x3x2xf32>
// CHECK: %[[CONV:.+]] = mhlo.convolution(%[[INPUT]], %[[NEW_FILTER]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<?x256x256x3xf32>, tensor<1x1x3x2xf32>) -> tensor<?x256x256x2xf32>
// CHECK: %[[SHAPE:.+]] = shape.shape_of %[[CONV]] : tensor<?x256x256x2xf32> -> tensor<4xindex>
// CHECK: %[[DYNAMIC_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CST_1]], %[[SHAPE]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>, tensor<4xindex>) -> tensor<?x256x256x2xf32>
// CHECK: %[[ADD:.+]] = mhlo.add %[[CONV]], %[[DYNAMIC_BCAST]] : tensor<?x256x256x2xf32>
%filter = mhlo.constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]]> : tensor<1x1x3x2xf32>
%cst_0 = mhlo.constant dense<[0.1, 0.2]> : tensor<2xf32>
%cst_1 = mhlo.constant dense<[0.3, 0.4]> : tensor<2xf32>
%0 = mhlo.convolution(%input, %filter) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<?x256x256x3xf32>, tensor<1x1x3x2xf32>) -> tensor<?x256x256x2xf32>
%1 = shape.shape_of %0 : tensor<?x256x256x2xf32> -> tensor<4xindex>
%2 = "mhlo.dynamic_broadcast_in_dim"(%cst_0, %1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>, tensor<4xindex>) -> tensor<?x256x256x2xf32>
%3 = mhlo.multiply %0, %2 : tensor<?x256x256x2xf32>
%4 = "mhlo.dynamic_broadcast_in_dim"(%cst_1, %1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>, tensor<4xindex>) -> tensor<?x256x256x2xf32>
%5 = mhlo.add %3, %4 : tensor<?x256x256x2xf32>
// CHECK-DAG: return %[[ADD]]
func.return %5 : tensor<?x256x256x2xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ func.func @batchNormInference_2D_inner_features(
// CHECK-DAG: %[[MUL_MEAN:.+]] = mhlo.multiply %[[MULTIPLIER]], %[[MEAN]] : tensor<256xf32>
// CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor<256xf32>
// CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[RHS_BCAST]] : tensor<4x256xf32>
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} :
Expand Down Expand Up @@ -74,8 +74,8 @@ func.func @batchNormInference_dynamic_shape(
// CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor<?xf32>
// CHECK-DAG: %[[X_SHAPE:.+]] = shape.shape_of %[[X]] : tensor<?x?x?x?xf32> -> tensor<4xindex>
// CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MULTIPLIER]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[RHS]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[RHS_BCAST]] : tensor<?x?x?x?xf32>
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.001 : f32, feature_index = 1 : i64} :
Expand Down Expand Up @@ -145,8 +145,8 @@ func.func @batchNormTraining_4D_middle_features(
// CHECK-DAG: %[[MUL_MEAN:.+]] = mhlo.multiply %[[MULTIPLIER]], %[[MEAN]] : tensor<256xf32>
// CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor<256xf32>
// CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor<3x4x256x6xf32>
// CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor<3x4x256x6xf32>
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[RHS_BCAST]] : tensor<3x4x256x6xf32>
%0:3 = "mhlo.batch_norm_training"(%x, %scale, %offset)
{epsilon = 1.0 : f32, feature_index = 2 : i64} :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <iterator>
#include <memory>
#include <utility>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

namespace mlir {
Expand All @@ -51,8 +44,7 @@ class FuseMhloMulAndConvolutionPattern : public OpRewritePattern<mhlo::MulOp> {
PatternRewriter &rewriter) const override {
// Variables for capturing values and attributes used while creating ops.
mhlo::ConvolutionOp conv_op;
Operation *bcast_or_const_op;
shape::ShapeOfOp shape_of_op;
mhlo::BroadcastInDimOp broadcast_op;
mhlo::ConstantOp filter;
mhlo::ConstantOp multiplier;
mlir::ElementsAttr filter_value, mul_value;
Expand All @@ -69,18 +61,14 @@ class FuseMhloMulAndConvolutionPattern : public OpRewritePattern<mhlo::MulOp> {
if (filter == nullptr) {
return failure();
}
// Try to match static broadcast or dynamic broadcast.
bcast_or_const_op = rhs.getDefiningOp();
bool is_dynamic_broadcast =
isa<mhlo::DynamicBroadcastInDimOp>(bcast_or_const_op);
multiplier = isa<mhlo::ConstantOp>(bcast_or_const_op)
? dyn_cast_or_null<mhlo::ConstantOp>(bcast_or_const_op)
: bcast_or_const_op->getOperand(0)
.getDefiningOp<mhlo::ConstantOp>();
broadcast_op = rhs.getDefiningOp<mhlo::BroadcastInDimOp>();
multiplier =
(broadcast_op == nullptr)
? rhs.getDefiningOp<mhlo::ConstantOp>()
: broadcast_op.getOperand().getDefiningOp<mhlo::ConstantOp>();
if (multiplier == nullptr) {
return failure();
}

auto result_type = OpTrait::util::getBroadcastedType(filter.getType(),
multiplier.getType());
if (!result_type) {
Expand All @@ -102,33 +90,15 @@ class FuseMhloMulAndConvolutionPattern : public OpRewritePattern<mhlo::MulOp> {
"unsupported dimensions";
});
}
if (!is_dynamic_broadcast &&
!((*conv_op.getODSResults(0).begin()).hasOneUse())) {
if (!((*conv_op.getODSResults(0).begin()).hasOneUse())) {
return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) {
diag << "entities 'conv' failed to satisfy constraint: has one use";
});
}
// For dynamic case, the result of conv should be used by shape_of and mul.
if (is_dynamic_broadcast) {
auto conv_uses = (*conv_op.getODSResults(0).begin()).getUses();
if (std::distance(conv_uses.begin(), conv_uses.end()) != 2 ||
quant::FindUserOfType<shape::ShapeOfOp>(conv_op) == nullptr ||
quant::FindUserOfType<mhlo::MulOp>(conv_op) == nullptr) {
return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic
&diag) {
diag << "entities 'conv' failed to satisfy constraint: has two uses "
"for dynamic case";
});
}
}

// Rewrite
// For dynamic case, we use filter's shape to create a static broadcast.
broadcast_dims =
!isa<mhlo::ConstantOp>(bcast_or_const_op) && !is_dynamic_broadcast
? dyn_cast_or_null<mhlo::BroadcastInDimOp>(bcast_or_const_op)
.getBroadcastDimensions()
: nullptr;
broadcast_op ? broadcast_op.getBroadcastDimensions() : nullptr;
if (broadcast_dims == nullptr) {
const auto filter_rank = filter_value.getShapedType().getRank();
auto dimsType = RankedTensorType::get({1}, rewriter.getIntegerType(64));
Expand All @@ -145,26 +115,7 @@ class FuseMhloMulAndConvolutionPattern : public OpRewritePattern<mhlo::MulOp> {
conv_op.getWindowReversalAttr(), conv_op.getDimensionNumbers(),
conv_op.getFeatureGroupCount(), conv_op.getBatchGroupCount(),
conv_op.getPrecisionConfigAttr());
// For static case, replace the convolution op now.
if (!is_dynamic_broadcast) {
rewriter.replaceOp(mul_op, {new_conv});
} else {
// For dynamic case, create new shape_of op and replace uses.
shape_of_op =
dyn_cast_or_null<mhlo::DynamicBroadcastInDimOp>(bcast_or_const_op)
.getOutputDimensions()
.getDefiningOp<shape::ShapeOfOp>();
// Check if the shape come from the original conv op.
if (!shape_of_op ||
shape_of_op.getArg().getDefiningOp<mhlo::ConvolutionOp>() !=
conv_op) {
return failure();
}
Value new_shape_of = rewriter.create<shape::ShapeOfOp>(
mul_op.getLoc(), shape_of_op.getType(), new_conv);
shape_of_op.replaceAllUsesWith(new_shape_of);
rewriter.replaceOp(mul_op, {new_conv});
}
rewriter.replaceOp(mul_op, {new_conv});

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ class UnfuseBatchNormInferencePattern
auto broadcast_multiplier =
broadcastToFeatureDim(bn_op.getLoc(), input_type, multiplier,
shape_value, feature_dim, rewriter);
auto broadcast_rhs = broadcastToFeatureDim(
bn_op.getLoc(), input_type, rhs, shape_value, feature_dim, rewriter);

// Computes x * multiplier + rhs
Value lhs = rewriter.create<mhlo::MulOp>(bn_op.getLoc(), bn_op.getOperand(),
broadcast_multiplier);
auto broadcast_rhs = broadcastToFeatureDim(
bn_op.getLoc(), input_type, rhs, shape_value, feature_dim, rewriter);
rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, lhs, broadcast_rhs);

return success();
Expand Down

0 comments on commit cb5bd00

Please sign in to comment.