Skip to content

Commit

Permalink
Refactor loop unroller pass.
Browse files Browse the repository at this point in the history
Add a knob to force unroll if needed.

PiperOrigin-RevId: 636179776
  • Loading branch information
farzinhoushmand authored and tensorflower-gardener committed Jun 5, 2024
1 parent 5a23cdf commit e2872e7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 27 deletions.
54 changes: 29 additions & 25 deletions third_party/xla/xla/service/while_loop_unroller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,14 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op,
return while_body_clone;
}

// Checks the soft conditions of unrollability. Soft conditions are:
// 1. num instructions in loop body.
// 2. trip count.
// 3. unroll expansion limit (#_body_instructions * trip_count).
// These conditions can be changed per usecase.
bool InitialFeasibilityCheck(HloInstruction* while_op, WhileLoopConfig config) {
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);

// While loop must have a single tuple operand.
CHECK_EQ(while_op->operands().size(), 1);
if (while_op->operands().size() != 1) {
VLOG(5) << absl::StrCat(
"Cannot unroll while loop. While loop must have a single "
"tuple operand, instead has more than one operand: ",
while_op->operands().size());
return false;
}

VLOG(5) << "Trying to unroll " << while_op->ToShortString();

// TODO(b/291628533): Extract this parameter to the unroller config. We don't
Expand Down Expand Up @@ -248,10 +243,6 @@ bool InitialFeasibilityCheck(HloInstruction* while_op, WhileLoopConfig config) {

absl::StatusOr<bool> UnrollInternal(HloInstruction* while_op,
WhileLoopConfig config) {
if (!InitialFeasibilityCheck(while_op, config)) {
return false;
}

VLOG(3) << "Unrolling while instruction " << while_op->ToShortString()
<< " with body instruction count "
<< while_op->while_body()->instruction_count();
Expand Down Expand Up @@ -282,10 +273,6 @@ absl::StatusOr<bool> UnrollInternal(HloInstruction* while_op,

absl::StatusOr<bool> UnrollInternalWrapped(HloInstruction* while_op,
WhileLoopConfig config) {
if (!InitialFeasibilityCheck(while_op, config)) {
return false;
}

VLOG(3) << "Unrolling (wrapped) while instruction "
<< while_op->ToShortString() << " with body instruction count "
<< while_op->while_body()->instruction_count();
Expand Down Expand Up @@ -414,6 +401,17 @@ std::optional<int64_t> MatchShapeCoveringDynamicIndexInstruction(
HloInstruction* while_op) {
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);

// While loop must have a single tuple operand.
CHECK_EQ(while_op->operands().size(), 1);
if (while_op->operands().size() != 1) {
VLOG(5) << absl::StrCat(
"Cannot unroll while loop ", while_op->name(),
". While loop must have a single "
"tuple operand, instead has more than one operand: ",
while_op->operands().size());
return std::nullopt;
}

// TODO(b/300668690): Add support for unrolling loops with control dependency.
// For now, we bail.
//
Expand Down Expand Up @@ -488,6 +486,7 @@ std::optional<int64_t> MatchShapeCoveringDynamicIndexInstruction(
std::optional<int64_t> trip_count =
MatchTrivialLoopTripCount(while_op, *indvar_tuple_idx, indvar_iter_val);
if (!trip_count.has_value()) {
VLOG(3) << "Loop doesn't have trivial trip count";
return std::nullopt;
}

Expand All @@ -498,11 +497,6 @@ std::optional<int64_t> MatchShapeCoveringDynamicIndexInstruction(
LiteralUtil::LiteralAsScalarInt64(std::move(indvar_iter_val)).value();
config.trip_count = trip_count.value();
config.induction_var_idx = *indvar_tuple_idx;

if (!InitialFeasibilityCheck(while_op, config)) {
return std::nullopt;
}

return config;
}

Expand Down Expand Up @@ -555,15 +549,19 @@ WhileLoopUnroller::GetUnrollableLoops(
for (HloInstruction* instr : all_while_ops) {
std::optional<WhileLoopConfig> config = IsLoopUnrollable(instr);
if (config.has_value()) {
if (!InitialFeasibilityCheck(instr, config.value())) {
VLOG(3) << "Initial feasibility check failed for " << instr->name();
continue;
}
while_loop_configs.emplace_back(instr, config.value());
}
}
return while_loop_configs;
}

/*static*/ absl::StatusOr<bool> WhileLoopUnroller::Unroll(
HloInstruction* while_op, int64_t unroll_factor,
bool wrap_in_trivial_loop) {
HloInstruction* while_op, int64_t unroll_factor, bool wrap_in_trivial_loop,
bool force_unroll) {
bool changed = false;
HloModule* module = while_op->GetModule();
// TODO(b/288130138): For now, we only support full unrolling. Will add
Expand All @@ -583,6 +581,12 @@ WhileLoopUnroller::GetUnrollableLoops(
// Construct the loop config
std::optional<WhileLoopConfig> config = IsLoopUnrollable(while_op);
if (!config.has_value()) {
VLOG(5) << "Not attempting to unroll " << while_op->name()
<< " because it is not unrollable.";
return false;
}

if (!force_unroll && !InitialFeasibilityCheck(while_op, config.value())) {
return false;
}

Expand Down
6 changes: 4 additions & 2 deletions third_party/xla/xla/service/while_loop_unroller.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,12 @@ class WhileLoopUnroller : public HloModulePass {

// Unrolls the given while loop with the default behaviour set to full unroll.
// If wrap_in_trivial_loop is set, the unrolled body of the loop will be
// wrapped in a loop with trip count of one.
// wrapped in a loop with trip count of one. Forcing unroll will not perform
// soft checking of the conditions.
static absl::StatusOr<bool> Unroll(HloInstruction* while_op,
int64_t unroll_factor = -1,
bool wrap_in_trivial_loop = false);
bool wrap_in_trivial_loop = false,
bool force_unroll = false);

private:
int64_t unroll_factor_;
Expand Down

0 comments on commit e2872e7

Please sign in to comment.