Skip to content

Commit

Permalink
[StableHLO] Make reduce lowering more robust (iree-org#14046)
Browse files Browse the repository at this point in the history
Check if reduce ops are supported. This is so that these patterns can be
given any reduce, even those that would be normally folded away by canon
patterns.

Issue: iree-org#14042
Issue: iree-org#12678
  • Loading branch information
kuhar authored and nhasabni committed Aug 24, 2023
1 parent 59c3fe1 commit b4101c5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,31 @@

#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h"
#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
namespace {
/// Returns true when reduction `op` is not supported and should be filtered
/// out.
static bool isUnsupported(mlir::stablehlo::ReduceOp op) {
// Empty reductions are not supported. We expect canonicalization patterns to
// handle them.
if (op.getDimensions().empty()) return true;

// We require all reduce shapes to be the same, up to the element types, so
// we can just the first operand and the first result as a representative.
if (auto inputTy =
dyn_cast<RankedTensorType>(op.getInputs().getType().front())) {
return llvm::is_contained(inputTy.getShape(), 0);
}

return false;
}

/// Returns a permutation AffineMap that puts all reduction dimensions to the
/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
/// if `rank` is 4 and `reductionDims` is {1, 3}, then
Expand Down Expand Up @@ -85,6 +103,11 @@ struct ReduceOpToGenericConverter final
LogicalResult matchAndRewrite(
mlir::stablehlo::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isUnsupported(op)) {
return rewriter.notifyMatchFailure(op,
"unsupported reduce (noop or empty)");
}

Location loc = op.getLoc();

int numOperands = static_cast<int>(adaptor.getInputs().size());
Expand Down Expand Up @@ -154,11 +177,11 @@ struct ReduceOpToGenericConverter final
rewriter.inlineRegionBefore(op.getBody(), region, region.end());
TypeConverter::SignatureConversion signatureConverter(numOperands * 2);

// The mhlo ReduceOp requires that the seed be used as a LHS operand inside
// the region, and the seed is encoded in linalg in the intial out value, so
// modify the signature of the block and the value mappings, so the output
// args will correlate with the original LHS and the inputs correlate with
// the original RHS.
// The stablehlo ReduceOp requires that the seed be used as a LHS operand
// inside the region, and the seed is encoded in linalg in the initial out
// value, so modify the signature of the block and the value mappings, so
// the output args will correlate with the original LHS and the inputs
// correlate with the original RHS.
for (auto [idx, val] : llvm::enumerate(op.getInputs())) {
signatureConverter.addInputs(
/*origInputNo=*/idx + numOperands,
Expand Down Expand Up @@ -188,6 +211,11 @@ struct ReduceOpToReduceConverter final
LogicalResult matchAndRewrite(
mlir::stablehlo::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isUnsupported(op)) {
return rewriter.notifyMatchFailure(op,
"unsupported reduce (noop or empty)");
}

auto reductionDims =
llvm::to_vector(op.getDimensions().getValues<int64_t>());
// stablehlo.reduce doesn't specify the order of the reduction dimensions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,42 @@ func.func @variadic_diff_type_reduce(%arg0: tensor<128x10xf32>, %arg1: tensor<12

// -----

// Make sure we do not crash on unsupported reductions.

// CHECK-LABEL: func.func @reduce_noop
// CHECK: stablehlo.reduce
// CHECK-PRIMITIVE-LABEL: func.func @reduce_noop
// CHECK-PRIMITIVE: stablehlo.reduce
func.func @reduce_noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) {
%4 = stablehlo.add %arg1, %arg2 : tensor<f32>
stablehlo.return %4 : tensor<f32>
}
func.return %1 : tensor<4x8xf32>
}

// CHECK-LABEL: func.func @reduce_zero_ext
// CHECK: stablehlo.reduce
// CHECK-PRIMITIVE-LABEL: func.func @reduce_zero_ext
// CHECK-PRIMITIVE: stablehlo.reduce
func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor<i32> {
%0 = stablehlo.constant dense<false> : tensor<i1>
%1 = stablehlo.constant dense<false> : tensor<0xi1>
%2 = stablehlo.compare NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
%3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32>
%4 = stablehlo.constant dense<0> : tensor<i32>
%5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor<i32>) -> tensor<i32>
reducer(%arg1: tensor<i32>, %arg2: tensor<i32>) {
%6 = stablehlo.add %arg1, %arg2 : tensor<i32>
stablehlo.return %6 : tensor<i32>
}
return %5 : tensor<i32>
}

// -----

// CHECK-LABEL: func @reduce_window_min_nhwc
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
Expand Down

0 comments on commit b4101c5

Please sign in to comment.