From 955d406584e10312daecae4aa4a8eecb3d110a39 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Tue, 4 Jun 2024 17:30:01 -0700 Subject: [PATCH] Refactor loop unroller pass. Add a knob to force unroll if needed. PiperOrigin-RevId: 640336974 --- .../xla/xla/service/while_loop_unroller.cc | 54 ++++++++++--------- .../xla/xla/service/while_loop_unroller.h | 6 ++- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/service/while_loop_unroller.cc b/third_party/xla/xla/service/while_loop_unroller.cc index 2da3691134b698..b5f3b550b259f4 100644 --- a/third_party/xla/xla/service/while_loop_unroller.cc +++ b/third_party/xla/xla/service/while_loop_unroller.cc @@ -195,19 +195,14 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, return while_body_clone; } +// Checks the soft conditions of unrollability. Soft conditions are: +// 1. num instructions in loop body. +// 2. trip count. +// 3. unroll expansion limit (#_body_instructions * trip_count). +// These conditions can be changed per usecase. bool InitialFeasibilityCheck(HloInstruction* while_op, WhileLoopConfig config) { CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - // While loop must have a single tuple operand. - CHECK_EQ(while_op->operands().size(), 1); - if (while_op->operands().size() != 1) { - VLOG(5) << absl::StrCat( - "Cannot unroll while loop. While loop must have a single " - "tuple operand, instead has more than one operand: ", - while_op->operands().size()); - return false; - } - VLOG(5) << "Trying to unroll " << while_op->ToShortString(); // TODO(b/291628533): Extract this parameter to the unroller config. We don't @@ -248,10 +243,6 @@ bool InitialFeasibilityCheck(HloInstruction* while_op, WhileLoopConfig config) { absl::StatusOr UnrollInternal(HloInstruction* while_op, WhileLoopConfig config) { - if (!InitialFeasibilityCheck(while_op, config)) { - return false; - } - VLOG(3) << "Unrolling while instruction " << while_op->ToShortString() << " with body instruction count " << while_op->while_body()->instruction_count(); @@ -282,10 +273,6 @@ absl::StatusOr UnrollInternal(HloInstruction* while_op, absl::StatusOr UnrollInternalWrapped(HloInstruction* while_op, WhileLoopConfig config) { - if (!InitialFeasibilityCheck(while_op, config)) { - return false; - } - VLOG(3) << "Unrolling (wrapped) while instruction " << while_op->ToShortString() << " with body instruction count " << while_op->while_body()->instruction_count(); @@ -414,6 +401,17 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( HloInstruction* while_op) { CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + // While loop must have a single tuple operand. + CHECK_EQ(while_op->operands().size(), 1); + if (while_op->operands().size() != 1) { + VLOG(5) << absl::StrCat( + "Cannot unroll while loop ", while_op->name(), + ". While loop must have a single " + "tuple operand, instead has more than one operand: ", + while_op->operands().size()); + return std::nullopt; + } + // TODO(b/300668690): Add support for unrolling loops with control dependency. // For now, we bail. // @@ -488,6 +486,7 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( std::optional trip_count = MatchTrivialLoopTripCount(while_op, *indvar_tuple_idx, indvar_iter_val); if (!trip_count.has_value()) { + VLOG(3) << "Loop doesn't have trivial trip count"; return std::nullopt; } @@ -498,11 +497,6 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( LiteralUtil::LiteralAsScalarInt64(std::move(indvar_iter_val)).value(); config.trip_count = trip_count.value(); config.induction_var_idx = *indvar_tuple_idx; - - if (!InitialFeasibilityCheck(while_op, config)) { - return std::nullopt; - } - return config; } @@ -555,6 +549,10 @@ WhileLoopUnroller::GetUnrollableLoops( for (HloInstruction* instr : all_while_ops) { std::optional config = IsLoopUnrollable(instr); if (config.has_value()) { + if (!InitialFeasibilityCheck(instr, config.value())) { + VLOG(3) << "Initial feasibility check failed for " << instr->name(); + continue; + } while_loop_configs.emplace_back(instr, config.value()); } } @@ -562,8 +560,8 @@ WhileLoopUnroller::GetUnrollableLoops( } /*static*/ absl::StatusOr WhileLoopUnroller::Unroll( - HloInstruction* while_op, int64_t unroll_factor, - bool wrap_in_trivial_loop) { + HloInstruction* while_op, int64_t unroll_factor, bool wrap_in_trivial_loop, + bool force_unroll) { bool changed = false; HloModule* module = while_op->GetModule(); // TODO(b/288130138): For now, we only support full unrolling. Will add @@ -583,6 +581,12 @@ WhileLoopUnroller::GetUnrollableLoops( // Construct the loop config std::optional config = IsLoopUnrollable(while_op); if (!config.has_value()) { + VLOG(5) << "Not attempting to unroll " << while_op->name() + << " because it is not unrollable."; + return false; + } + + if (!force_unroll && !InitialFeasibilityCheck(while_op, config.value())) { return false; } diff --git a/third_party/xla/xla/service/while_loop_unroller.h b/third_party/xla/xla/service/while_loop_unroller.h index b1684dccb51e88..13e90a98d06e60 100644 --- a/third_party/xla/xla/service/while_loop_unroller.h +++ b/third_party/xla/xla/service/while_loop_unroller.h @@ -103,10 +103,12 @@ class WhileLoopUnroller : public HloModulePass { // Unrolls the given while loop with the default behaviour set to full unroll. // If wrap_in_trivial_loop is set, the unrolled body of the loop will be - // wrapped in a loop with trip count of one. + // wrapped in a loop with trip count of one. Forcing unroll will not perform + // soft checking of the conditions. static absl::StatusOr Unroll(HloInstruction* while_op, int64_t unroll_factor = -1, - bool wrap_in_trivial_loop = false); + bool wrap_in_trivial_loop = false, + bool force_unroll = false); private: int64_t unroll_factor_;