Skip to content

Commit

Permalink
[StableHLO] Add canonicalization patterns for add, sub, and mul (iree…
Browse files Browse the repository at this point in the history
…-org#13755)

These should cover the same rewrites as the equivalent
folds/canonicalization functions from MHLO, but the implementation is
not the same.

Issue: iree-org#12678
  • Loading branch information
kuhar authored and nhasabni committed Aug 24, 2023
1 parent d3228ff commit cad373f
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 4 deletions.
1 change: 1 addition & 0 deletions build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, repo_map: Dict[str, str]):

# MLIR
"@llvm-project//mlir:AllPassesAndDialects": ["MLIRAllDialects"],
"@llvm-project//mlir:CommonFolders": [""],
"@llvm-project//mlir:DialectUtils": [""],
"@llvm-project//mlir:GPUDialect": ["MLIRGPUOps"],
"@llvm-project//mlir:GPUTransforms": ["MLIRGPUTransforms"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ iree_compiler_cc_library(
":PassHeaders",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@

#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.h"
#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h"
#include "llvm/ADT/APFloat.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"

Expand Down Expand Up @@ -46,6 +49,193 @@ static bool isIotaRange(ElementsAttr attr) {
return true;
}

/// Matches when either of the submatchers match.
template <typename MatcherA, typename MatcherB>
struct m_AnyOf {
m_AnyOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {}

bool match(Operation *op) { return matcherA.match(op) || matcherB.match(op); }

MatcherA matcherA;
MatcherB matcherB;
};

template <typename MatcherA, typename MatcherB>
m_AnyOf(MatcherA, MatcherB) -> m_AnyOf<MatcherA, MatcherB>;

/// Binary constant folder that used a generic folder function to handle both
/// ints and floats.
template <typename Fn>
static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs,
Fn &&folder) {
Attribute operands[2] = {lhs, rhs};
Type elemTy = getElementTypeOrSelf(cast<TypedAttr>(lhs).getType());

if (isa<IntegerType>(elemTy)) {
if (Attribute res = constFoldBinaryOp<IntegerAttr>(
operands, [&folder](const APInt &lhs, const APInt &rhs) {
return folder(lhs, rhs);
})) {
return cast<TypedAttr>(res);
}
return nullptr;
}

if (isa<FloatType>(elemTy)) {
if (Attribute res = constFoldBinaryOp<FloatAttr>(
operands, [&folder](const APFloat &lhs, const APFloat &rhs) {
return folder(lhs, rhs);
})) {
return cast<TypedAttr>(res);
}
return nullptr;
}

return nullptr;
}

struct AddOpCanon final : OpRewritePattern<mlir::stablehlo::AddOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type) return failure();

Value lhs = op.getLhs();
Value rhs = op.getRhs();

if (matchPattern(lhs, m_Zero())) {
rewriter.replaceOp(op, rhs);
return success();
}

if (matchPattern(rhs, m_AnyOf(m_Zero(), m_NegZeroFloat()))) {
rewriter.replaceOp(op, lhs);
return success();
}

TypedAttr lhsAttr;
matchPattern(lhs, m_Constant(&lhsAttr));

TypedAttr rhsAttr;
matchPattern(rhs, m_Constant(&rhsAttr));

// The canonical form has the constant operand as the RHS.
if (isa<IntegerType>(type.getElementType()) && lhsAttr && !rhsAttr) {
rewriter.updateRootInPlace(op, [op, lhs, rhs] {
op->setOperands(ValueRange{rhs, lhs});
});
return success();
}

if (lhsAttr && rhsAttr) {
if (TypedAttr res =
foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::plus<>{})) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
return success();
}
}

return failure();
}
};

struct SubtractOpCanon final : OpRewritePattern<mlir::stablehlo::SubtractOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type) return failure();

Value lhs = op.getLhs();
Value rhs = op.getRhs();

if (isa<IntegerType>(type.getElementType()) && lhs == rhs) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
op, rewriter.getZeroAttr(op.getType()));
return success();
}

// Subtraction of 0.
if (matchPattern(rhs, m_AnyOf(m_Zero(), m_PosZeroFloat()))) {
rewriter.replaceOp(op, lhs);
return success();
}

TypedAttr lhsAttr;
matchPattern(lhs, m_Constant(&lhsAttr));

TypedAttr rhsAttr;
matchPattern(rhs, m_Constant(&rhsAttr));

if (lhsAttr && rhsAttr) {
if (TypedAttr res =
foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::minus<>{})) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
return success();
}
}

