diff --git a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp index 68a08b5a431a..42a36c98f935 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp @@ -32,8 +32,13 @@ namespace mlir::iree_compiler { namespace { struct AutoInputConversionPipelinePass final : AutoInputConversionPipelineBase { + AutoInputConversionPipelinePass( + const AutoInputConversionPipelineOptions& inputOptions) + : options(inputOptions) {} void runOnOperation() override; void getDependentDialects(DialectRegistry& registry) const override; + + AutoInputConversionPipelineOptions options; }; // All the features seen that should be handled during input conversion. @@ -154,10 +159,14 @@ void AutoInputConversionPipelinePass::runOnOperation() { OpPassManager::Nesting::Explicit); #ifdef IREE_HAVE_MHLO_INPUT if (features.hasStableHLO && !features.hasMHLO) { + stablehlo::StableHloOptions options; + options.demoteI64ToI32 = demoteI64ToI32; + options.demoteF64ToF32 = demoteF64ToF32; + options.promoteBF16ToF32 = promoteBF16ToF32; if (features.hasTuples) { - stablehlo::buildStableHLOXLAInputConversionPassPipeline(pm); + stablehlo::buildStableHLOXLAInputConversionPassPipeline(pm, options); } else { - stablehlo::buildStableHLOInputConversionPassPipeline(pm); + stablehlo::buildStableHLOInputConversionPassPipeline(pm, options); } } if (features.hasMHLO) { @@ -201,8 +210,19 @@ void AutoInputConversionPipelinePass::getDependentDialects( }; #ifdef IREE_HAVE_MHLO_INPUT - appendPipelineDialects(stablehlo::buildStableHLOInputConversionPassPipeline); - appendPipelineDialects( + auto appendStablehloPipelineDialects = + [®istry](function_ref + buildFn) { + const stablehlo::StableHloOptions options; + OpPassManager pm; + buildFn(pm, options); + pm.getDependentDialects(registry); + }; + + appendStablehloPipelineDialects( + stablehlo::buildStableHLOInputConversionPassPipeline); + appendStablehloPipelineDialects( stablehlo::buildStableHLOXLAInputConversionPassPipeline); appendPipelineDialects(MHLO::buildMHLOInputConversionPassPipeline); @@ -224,7 +244,13 @@ void AutoInputConversionPipelinePass::getDependentDialects( std::unique_ptr> createAutoInputConversionPipelinePass() { - return std::make_unique(); + AutoInputConversionPipelineOptions options; + return std::make_unique(options); +} + +std::unique_ptr> createAutoInputConversionPipelinePass( + const AutoInputConversionPipelineOptions& options) { + return std::make_unique(options); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/InputConversion/Common/Passes.h b/compiler/src/iree/compiler/InputConversion/Common/Passes.h index dfbff2fda103..797a29c4605e 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/Passes.h +++ b/compiler/src/iree/compiler/InputConversion/Common/Passes.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES_H_ #define IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES_H_ +#include "iree/compiler/InputConversion/Common/PassDetail.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -14,6 +15,9 @@ namespace mlir { namespace iree_compiler { +#define GEN_PASS_DECL +#include "iree/compiler/InputConversion/Common/Passes.h.inc" + //===----------------------------------------------------------------------===// // Pipelines //===----------------------------------------------------------------------===// @@ -28,6 +32,8 @@ void buildCommonInputConversionPassPipeline(OpPassManager &passManager); std::unique_ptr> createAutoInputConversionPipelinePass(); +std::unique_ptr> createAutoInputConversionPipelinePass( + const AutoInputConversionPipelineOptions& options); std::unique_ptr> createIREEImportPublicPass(); std::unique_ptr> createImportMLProgramPass(); std::unique_ptr> diff --git a/compiler/src/iree/compiler/InputConversion/Common/Passes.td b/compiler/src/iree/compiler/InputConversion/Common/Passes.td index f42e5fffd68c..950cce602b91 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/Passes.td +++ b/compiler/src/iree/compiler/InputConversion/Common/Passes.td @@ -53,6 +53,14 @@ def AutoInputConversionPipeline : conversion to run, then run that conversion. }]; let constructor = "mlir::iree_compiler::createAutoInputConversionPipelinePass()"; + let options = [ + Option<"demoteI64ToI32", "iree-autoinput-demote-i64-to-i32", "bool", + /*default=*/"true", "Convert I64 to I32 equivalents">, + Option<"demoteF64ToF32", "iree-autoinput-demote-f64-to-f32", "bool", + /*default=*/"false", "Convert F64 to F32 equivalents">, + Option<"promoteBF16ToF32", "iree-autoinput-demote-bf16-to-f32", "bool", + /*default=*/"false", "Convert BF16 to F32 equivalents">, + ]; } #endif // IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp index d82ac7557fa3..6eef69734e3a 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp @@ -26,36 +26,19 @@ namespace { } // namespace namespace { -// TODO(#8745): remove these flags when the -iree-flow-demote-* flags can be -// used without tripping upstream verifier issues. -llvm::cl::opt clDemoteI64ToI32( - "iree-stablehlo-demote-i64-to-i32", - llvm::cl::desc( - "Converts all StableHLO i64 ops and values into i32 counterparts."), - llvm::cl::init(true)); -llvm::cl::opt clDemoteF64ToF32( - "iree-stablehlo-demote-f64-to-f32", - llvm::cl::desc( - "Converts all StableHLO f64 ops and values into f32 counterparts."), - llvm::cl::init(true)); -llvm::cl::opt clPromoteBF16ToF32( - "iree-stablehlo-promote-bf16-to-f32", - llvm::cl::desc( - "Converts all StableHLO bf16 ops and values into f32 counterparts."), - llvm::cl::init(false)); void registerStableHLOConversionPassPipeline() { - PassPipelineRegistration<> stablehlo( + PassPipelineRegistration stablehlo( "iree-stablehlo-input-transformation-pipeline", "Runs the StableHLO IREE flow dialect transformation pipeline", - [](OpPassManager &passManager) { - buildStableHLOInputConversionPassPipeline(passManager); + [](OpPassManager& passManager, const StableHloOptions& options) { + buildStableHLOInputConversionPassPipeline(passManager, options); }); } // Prepare HLO for use as an input to the Flow dialect. -void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager, - bool detuple) { +void buildStableHLOInputConversionPassPipelineImpl( + OpPassManager& passManager, const StableHloOptions& options, bool detuple) { passManager.addNestedPass(mlir::createCanonicalizerPass()); passManager.addNestedPass(createStableHLOCanonicalize()); passManager.addNestedPass(mlir::createCSEPass()); @@ -88,13 +71,13 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager, // stack. This is often required because of implicit i64 insertion by JAX/HLO // that we don't want forcing 32-bit embedded devices to support. // TODO(#8745): remove these and prefer the flow pipeline options instead. - if (clDemoteI64ToI32) { + if (options.demoteI64ToI32) { passManager.addPass(IREE::Util::createDemoteI64ToI32Pass()); } - if (clDemoteF64ToF32) { + if (options.demoteF64ToF32) { passManager.addPass(IREE::Util::createDemoteF64ToF32Pass()); } - if (clPromoteBF16ToF32) { + if (options.promoteBF16ToF32) { passManager.addPass(IREE::Util::createPromoteBF16ToF32Pass()); } @@ -123,12 +106,16 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager, } } // namespace -void buildStableHLOInputConversionPassPipeline(OpPassManager &passManager) { - buildStableHLOInputConversionPassPipelineImpl(passManager, /*detuple=*/false); +void buildStableHLOInputConversionPassPipeline( + OpPassManager& passManager, const StableHloOptions& options) { + buildStableHLOInputConversionPassPipelineImpl(passManager, options, + /*detuple=*/false); } -void buildStableHLOXLAInputConversionPassPipeline(OpPassManager &passManager) { - buildStableHLOInputConversionPassPipelineImpl(passManager, /*detuple=*/true); +void buildStableHLOXLAInputConversionPassPipeline( + OpPassManager& passManager, const StableHloOptions& options) { + buildStableHLOInputConversionPassPipelineImpl(passManager, options, + /*detuple=*/true); } void registerStableHLOConversionPasses() { diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.h b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.h index 29cfc95b47e2..be8af77ced36 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.h +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.h @@ -16,15 +16,23 @@ namespace iree_compiler::stablehlo { std::unique_ptr createStableHloToLinalgTypeConverter(); +struct StableHloOptions : public PassPipelineOptions { + bool demoteI64ToI32 = true; + bool demoteF64ToF32 = false; + bool promoteBF16ToF32 = false; +}; + //===----------------------------------------------------------------------===// // Pipelines //===----------------------------------------------------------------------===// -void buildStableHLOInputConversionPassPipeline(OpPassManager &passManager); +void buildStableHLOInputConversionPassPipeline(OpPassManager& passManager, + const StableHloOptions& options); // Performs input legalization on programs that may have originated from an XLA // import (or made to interop with it). -void buildStableHLOXLAInputConversionPassPipeline(OpPassManager &passManager); +void buildStableHLOXLAInputConversionPassPipeline( + OpPassManager& passManager, const StableHloOptions& options); //===----------------------------------------------------------------------===// // Register all Passes diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index 5e0c30909500..904dc31e541c 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -68,6 +68,23 @@ void InputDialectOptions::bindOptions(OptionsBinder &binder) { ), // clang-format on llvm::cl::cat(category)); + +#ifdef IREE_HAVE_MHLO_INPUT + binder.opt( + "iree-input-demote-i64-to-i32", demoteI64ToI32, + llvm::cl::desc("Converts all i64 ops and values into i32 counterparts."), + llvm::cl::cat(category)); + + binder.opt( + "iree-input-demote-f64-to-f32", demoteF64ToF32, + llvm::cl::desc("Converts all f64 ops and values into f32 counterparts."), + llvm::cl::cat(category)); + + binder.opt( + "iree-input-promote-bf16-to-f32", promoteBF16ToF32, + llvm::cl::desc("Converts all bf16 ops and values into f32 counterparts."), + llvm::cl::cat(category)); +#endif } void HighLevelOptimizationOptions::bindOptions(OptionsBinder &binder) { diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index 50052407c4ed..40b9504f9474 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -59,6 +59,10 @@ struct InputDialectOptions { }; Type type = Type::auto_detect; + bool demoteI64ToI32 = true; + bool demoteF64ToF32 = true; + bool promoteBF16ToF32 = true; + void bindOptions(OptionsBinder &binder); using FromFlags = OptionsFromFlags; }; diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index a46fedc3d7ae..9ce17f2dd471 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -53,18 +53,28 @@ void buildIREEVMTransformPassPipeline( hooks.pipelineExtensions->extendInputConversionPreprocessingPassPipeline( passManager, inputOptions.type); } + AutoInputConversionPipelineOptions autoOptions; + +#ifdef IREE_HAVE_MHLO_INPUT + stablehlo::StableHloOptions stablehloOptions; + stablehloOptions.demoteI64ToI32 = inputOptions.demoteI64ToI32; + stablehloOptions.demoteF64ToF32 = inputOptions.demoteF64ToF32; + stablehloOptions.promoteBF16ToF32 = inputOptions.promoteBF16ToF32; +#endif switch (inputOptions.type) { case InputDialectOptions::Type::none: break; case InputDialectOptions::Type::auto_detect: - passManager.addPass(createAutoInputConversionPipelinePass()); + passManager.addPass(createAutoInputConversionPipelinePass(autoOptions)); break; #ifdef IREE_HAVE_MHLO_INPUT case InputDialectOptions::Type::stablehlo: - stablehlo::buildStableHLOInputConversionPassPipeline(passManager); + stablehlo::buildStableHLOInputConversionPassPipeline(passManager, + stablehloOptions); break; case InputDialectOptions::Type::stablehlo_xla: - stablehlo::buildStableHLOXLAInputConversionPassPipeline(passManager); + stablehlo::buildStableHLOXLAInputConversionPassPipeline(passManager, + stablehloOptions); break; case InputDialectOptions::Type::mhlo_legacy: MHLO::buildMHLOInputConversionPassPipeline(passManager); diff --git a/tests/e2e/vulkan_specific/BUILD.bazel b/tests/e2e/vulkan_specific/BUILD.bazel index f9170323c6f7..9111af49a6c2 100644 --- a/tests/e2e/vulkan_specific/BUILD.bazel +++ b/tests/e2e/vulkan_specific/BUILD.bazel @@ -53,7 +53,7 @@ iree_check_single_backend_test_suite( ], compiler_flags = [ "--iree-input-type=stablehlo", - "--iree-stablehlo-demote-i64-to-i32=false", + "--iree-input-demote-i64-to-i32=false", "--iree-vulkan-target-triple=valhall-unknown-android31", ], driver = "vulkan", diff --git a/tests/e2e/vulkan_specific/CMakeLists.txt b/tests/e2e/vulkan_specific/CMakeLists.txt index 3d0d9b3bae71..65623d2501af 100644 --- a/tests/e2e/vulkan_specific/CMakeLists.txt +++ b/tests/e2e/vulkan_specific/CMakeLists.txt @@ -55,7 +55,7 @@ iree_check_single_backend_test_suite( "vulkan" COMPILER_FLAGS "--iree-input-type=stablehlo" - "--iree-stablehlo-demote-i64-to-i32=false" + "--iree-input-demote-i64-to-i32=false" "--iree-vulkan-target-triple=valhall-unknown-android31" LABELS "manual"