Skip to content

Commit

Permalink
Minor changes that help debugability:
Browse files Browse the repository at this point in the history
1. Cast the printed memory budget into double to make it easy to compare it with the estimated minumum memory required, and thereby easy to compare.
2. Guard some checks when computing the memory lower bound within vlogs as the checks are often slow.
3. Increase the VLOG level for printing the memory usage of the solution as computing this also often takes a while.

PiperOrigin-RevId: 617677396
  • Loading branch information
tensorflower-gardener committed Mar 21, 2024
1 parent c1f19cb commit 0f08501
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
34 changes: 19 additions & 15 deletions third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2637,24 +2637,28 @@ int64_t MemoryBudgetLowerBound(const HloModule& module,
// as aliasing HloValues are mapped to the same buffer.
absl::flat_hash_map<HloBuffer::Id, const HloValue*>
buffer_to_sharded_value_mapping;
bool vlog_is_on_5 = VLOG_IS_ON(5);
for (LivenessIdx time_idx = 0; time_idx < liveness_set.size(); ++time_idx) {
for (const HloValue* value : liveness_set[time_idx]) {
const auto& buffer = alias_analysis->GetBufferContainingValue(*value);
if (value->instruction()->has_sharding()) {
auto this_value_sharding = get_value_sharding(value);
auto iter = buffer_to_sharded_value_mapping.find(buffer.id());
if (iter != buffer_to_sharded_value_mapping.end()) {
auto buffer_value_sharding = get_value_sharding(iter->second);
if (this_value_sharding != buffer_value_sharding) {
// TODO(pratikf): This is an unavoidable situation, but possibly
// there is a better design decision that can be made here.
VLOG(1) << "We have a situation where two HloValues alias, but "
"they have different shardings. This can happen in the "
"presence of user-specified shardings, and is expected. "
"This, however, means that the memory budget estimate "
"is not very accurate. The aliasing HLOs are "
<< value->ToShortString() << " and "
<< iter->second->ToShortString();
if (vlog_is_on_5) {
auto this_value_sharding = get_value_sharding(value);
auto iter = buffer_to_sharded_value_mapping.find(buffer.id());
if (iter != buffer_to_sharded_value_mapping.end()) {
auto buffer_value_sharding = get_value_sharding(iter->second);
if (this_value_sharding != buffer_value_sharding) {
// TODO(pratikf): This is an unavoidable situation, but possibly
// there is a better design decision that can be made here.
VLOG(1)
<< "We have a situation where two HloValues alias, but "
"they have different shardings. This can happen in the "
"presence of user-specified shardings, and is expected. "
"This, however, means that the memory budget estimate "
"is not very accurate. The aliasing HLOs are "
<< value->ToShortString() << " and "
<< iter->second->ToShortString();
}
}
}
buffer_to_sharded_value_mapping[buffer.id()] = value;
Expand Down Expand Up @@ -3768,7 +3772,7 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
XLA_VLOG_LINES(5, PrintAutoShardingSolution(sequence, liveness_set,
strategy_map, strategy_groups,
cost_graph, s_val, objective));
XLA_VLOG_LINES(1, PrintSolutionMemoryUsage(liveness_set, strategy_map,
XLA_VLOG_LINES(6, PrintSolutionMemoryUsage(liveness_set, strategy_map,
cost_graph, s_val));

// ----- Substitute all-reduce with reduce-scatter -----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ AutoShardingSolverResult CallORToolsSolver(
}
LOG(INFO) << "Minimum memory budget estimate: "
<< MinimumMemoryBudgetRequired(request);
LOG(INFO) << "Using memory budget: " << request.memory_budget();
LOG(INFO) << "Using memory budget: "
<< static_cast<double>(request.memory_budget());
}

// d. specified via "BoolVarArray"
Expand Down

0 comments on commit 0f08501

Please sign in to comment.