return failure();
}
};

struct MulOpCanon final : OpRewritePattern<mlir::stablehlo::MulOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type) return failure();

Value lhs = op.getLhs();
Value rhs = op.getRhs();

// Multiplication by 0. This fold is not trivial for floats in presence of
// NaN values.
if (matchPattern(lhs, m_Zero())) {
rewriter.replaceOp(op, lhs);
return success();
}
if (matchPattern(rhs, m_Zero())) {
rewriter.replaceOp(op, rhs);
return success();
}

// Multiplication by 1.
if (matchPattern(rhs, m_One())) {
rewriter.replaceOp(op, lhs);
return success();
}

TypedAttr lhsAttr;
matchPattern(lhs, m_Constant(&lhsAttr));

TypedAttr rhsAttr;
matchPattern(rhs, m_Constant(&rhsAttr));

// The canonical form has the constant operand as the RHS.
if (isa<IntegerType>(type.getElementType()) && lhsAttr && !rhsAttr) {
rewriter.updateRootInPlace(op, [op, lhs, rhs] {
op->setOperands(ValueRange{rhs, lhs});
});
return success();
}

if (lhsAttr && rhsAttr) {
if (TypedAttr res =
foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{})) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
return success();
}
}

return failure();
}
};

struct BroadcastInDimOpCanon final
: OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -309,9 +499,10 @@ struct StableHLOCanonicalize final
void populateCanonicalizationPatterns(MLIRContext *context,
RewritePatternSet *patterns,
PatternBenefit benefit) {
patterns->add<BroadcastInDimOpCanon, ConcatenateOpCanon, ConvertOpCanon,
DynamicReshapeOpCanon, GetTupleElementOpCanon, RealOpCanon,
ImagOpCanon, GetDimensionSizeOpCanon, ReshapeOpCanon,
TransposeOpCanon>(context, benefit);
patterns->add<AddOpCanon, SubtractOpCanon, MulOpCanon, BroadcastInDimOpCanon,
ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon,
GetTupleElementOpCanon, RealOpCanon, ImagOpCanon,
GetDimensionSizeOpCanon, ReshapeOpCanon, TransposeOpCanon>(
context, benefit);
}
} // namespace mlir::iree_compiler::stablehlo
Original file line number Diff line number Diff line change
@@ -1,5 +1,114 @@
// RUN: iree-opt --iree-stablehlo-canonicalize --split-input-file %s | FileCheck %s

// CHECK-LABEL: func.func @add
// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<f32>)
func.func @add(%arg0: tensor<2xi32>, %arg1: tensor<f32>)
-> (tensor<i32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) {
%c0 = stablehlo.constant dense<0> : tensor<i32>
%cn0 = stablehlo.constant dense<-0.0> : tensor<f32>
%c0_2 = stablehlo.constant dense<0> : tensor<2xi32>
%c1 = stablehlo.constant dense<5> : tensor<i32>
%c2 = stablehlo.constant dense<3.0> : tensor<f32>
%c3 = stablehlo.constant dense<[1, 2]> : tensor<2xi32>

%0 = stablehlo.add %c0, %c1 : tensor<i32>
%1 = stablehlo.add %c1, %c1 : tensor<i32>
%2 = stablehlo.add %c2, %c2 : tensor<f32>
%3 = stablehlo.add %arg1, %cn0 : tensor<f32>

%4 = stablehlo.add %c0_2, %arg0 : tensor<2xi32>
%5 = stablehlo.add %c3, %arg0 : tensor<2xi32>
%6 = stablehlo.add %c3, %c3 : tensor<2xi32>

// CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<5> : tensor<i32>
// CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense<10> : tensor<i32>
// CHECK-DAG: [[C2:%.+]] = stablehlo.constant dense<6.000000e+00> : tensor<f32>
// CHECK-DAG: [[C3:%.+]] = stablehlo.constant dense<[2, 4]> : tensor<2xi32>
// CHECK-DAG: [[C4:%.+]] = stablehlo.constant dense<[1, 2]> : tensor<2xi32>

// CHECK-DAG: [[A0:%.+]] = stablehlo.add [[ARG0]], [[C4]] : tensor<2xi32>

// CHECK-NEXT: return [[C0]], [[C1]], [[C2]], [[ARG1]], [[ARG0]], [[A0]], [[C3]]
return %0, %1, %2, %3, %4, %5, %6 : tensor<i32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>
}

// -----

