Skip to content

Commit

Permalink
[StableHLO] Add broadcast_in_dim canon patterns (iree-org#13746)
Browse files Browse the repository at this point in the history
These are based on the equivalent fold and canonicalizer from MHLO.

Also run stablehlo canonicalization pass between other lowering passes. 

Issue: iree-org#12678
  • Loading branch information
kuhar authored and nhasabni committed Aug 24, 2023
1 parent 5dd4065 commit e2f3b70
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ void registerStableHLOConversionPassPipeline() {
void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,
bool detuple) {
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
passManager.addNestedPass<func::FuncOp>(mlir::createCSEPass());
passManager.addNestedPass<func::FuncOp>(
stablehlo::createLegalizeControlFlow());

Expand All @@ -68,7 +70,7 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,

passManager.addNestedPass<func::FuncOp>(
createStableHLOToStableHLOPreprocessing());
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());

// Various shape functions may have been materialized in the `shape.shape_of`
// style of treating shapes as tensors. We prefer to legalize these to
Expand All @@ -77,6 +79,7 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,
passManager.addNestedPass<func::FuncOp>(createShapeToShapeLowering());
passManager.addPass(createConvertShapeToStandardPass());
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());

// We also don't handle calls well on the old codepath; until we remove the
// use of the CFG we can continue inlining.
Expand All @@ -99,6 +102,7 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,
// Perform initial cleanup. createLegalizeInputTypes could rewrite types. In
// this context, some operations could be folded away.
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
passManager.addNestedPass<func::FuncOp>(mlir::createCSEPass());

// Convert to Linalg. After this point, StableHLO will be eliminated.
Expand All @@ -113,6 +117,7 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,
// Note that some StableHLO ops are left by the above and must resolve via
// canonicalization. See comments in the above pass and find a better way.
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());

passManager.addPass(stablehlo::createVerifyCompilerStableHloInputLegality());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
Expand All @@ -32,6 +33,84 @@ namespace {
// allowed to materialize as new constants.
constexpr int64_t kFoldOpEltLimit = 65536;

static bool isIotaRange(ElementsAttr attr) {
auto elems = attr.tryGetValues<APInt>();
if (!elems) return false;

for (auto [idx, value] : llvm::enumerate(*elems)) {
if (idx != value) {
return false;
}
}

return true;
}

struct BroadcastInDimOpCanon final
: OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type) return failure();

Value operand = op.getOperand();
auto operandTy = dyn_cast<RankedTensorType>(operand.getType());
if (!operandTy) return failure();

// Fold when broadcast is a noop.
DenseIntElementsAttr dims = op.getBroadcastDimensions();
bool isDimsIota = isIotaRange(dims);
if (type == operandTy && isDimsIota) {
rewriter.replaceOp(op, operand);
return success();
}

// Handle splat broadcasts.
if (SplatElementsAttr cstAttr;
matchPattern(operand, m_Constant(&cstAttr))) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
op, SplatElementsAttr::get(op.getType(),
cstAttr.getSplatValue<Attribute>()));
return success();
}

auto bsDimIndices = dims.getValues<int64_t>();
if (operandTy.hasStaticShape() && type.hasStaticShape() &&
type.getNumElements() == operandTy.getNumElements()) {
// BroadcastInDim equivalent to reshape.
if (isDimsIota) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, type,
operand);
return success();
}
// BroadcastInDim equivalent to transpose.
if (type.getRank() == operandTy.getRank()) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::TransposeOp>(
op, type, operand, dims);
return success();
}
}

// Eliminate redundant nested BroadcastInDim.
if (auto broadcastInDimOp =
operand.getDefiningOp<mlir::stablehlo::BroadcastInDimOp>()) {
auto newIndices = cast<DenseIntElementsAttr>(
broadcastInDimOp.getBroadcastDimensions().mapValues(
dims.getElementType(), [&bsDimIndices](const APInt &dim) {
return APInt(dim.getBitWidth(),
bsDimIndices[dim.getSExtValue()], true);
}));
rewriter.replaceOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>(
op, type, broadcastInDimOp.getOperand(), newIndices);
return success();
}

return failure();
}
};

struct ConcatenateOpCanon final
: OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -206,15 +285,7 @@ struct TransposeOpCanon final : OpRewritePattern<mlir::stablehlo::TransposeOp> {
LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op,
PatternRewriter &rewriter) const override {
// Check if this transpose is a noop and use the operand instead.
auto dims = op.getPermutation().tryGetValues<APInt>();
if (failed(dims)) return failure();

// Check if dims is an iota range.
for (auto [idx, dim] : llvm::enumerate(*dims)) {
if (idx != dim.getLimitedValue()) {
return failure();
}
}
if (!isIotaRange(op.getPermutation())) return failure();

rewriter.replaceOp(op, op.getOperand());
return success();
Expand All @@ -238,9 +309,9 @@ struct StableHLOCanonicalize final
void populateCanonicalizationPatterns(MLIRContext *context,
RewritePatternSet *patterns,
PatternBenefit benefit) {
patterns->add<ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon,
GetTupleElementOpCanon, RealOpCanon, ImagOpCanon,
GetDimensionSizeOpCanon, ReshapeOpCanon, TransposeOpCanon>(
context, benefit);
patterns->add<BroadcastInDimOpCanon, ConcatenateOpCanon, ConvertOpCanon,
DynamicReshapeOpCanon, GetTupleElementOpCanon, RealOpCanon,
ImagOpCanon, GetDimensionSizeOpCanon, ReshapeOpCanon,
TransposeOpCanon>(context, benefit);
}
} // namespace mlir::iree_compiler::stablehlo
Original file line number Diff line number Diff line change
@@ -1,5 +1,38 @@
// RUN: iree-opt --iree-stablehlo-canonicalize --split-input-file %s | FileCheck %s

