Skip to content

Commit

Permalink
Route demotion flag to Input options (iree-org#13993)
Browse files Browse the repository at this point in the history
Demotion should be configuration via the shared object file. The
currentl flags are frontend specific. Rerouted the passes so it is
configurable via `setFlags` for the libIREECompile.so file.
  • Loading branch information
rsuderman authored and nhasabni committed Aug 24, 2023
1 parent 643210c commit 74d2728
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ namespace mlir::iree_compiler {
namespace {
struct AutoInputConversionPipelinePass final
: AutoInputConversionPipelineBase<AutoInputConversionPipelinePass> {
AutoInputConversionPipelinePass(
const AutoInputConversionPipelineOptions& inputOptions)
: options(inputOptions) {}
void runOnOperation() override;
void getDependentDialects(DialectRegistry& registry) const override;

AutoInputConversionPipelineOptions options;
};

// All the features seen that should be handled during input conversion.
Expand Down Expand Up @@ -154,10 +159,14 @@ void AutoInputConversionPipelinePass::runOnOperation() {
OpPassManager::Nesting::Explicit);
#ifdef IREE_HAVE_MHLO_INPUT
if (features.hasStableHLO && !features.hasMHLO) {
stablehlo::StableHloOptions options;
options.demoteI64ToI32 = demoteI64ToI32;
options.demoteF64ToF32 = demoteF64ToF32;
options.promoteBF16ToF32 = promoteBF16ToF32;
if (features.hasTuples) {
stablehlo::buildStableHLOXLAInputConversionPassPipeline(pm);
stablehlo::buildStableHLOXLAInputConversionPassPipeline(pm, options);
} else {
stablehlo::buildStableHLOInputConversionPassPipeline(pm);
stablehlo::buildStableHLOInputConversionPassPipeline(pm, options);
}
}
if (features.hasMHLO) {
Expand Down Expand Up @@ -201,8 +210,19 @@ void AutoInputConversionPipelinePass::getDependentDialects(
};

#ifdef IREE_HAVE_MHLO_INPUT
appendPipelineDialects(stablehlo::buildStableHLOInputConversionPassPipeline);
appendPipelineDialects(
auto appendStablehloPipelineDialects =
[&registry](function_ref<void(OpPassManager&,
const stablehlo::StableHloOptions& options)>
buildFn) {
const stablehlo::StableHloOptions options;
OpPassManager pm;
buildFn(pm, options);
pm.getDependentDialects(registry);
};

appendStablehloPipelineDialects(
stablehlo::buildStableHLOInputConversionPassPipeline);
appendStablehloPipelineDialects(
stablehlo::buildStableHLOXLAInputConversionPassPipeline);

appendPipelineDialects(MHLO::buildMHLOInputConversionPassPipeline);
Expand All @@ -224,7 +244,13 @@ void AutoInputConversionPipelinePass::getDependentDialects(

std::unique_ptr<OperationPass<ModuleOp>>
createAutoInputConversionPipelinePass() {
return std::make_unique<AutoInputConversionPipelinePass>();
AutoInputConversionPipelineOptions options;
return std::make_unique<AutoInputConversionPipelinePass>(options);
}

std::unique_ptr<OperationPass<ModuleOp>> createAutoInputConversionPipelinePass(
const AutoInputConversionPipelineOptions& options) {
return std::make_unique<AutoInputConversionPipelinePass>(options);
}

} // namespace mlir::iree_compiler
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/InputConversion/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
#ifndef IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES_H_
#define IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES_H_

#include "iree/compiler/InputConversion/Common/PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace iree_compiler {

#define GEN_PASS_DECL
#include "iree/compiler/InputConversion/Common/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Pipelines
//===----------------------------------------------------------------------===//
Expand All @@ -28,6 +32,8 @@ void buildCommonInputConversionPassPipeline(OpPassManager &passManager);

std::unique_ptr<OperationPass<ModuleOp>>
createAutoInputConversionPipelinePass();
std::unique_ptr<OperationPass<ModuleOp>> createAutoInputConversionPipelinePass(
const AutoInputConversionPipelineOptions& options);
std::unique_ptr<OperationPass<ModuleOp>> createIREEImportPublicPass();
std::unique_ptr<OperationPass<ModuleOp>> createImportMLProgramPass();
std::unique_ptr<OperationPass<func::FuncOp>>
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/InputConversion/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def AutoInputConversionPipeline :
conversion to run, then run that conversion.
}];
let constructor = "mlir::iree_compiler::createAutoInputConversionPipelinePass()";
let options = [
Option<"demoteI64ToI32", "iree-autoinput-demote-i64-to-i32", "bool",
/*default=*/"true", "Convert I64 to I32 equivalents">,
Option<"demoteF64ToF32", "iree-autoinput-demote-f64-to-f32", "bool",
/*default=*/"false", "Convert F64 to F32 equivalents">,
Option<"promoteBF16ToF32", "iree-autoinput-demote-bf16-to-f32", "bool",
/*default=*/"false", "Convert BF16 to F32 equivalents">,
];
}

#endif // IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES
45 changes: 16 additions & 29 deletions compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,19 @@ namespace {
} // namespace

namespace {
// TODO(#8745): remove these flags when the -iree-flow-demote-* flags can be
// used without tripping upstream verifier issues.
llvm::cl::opt<bool> clDemoteI64ToI32(
"iree-stablehlo-demote-i64-to-i32",
llvm::cl::desc(
"Converts all StableHLO i64 ops and values into i32 counterparts."),
llvm::cl::init(true));
llvm::cl::opt<bool> clDemoteF64ToF32(
"iree-stablehlo-demote-f64-to-f32",
llvm::cl::desc(
"Converts all StableHLO f64 ops and values into f32 counterparts."),
llvm::cl::init(true));
llvm::cl::opt<bool> clPromoteBF16ToF32(
"iree-stablehlo-promote-bf16-to-f32",
llvm::cl::desc(
"Converts all StableHLO bf16 ops and values into f32 counterparts."),
llvm::cl::init(false));

void registerStableHLOConversionPassPipeline() {
PassPipelineRegistration<> stablehlo(
PassPipelineRegistration<StableHloOptions> stablehlo(
"iree-stablehlo-input-transformation-pipeline",
"Runs the StableHLO IREE flow dialect transformation pipeline",
[](OpPassManager &passManager) {
buildStableHLOInputConversionPassPipeline(passManager);
[](OpPassManager& passManager, const StableHloOptions& options) {
buildStableHLOInputConversionPassPipeline(passManager, options);
});
}

// Prepare HLO for use as an input to the Flow dialect.
void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,
bool detuple) {
void buildStableHLOInputConversionPassPipelineImpl(
OpPassManager& passManager, const StableHloOptions& options, bool detuple) {
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
passManager.addNestedPass<func::FuncOp>(mlir::createCSEPass());
Expand Down Expand Up @@ -88,13 +71,13 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,
// stack. This is often required because of implicit i64 insertion by JAX/HLO
// that we don't want forcing 32-bit embedded devices to support.
// TODO(#8745): remove these and prefer the flow pipeline options instead.
if (clDemoteI64ToI32) {
if (options.demoteI64ToI32) {
passManager.addPass(IREE::Util::createDemoteI64ToI32Pass());
}
if (clDemoteF64ToF32) {
if (options.demoteF64ToF32) {
passManager.addPass(IREE::Util::createDemoteF64ToF32Pass());
}
if (clPromoteBF16ToF32) {
if (options.promoteBF16ToF32) {
passManager.addPass(IREE::Util::createPromoteBF16ToF32Pass());
}

Expand Down Expand Up @@ -123,12 +106,16 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,
}
} // namespace

void buildStableHLOInputConversionPassPipeline(OpPassManager &passManager) {
buildStableHLOInputConversionPassPipelineImpl(passManager, /*detuple=*/false);
void buildStableHLOInputConversionPassPipeline(
OpPassManager& passManager, const StableHloOptions& options) {
buildStableHLOInputConversionPassPipelineImpl(passManager, options,
/*detuple=*/false);
}

void buildStableHLOXLAInputConversionPassPipeline(OpPassManager &passManager) {
buildStableHLOInputConversionPassPipelineImpl(passManager, /*detuple=*/true);
void buildStableHLOXLAInputConversionPassPipeline(
OpPassManager& passManager, const StableHloOptions& options) {
buildStableHLOInputConversionPassPipelineImpl(passManager, options,
/*detuple=*/true);
}

void registerStableHLOConversionPasses() {
Expand Down
12 changes: 10 additions & 2 deletions compiler/src/iree/compiler/InputConversion/StableHLO/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,23 @@ namespace iree_compiler::stablehlo {

std::unique_ptr<TypeConverter> createStableHloToLinalgTypeConverter();

struct StableHloOptions : public PassPipelineOptions<StableHloOptions> {
bool demoteI64ToI32 = true;
bool demoteF64ToF32 = false;
bool promoteBF16ToF32 = false;
};

//===----------------------------------------------------------------------===//
// Pipelines
//===----------------------------------------------------------------------===//

void buildStableHLOInputConversionPassPipeline(OpPassManager &passManager);
void buildStableHLOInputConversionPassPipeline(OpPassManager& passManager,
const StableHloOptions& options);

// Performs input legalization on programs that may have originated from an XLA
// import (or made to interop with it).
void buildStableHLOXLAInputConversionPassPipeline(OpPassManager &passManager);
void buildStableHLOXLAInputConversionPassPipeline(
OpPassManager& passManager, const StableHloOptions& options);

//===----------------------------------------------------------------------===//
// Register all Passes
Expand Down
17 changes: 17 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ void InputDialectOptions::bindOptions(OptionsBinder &binder) {
),
// clang-format on
llvm::cl::cat(category));

#ifdef IREE_HAVE_MHLO_INPUT
binder.opt<bool>(
"iree-input-demote-i64-to-i32", demoteI64ToI32,
llvm::cl::desc("Converts all i64 ops and values into i32 counterparts."),
llvm::cl::cat(category));

binder.opt<bool>(
"iree-input-demote-f64-to-f32", demoteF64ToF32,
llvm::cl::desc("Converts all f64 ops and values into f32 counterparts."),
llvm::cl::cat(category));

binder.opt<bool>(
"iree-input-promote-bf16-to-f32", promoteBF16ToF32,
llvm::cl::desc("Converts all bf16 ops and values into f32 counterparts."),
llvm::cl::cat(category));
#endif
}

void HighLevelOptimizationOptions::bindOptions(OptionsBinder &binder) {
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ struct InputDialectOptions {
};
Type type = Type::auto_detect;

bool demoteI64ToI32 = true;
bool demoteF64ToF32 = true;
bool promoteBF16ToF32 = true;

void bindOptions(OptionsBinder &binder);
using FromFlags = OptionsFromFlags<InputDialectOptions>;
};
Expand Down
16 changes: 13 additions & 3 deletions compiler/src/iree/compiler/Pipelines/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,28 @@ void buildIREEVMTransformPassPipeline(
hooks.pipelineExtensions->extendInputConversionPreprocessingPassPipeline(
passManager, inputOptions.type);
}
AutoInputConversionPipelineOptions autoOptions;

#ifdef IREE_HAVE_MHLO_INPUT
stablehlo::StableHloOptions stablehloOptions;
stablehloOptions.demoteI64ToI32 = inputOptions.demoteI64ToI32;
stablehloOptions.demoteF64ToF32 = inputOptions.demoteF64ToF32;
stablehloOptions.promoteBF16ToF32 = inputOptions.promoteBF16ToF32;
#endif
switch (inputOptions.type) {
case InputDialectOptions::Type::none:
break;
case InputDialectOptions::Type::auto_detect:
passManager.addPass(createAutoInputConversionPipelinePass());
passManager.addPass(createAutoInputConversionPipelinePass(autoOptions));
break;
#ifdef IREE_HAVE_MHLO_INPUT
case InputDialectOptions::Type::stablehlo:
stablehlo::buildStableHLOInputConversionPassPipeline(passManager);
stablehlo::buildStableHLOInputConversionPassPipeline(passManager,
stablehloOptions);
break;
case InputDialectOptions::Type::stablehlo_xla:
stablehlo::buildStableHLOXLAInputConversionPassPipeline(passManager);
stablehlo::buildStableHLOXLAInputConversionPassPipeline(passManager,
stablehloOptions);
break;
case InputDialectOptions::Type::mhlo_legacy:
MHLO::buildMHLOInputConversionPassPipeline(passManager);
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/vulkan_specific/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ iree_check_single_backend_test_suite(
],
compiler_flags = [
"--iree-input-type=stablehlo",
"--iree-stablehlo-demote-i64-to-i32=false",
"--iree-input-demote-i64-to-i32=false",
"--iree-vulkan-target-triple=valhall-unknown-android31",
],
driver = "vulkan",
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/vulkan_specific/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ iree_check_single_backend_test_suite(
"vulkan"
COMPILER_FLAGS
"--iree-input-type=stablehlo"
"--iree-stablehlo-demote-i64-to-i32=false"
"--iree-input-demote-i64-to-i32=false"
"--iree-vulkan-target-triple=valhall-unknown-android31"
LABELS
"manual"
Expand Down

0 comments on commit 74d2728

Please sign in to comment.