// CHECK-LABEL: func.func @subtract
// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<f32>)
func.func @subtract(%arg0: tensor<2xi32>, %arg1: tensor<f32>)
-> (tensor<i32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<2xi32>, tensor<2xi32>) {
%c0 = stablehlo.constant dense<0> : tensor<i32>
%cp0 = stablehlo.constant dense<0.0> : tensor<f32>
%c0_2 = stablehlo.constant dense<0> : tensor<2xi32>
%c1 = stablehlo.constant dense<5> : tensor<i32>
%c2 = stablehlo.constant dense<3.0> : tensor<f32>
%c3 = stablehlo.constant dense<[1, 2]> : tensor<2xi32>
%c4 = stablehlo.constant dense<4> : tensor<i32>
%c5 = stablehlo.constant dense<[0, 1]> : tensor<2xi32>

%0 = stablehlo.subtract %c1, %c0 : tensor<i32>
%1 = stablehlo.subtract %c1, %c4 : tensor<i32>

%2 = stablehlo.subtract %arg1, %cp0 : tensor<f32>
%3 = stablehlo.subtract %arg1, %arg1 : tensor<f32>

%4 = stablehlo.subtract %arg0, %arg0 : tensor<2xi32>

%5 = stablehlo.subtract %c3, %c5 : tensor<2xi32>

// CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<5> : tensor<i32>
// CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense<1> : tensor<i32>
// CHECK-DAG: [[C2:%.+]] = stablehlo.constant dense<0> : tensor<2xi32>
// CHECK-DAG: [[C3:%.+]] = stablehlo.constant dense<1> : tensor<2xi32>

// CHECK-DAG: [[S0:%.+]] = stablehlo.subtract [[ARG1]], [[ARG1]] : tensor<f32>

// CHECK-NEXT: return [[C0]], [[C1]], [[ARG1]], [[S0]], [[C2]], [[C3]]
return %0, %1, %2, %3, %4, %5 : tensor<i32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<2xi32>, tensor<2xi32>
}

// -----

// CHECK-LABEL: func.func @multiply
// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<f32>)
func.func @multiply(%arg0: tensor<2xi32>, %arg1: tensor<f32>)
-> (tensor<i32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) {
%c0 = stablehlo.constant dense<0> : tensor<i32>
%cp0 = stablehlo.constant dense<0.0> : tensor<f32>
%c0_2 = stablehlo.constant dense<0> : tensor<2xi32>
%c1 = stablehlo.constant dense<5> : tensor<i32>
%c2 = stablehlo.constant dense<3.0> : tensor<f32>
%c3 = stablehlo.constant dense<[1, 2]> : tensor<2xi32>
%c4 = stablehlo.constant dense<4> : tensor<i32>
%c5 = stablehlo.constant dense<1> : tensor<2xi32>

%0 = stablehlo.multiply %c1, %c0 : tensor<i32>
%1 = stablehlo.multiply %c4, %c4 : tensor<i32>

%2 = stablehlo.multiply %arg1, %cp0 : tensor<f32>
%3 = stablehlo.multiply %c2, %c2 : tensor<f32>

%4 = stablehlo.multiply %arg0, %c0_2 : tensor<2xi32>
%5 = stablehlo.multiply %arg0, %c5 : tensor<2xi32>
%6 = stablehlo.multiply %c3, %arg0 : tensor<2xi32>

// CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<0> : tensor<i32>
// CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense<16> : tensor<i32>
// CHECK-DAG: [[C2:%.+]] = stablehlo.constant dense<9.000000e+00> : tensor<f32>
// CHECK-DAG: [[C3:%.+]] = stablehlo.constant dense<0> : tensor<2xi32>
// CHECK-DAG: [[C4:%.+]] = stablehlo.constant dense<[1, 2]> : tensor<2xi32>
// CHECK-DAG: [[CP0:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>

// CHECK-DAG: [[M0:%.+]] = stablehlo.multiply [[ARG1]], [[CP0]] : tensor<f32>
// CHECK-DAG: [[M1:%.+]] = stablehlo.multiply [[ARG0]], [[C4]] : tensor<2xi32>

// CHECK-NEXT: return [[C0]], [[C1]], [[M0]], [[C2]], [[C3]], [[ARG0]], [[M1]]
return %0, %1, %2, %3, %4, %5, %6 : tensor<i32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>
}

// -----

// CHECK-LABEL: func.func @broadcast_in_dim
// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>)
func.func @broadcast_in_dim(%arg0: tensor<3x3xi32>)
Expand Down

0 comments on commit cad373f

Please sign in to comment.