// CHECK-LABEL: func.func @broadcast_in_dim
// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>)
func.func @broadcast_in_dim(%arg0: tensor<3x3xi32>)
-> (tensor<6xi32>, tensor<3xf32>, tensor<3x3xi32>, tensor<3x3xi32>, tensor<3x3xi32>, tensor<3x3x1xi32>, tensor<3x2x3x3xi32>) {
%c0 = stablehlo.constant dense<5> : tensor<i32>
%c1 = stablehlo.constant dense<3.0> : tensor<f32>
%c2 = stablehlo.constant dense<1> : tensor<1x3xi32>

%0 = stablehlo.broadcast_in_dim %c0, dims = [] : (tensor<i32>) -> tensor<6xi32>
%1 = stablehlo.broadcast_in_dim %c1, dims = [] : (tensor<f32>) -> tensor<3xf32>
%2 = stablehlo.broadcast_in_dim %c2, dims = [1, 0] : (tensor<1x3xi32>) -> tensor<3x3xi32>

%3 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<3x3xi32>) -> tensor<3x3xi32>
%4 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<3x3xi32>) -> tensor<3x3xi32>
%5 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<3x3xi32>) -> tensor<3x3x1xi32>

%6 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<3x3xi32>) -> tensor<3x3x2xi32>
%7 = stablehlo.broadcast_in_dim %6, dims = [0, 2, 1] : (tensor<3x3x2xi32>) -> tensor<3x2x3x3xi32>

// CHECK-DAG: [[R0:%.+]] = stablehlo.constant dense<5> : tensor<6xi32>
// CHECK-DAG: [[R1:%.+]] = stablehlo.constant dense<3.000000e+00> : tensor<3xf32>
// CHECK-DAG: [[R2:%.+]] = stablehlo.constant dense<1> : tensor<3x3xi32>

// CHECK-DAG: [[R4:%.+]] = stablehlo.transpose [[ARG0]], dims = [1, 0] : (tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK-DAG: [[R5:%.+]] = stablehlo.reshape [[ARG0]] : (tensor<3x3xi32>) -> tensor<3x3x1xi32>
// CHECK-DAG: [[R6:%.+]] = stablehlo.broadcast_in_dim [[ARG0]], dims = [2, 0] : (tensor<3x3xi32>) -> tensor<3x2x3x3xi32>

// CHECK-NEXT: return [[R0]], [[R1]], [[R2]], [[ARG0]], [[R4]], [[R5]], [[R6]]
return %0, %1, %2, %3, %4, %5, %7 : tensor<6xi32>, tensor<3xf32>, tensor<3x3xi32>, tensor<3x3xi32>, tensor<3x3xi32>, tensor<3x3x1xi32>, tensor<3x2x3x3xi32>
}

// -----

// CHECK-LABEL: func.func @concatenate
func.func @concatenate() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32>) {
%c0 = stablehlo.constant dense<[0, 1]> : tensor<2xi32>
Expand All @@ -17,7 +50,7 @@ func.func @concatenate() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tens
%3 = stablehlo.concatenate %c3, %c5, dim = 1 : (tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x5xi32>

// CHECK-DAG: [[R0:%.+]] = stablehlo.constant dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32>
// CHECK-DAG: [[R1:%.+]] = stablehlo.constant dense<[0, 1, 5]> : tensor<3xi32
// CHECK-DAG: [[R1:%.+]] = stablehlo.constant dense<[0, 1, 5]> : tensor<3xi32>
// CHECK-DAG: [[R2:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\]\]}}> : tensor<3x3xi32>
// CHECK-DAG: [[R3:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2, 11, 12\], \[3, 4, 5, 13, 14\]\]}}> : tensor<2x5xi32>
// CHECK-NEXT: return [[R0]], [[R1]], [[R2]], [[R3]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ func.func @batch_norm_inference(
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<4x256xf32>) {
// CHECK-DAG: %[[EPS:.+]] = stablehlo.constant dense<1.001000e-05> : tensor<f32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = stablehlo.broadcast_in_dim %[[EPS]], dims = [] : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = stablehlo.constant dense<1.001000e-05> : tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = stablehlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = stablehlo.sqrt %[[VARIANCE_EPS]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = stablehlo.broadcast_in_dim %[[STDDEV]], dims = [1] : (tensor<256xf32>) -> tensor<4x256xf32>
Expand Down

0 comments on commit e2f3b70

Please sign in to comment.