Skip to content

Commit

Permalink
Simplify arith.minsi and maxsi.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636495600
  • Loading branch information
jreiffers authored and tensorflower-gardener committed May 23, 2024
1 parent 3740048 commit 146ed63
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 26 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ cc_library(
"//xla/service/gpu/fusions/mlir:computation_partitioner",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
"//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
"//xla/service/gpu/fusions/mlir/ir:xla_gpu",
"//xla/service/gpu/model:indexing_analysis",
"//xla/service/gpu/model:indexing_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DataLayoutInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TensorDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,8 @@ MlirFusionEmitterBase::CreateLLVMModule(
// simplify-affine has maximally folded expressions to work with.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<mlir::func::FuncOp>(CreateSimplifyArithPass());
pm.addPass(CreateSimplifyAffinePass());
// Replace comparisons that result in constant values (e.g. due to ranges not
// overlapping). This pass must run after SimplifyAffinePass, since that
// generates the range information.
pm.addPass(CreateSimplifyArithPass());

// simplify-affine lowers most affine.apply ops, but if it can't prove a
// division or modulo is unsigned, affine.apply ops will remain.
Expand Down Expand Up @@ -442,7 +439,7 @@ MlirFusionEmitterBase::CreateMLIRModule(

// Run a minimal simplification pipeline.
mlir::PassManager pm(&context);
pm.addPass(CreateSimplifyArithPass());
pm.addNestedPass<mlir::func::FuncOp>(CreateSimplifyArithPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
// We won't dump the trace here if the pipeline fails. This is acceptable,
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/mlir/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def MergePointersToSameSlicePass :
let constructor = "CreateMergePointersToSameSlicePass()";
}

def SimplifyArithPass : Pass<"xla-gpu-simplify-arith", "mlir::ModuleOp"> {
def SimplifyArithPass : Pass<"xla-gpu-simplify-arith", "mlir::func::FuncOp"> {
let summary = "Simplifies arith using XLA's range-aware simplifier.";

let description = [{
Expand All @@ -98,7 +98,8 @@ def SimplifyArithPass : Pass<"xla-gpu-simplify-arith", "mlir::ModuleOp"> {
}];

let dependentDialects = [
"mlir::arith::ArithDialect"
"mlir::arith::ArithDialect",
"mlir::func::FuncDialect",
];

let constructor = "CreateSimplifyArithPass()";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,13 @@ std::optional<Interval> GetRange(mlir::Value value) {
return {{values[0].getSExtValue(), values[1].getSExtValue()}};
};

if (value.getDefiningOp()) {
if (auto apply = value.getDefiningOp<ApplyIndexingOp>()) {
return apply.getIndexingMap().GetRangeEvaluator().ComputeExpressionRange(
apply.getIndexingMap().GetAffineMap().getResult(
mlir::cast<mlir::OpResult>(value).getResultNumber()));
} else if (auto cst = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
return {{cst.value(), cst.value()}};
} else if (value.getDefiningOp()) {
return attr_to_range(value.getDefiningOp()->getAttr("xla.range"));
}

Expand Down
113 changes: 100 additions & 13 deletions third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdint>
#include <functional>
#include <limits>
Expand All @@ -27,6 +28,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
#include "xla/service/gpu/fusions/mlir/passes.h"
#include "xla/service/gpu/model/indexing_map.h"

Expand Down Expand Up @@ -65,30 +67,115 @@ struct RewriteCmpI : mlir::OpRewritePattern<mlir::arith::CmpIOp> {

mlir::LogicalResult matchAndRewrite(
mlir::arith::CmpIOp op, mlir::PatternRewriter& rewriter) const override {
// We don't need to support constants on the LHS, since comparisons are
// canonicalized to have them on the RHS.
auto rhs = mlir::getConstantIntValue(op.getRhs());
auto rhs = GetRange(op.getRhs());
auto lhs = GetRange(op.getLhs());
if (lhs && rhs) {
Interval::ComparisonResult result =
EvaluateCmpI(op.getPredicate(), *lhs, {*rhs, *rhs});
if (result != std::nullopt) {
rewriter.replaceOpWithNewOp<mlir::arith::ConstantIntOp>(
op, *result, rewriter.getI1Type());
return mlir::success();
}
if (!lhs || !rhs) {
return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
}
Interval::ComparisonResult result =
EvaluateCmpI(op.getPredicate(), *lhs, *rhs);
if (result != std::nullopt) {
rewriter.replaceOpWithNewOp<mlir::arith::ConstantIntOp>(
op, *result, rewriter.getI1Type());
return mlir::success();
}
// TODO(jreiffers): Consider supporting ranges on the RHS as well.
return rewriter.notifyMatchFailure(op, "not a constant result");
}
};

struct RewriteMaxSi : mlir::OpRewritePattern<mlir::arith::MaxSIOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(
mlir::arith::MaxSIOp op, mlir::PatternRewriter& rewriter) const override {
auto lhs = GetRange(op.getLhs());
auto rhs = GetRange(op.getRhs());
if (!lhs || !rhs) {
return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
}
if (auto lhs_ge_rhs = *lhs >= *rhs; lhs_ge_rhs == true) {
rewriter.replaceOp(op, op.getLhs());
} else if (auto rhs_ge_lhs = *rhs >= *lhs; rhs_ge_lhs == true) {
rewriter.replaceOp(op, op.getRhs());
} else {
return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
}
return mlir::success();
}
};

struct RewriteMinSi : mlir::OpRewritePattern<mlir::arith::MinSIOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(
mlir::arith::MinSIOp op, mlir::PatternRewriter& rewriter) const override {
auto lhs = GetRange(op.getLhs());
auto rhs = GetRange(op.getRhs());
if (!lhs || !rhs) {
return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
}
if (auto lhs_le_rhs = *lhs <= *rhs; lhs_le_rhs == true) {
rewriter.replaceOp(op, op.getLhs());
} else if (auto rhs_le_lhs = *rhs <= *lhs; rhs_le_lhs == true) {
rewriter.replaceOp(op, op.getRhs());
} else {
return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
}
return mlir::success();
}
};

void AnnotateRanges(mlir::func::FuncOp func) {
func->walk([](mlir::Operation* op) {
if (op->getNumResults() != 1) {
return;
}

auto result = op->getResult(0);
if (GetRange(result).has_value()) {
return;
}

auto get_range = [](mlir::Value value) -> Interval {
auto range = GetRange(value);
if (range) {
return *range;
}
return {std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max()};
};

std::optional<Interval> out_range = std::nullopt;
if (mlir::isa<mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
mlir::arith::AddIOp, mlir::arith::MulIOp>(op)) {
auto lhs_range = get_range(op->getOperand(0));
auto rhs_range = get_range(op->getOperand(1));
if (mlir::isa<mlir::arith::MaxSIOp>(op)) {
out_range = lhs_range.max(rhs_range);
} else if (mlir::isa<mlir::arith::MinSIOp>(op)) {
out_range = lhs_range.min(rhs_range);
} else if (mlir::isa<mlir::arith::AddIOp>(op)) {
out_range = lhs_range + rhs_range;
} else {
out_range = lhs_range * rhs_range;
}
}

if (out_range) {
mlir::OpBuilder b(op);
op->setAttr("xla.range",
b.getIndexArrayAttr({out_range->lower, out_range->upper}));
}
});
}

class SimplifyArithPass
: public impl::SimplifyArithPassBase<SimplifyArithPass> {
public:
void runOnOperation() override {
mlir::RewritePatternSet patterns(&getContext());
patterns.add<RewriteCmpI>(&getContext());
AnnotateRanges(getOperation());
patterns.add<RewriteCmpI, RewriteMaxSi, RewriteMinSi>(&getContext());
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,98 @@ module {
module {
func.func @both_range(%arg0: index {xla.range = [12 : index, 42 : index]},
%arg1: index {xla.range = [63 : index, 100 : index]}) -> i1 {
// This is true, but we don't support it yet.
%eq = arith.cmpi slt, %arg0, %arg1 : index
return %eq : i1
}
}

// CHECK: @both_range
// CHECK-NEXT: cmpi
// CHECK-NEXT: return
// CHECK-LABEL: @both_range
// CHECK-NEXT: constant true
// CHECK-NEXT: return

// -----

module {
func.func @minsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
%arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
%min = arith.minsi %arg0, %arg1 : index
return %min : index
}
}

// CHECK-LABEL: @minsi_lhs
// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
// CHECK-NEXT: return %[[ARG0]]

// -----

module {
func.func @minsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
%arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
%min = arith.minsi %arg1, %arg0 : index
return %min : index
}
}

// CHECK-LABEL: @minsi_rhs
// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
// CHECK-NEXT: return %[[ARG0]]

// -----

module {
func.func @maxsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
%arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
%min = arith.maxsi %arg1, %arg0 : index
return %min : index
}
}

// CHECK-LABEL: @maxsi_lhs
// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
// CHECK-NEXT: return %[[ARG1]]

// -----

module {
func.func @maxsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
%arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
%min = arith.maxsi %arg0, %arg1 : index
return %min : index
}
}

// CHECK-LABEL: @maxsi_rhs
// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
// CHECK-NEXT: return %[[ARG1]]

// -----

module {
func.func @maxsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
%arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
%add = arith.addi %arg0, %arg1 : index
%min = arith.maxsi %add, %arg1 : index
return %min : index
}
}

// CHECK-LABEL: @maxsi_add
// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[ARG1]]
// CHECK-NEXT: return %[[ADD]]

// -----

module {
func.func @minsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
%arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
%add = arith.addi %arg0, %arg1 : index
%min = arith.minsi %add, %arg1 : index
return %min : index
}
}

// CHECK-LABEL: @minsi_add
// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
// CHECK-NEXT: return %[[ARG1]]

0 comments on commit 146ed63

Please sign in to comment.