Skip to content

Commit

Permalink
Replaces the lazy constraint instantiation logic with memory term red…
Browse files Browse the repository at this point in the history
…uction.

PiperOrigin-RevId: 627368773
  • Loading branch information
tensorflower-gardener committed Apr 23, 2024
1 parent d7e890a commit 38bbeaf
Showing 1 changed file with 111 additions and 151 deletions.
Expand Up @@ -19,9 +19,11 @@ limitations under the License.
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

Expand All @@ -38,6 +40,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_memory.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
#include "xla/status.h"
#include "xla/status_macros.h"
Expand Down Expand Up @@ -67,13 +70,7 @@ constexpr double kMaxCostEpsilon = 1.0001;
// bounds, etc.) using smaller absolute values, due to limitations on precision.
// To compensate, the overbudget objective coefficient must be amplified by the
// same amount.
constexpr double kMemoryMultiplier = 1e-6;

// Any memory terms below this threshold will be dropped (to reduce MIP size).
constexpr double kTinyTermThreshold = 1e-6;

// Always include memory constraints with this number of terms or fewer.
constexpr int64_t kMemoryCardinalityThreshold = 1000;
constexpr double kMemoryMultiplier = 1e-5;

bool AutoShardingSolverOutput::operator==(
const AutoShardingSolverOutput& other) const {
Expand Down Expand Up @@ -158,7 +155,7 @@ AutoShardingSolverResult SolveAndExtractSolution(
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
const MPVariable* overbudget_var, const MPVariable* makespan_var,
absl::Time start_time, MPSolver& solver);
MPSolver& solver);

double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) {
double min_memory_budget_required_estimate = 0.0;
Expand Down Expand Up @@ -221,6 +218,69 @@ AutoShardingSolverRequest ScaleRequest(
return scaled_request;
}

// Given the live matrix and memory costs (for nodes or edges), reduce terms and
// create constrained variables for the subsequent groups.
std::pair<int64_t, int64_t> ReduceMemoryTerms(
MPSolver& solver, int64_t num_lives, int64_t num_primitives,
const std::function<
tsl::protobuf::RepeatedField<int64_t>(int64_t)>& // NOLINT
live,
const tsl::protobuf::RepeatedPtrField< // NOLINT
AutoShardingSolverRequest_Costs>& memory_costs,
std::string_view prim_type,
std::vector<std::vector<MPVariable*>>& prim_vars,
std::vector<std::vector<int64_t>>& reduced_live,
std::vector<MPVariable*>& group_vars) {
MemoryTermReducer reducer;
auto num_terms = reducer.Reduce(num_lives, num_primitives, live);
reduced_live = reducer.GetReducedLive();
const auto& reduced_groups = reducer.GetReducedGroups();
solver.MakeIntVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(),
absl::StrCat("group_", prim_type), &group_vars);
for (int64_t group_idx = 0; group_idx < group_vars.size(); ++group_idx) {
MPConstraint* constraint = solver.MakeRowConstraint(
-MPSolver::infinity(), 0.0,
absl::StrCat("group_", prim_type, "[", group_idx, "]"));
constraint->SetCoefficient(group_vars[group_idx], -1.0);
for (const int64_t prim_idx : reduced_groups[group_idx]) {
for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) {
double memory_cost = memory_costs.at(prim_idx).costs(j);
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(prim_vars[prim_idx][j]);
constraint->SetCoefficient(prim_vars[prim_idx][j],
accumulated_coefficient + memory_cost);
}
}
}
return num_terms;
}

// Adds the appropriate memory terms (for nodes or edges) at the given time.
void AddMemoryTerms(MPSolver& solver, int64_t num_primitives,
const std::vector<std::vector<int64_t>>& live,
const tsl::protobuf::RepeatedPtrField< // NOLINT
AutoShardingSolverRequest_Costs>& memory_costs,
LivenessIdx time_idx,
std::vector<std::vector<MPVariable*>>& prim_vars,
std::vector<MPVariable*>& group_vars,
MPConstraint* constraint) {
for (const int64_t prim_idx : live[time_idx]) {
if (prim_idx >= num_primitives) {
constraint->SetCoefficient(group_vars[prim_idx - num_primitives], 1.0);
continue;
}
for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) {
double memory_cost = memory_costs.at(prim_idx).costs(j);
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(prim_vars[prim_idx][j]);
constraint->SetCoefficient(prim_vars[prim_idx][j],
accumulated_coefficient + memory_cost);
}
}
}

// Taking an auto-sharding problem (`request`) as an input, calls the OR tools
// CP-SAT solver and outputs a solution to the input problem.
//
Expand Down Expand Up @@ -474,6 +534,46 @@ AutoShardingSolverResult CallORToolsSolver(
}
// c.
if (request.memory_budget() > 0) {
auto LiveNodes =
[request](int64_t live_idx) -> tsl::protobuf::RepeatedField<int64_t> {
return request.live(live_idx).nodes();
};
auto LiveEdges =
[request](int64_t live_idx) -> tsl::protobuf::RepeatedField<int64_t> {
return request.live_edges(live_idx).edges();
};
std::vector<std::vector<int64_t>> reduced_live_nodes, reduced_live_edges;
std::vector<MPVariable*> group_node_vars, group_edge_vars;
const absl::Time term_reduction_start_time = absl::Now();
auto num_node_terms = ReduceMemoryTerms(
*solver, request.live_size(), request.num_nodes(), std::move(LiveNodes),
request.memory_costs(), "node", s, reduced_live_nodes, group_node_vars);
auto num_edge_terms = ReduceMemoryTerms(
*solver, request.live_edges_size(), request.edges_size(),
std::move(LiveEdges), request.memory_edge_costs(), "edge", e,
reduced_live_edges, group_edge_vars);
const absl::Time term_reduction_end_time = absl::Now();
const auto term_reduction_duration =
term_reduction_end_time - term_reduction_start_time;
LOG(INFO) << "Memory Term Reducer took "
<< absl::ToInt64Milliseconds(term_reduction_duration)
<< " ms and reduced the number of terms from "
<< num_node_terms.first + num_edge_terms.first << " to "
<< num_node_terms.second + num_edge_terms.second;
for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) {
MPConstraint* constraint = solver->MakeRowConstraint(
-MPSolver::infinity(), request.memory_budget() * kMemoryMultiplier,
absl::StrCat("mem[", time_idx, "]"));
if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0);
AddMemoryTerms(*solver, request.num_nodes(), reduced_live_nodes,
request.memory_costs(), time_idx, s, group_node_vars,
constraint);
if (!request.live_edges().empty() && request.enable_memory_edge_costs()) {
AddMemoryTerms(*solver, request.edges_size(), reduced_live_edges,
request.memory_edge_costs(), time_idx, e,
group_edge_vars, constraint);
}
}
if (overbudget_var) {
solver->MutableObjective()->SetCoefficient(
overbudget_var,
Expand Down Expand Up @@ -645,7 +745,7 @@ AutoShardingSolverResult CallORToolsSolver(
VLOG(0) << "Max cost: " << request.max_cost().coeff();
}
auto result = SolveAndExtractSolution(request, s, e, overbudget_var,
makespan_var, start_time, *solver);
makespan_var, *solver);
if (result.status.ok()) {
const AutoShardingEvaluation evaluation =
Evaluate(unscaled_request, result);
Expand Down Expand Up @@ -711,154 +811,14 @@ std::vector<EdgeStrategyIdx> GetChosenEdgeStrategy(
return chosen_edge_strategy;
}

// Finds the timestep with the largest memory overbudget (-1 if no such value).
LivenessIdx FindPeakLiveness(const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e) {
const std::vector<NodeStrategyIdx> chosen_node_strategy =
GetChosenNodeStrategy(request, s);
const std::vector<EdgeStrategyIdx> chosen_edge_strategy =
GetChosenEdgeStrategy(request, e);
LivenessIdx peak_time_idx = -1;
double peak_overbudget = 0.0;
for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) {
if (request.live(time_idx).nodes_size() <= kMemoryCardinalityThreshold) {
continue; // We always enforce these, no need to consider them again.
}
double memory_usage = 0.0;
for (NodeIdx node_idx : request.live(time_idx).nodes()) {
const NodeStrategyIdx j = chosen_node_strategy[node_idx];
memory_usage += request.memory_costs(node_idx).costs(j);
}
if (!request.live_edges().empty() && request.enable_memory_edge_costs()) {
for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) {
const EdgeStrategyIdx j = chosen_edge_strategy[edge_idx];
memory_usage += request.memory_edge_costs(edge_idx).costs(j);
}
}
const double overbudget = memory_usage - request.memory_budget();
if (peak_overbudget < overbudget) {
peak_overbudget = overbudget;
peak_time_idx = time_idx;
}
}
return peak_time_idx;
}

// Imposes a new memory constraint at the given location. Returns the number of
// tiny terms created.
int ImposeMemoryConstraint(const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
const MPVariable* overbudget_var, MPSolver& solver,
LivenessIdx time_idx) {
int tiny_term_count = 0;
VLOG(1) << "Imposing a memory constraint at time index " << time_idx;
MPConstraint* constraint =
solver.MakeRowConstraint(-MPSolver::infinity(), MPSolver::infinity(),
absl::StrCat("mem[", time_idx, "]"));
if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0);
double tiny_term_total = 0.0; // Used to trim the memory budget downward.
for (NodeIdx node_idx : request.live(time_idx).nodes()) {
double tiny_term_max = 0.0;
for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) {
double memory_cost = request.memory_costs(node_idx).costs(j);
if (memory_cost < kTinyTermThreshold * request.memory_budget()) {
tiny_term_max = std::max(tiny_term_max, memory_cost);
if (memory_cost > 0.0) ++tiny_term_count;
continue;
}
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(s[node_idx][j]);
constraint->SetCoefficient(s[node_idx][j],
accumulated_coefficient + memory_cost);
}
tiny_term_total += tiny_term_max;
}
if (!request.live_edges().empty() && request.enable_memory_edge_costs()) {
for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) {
double tiny_term_max = 0.0;
for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) {
double memory_cost = request.memory_edge_costs(edge_idx).costs(j);
if (memory_cost < kTinyTermThreshold * request.memory_budget()) {
tiny_term_max = std::max(tiny_term_max, memory_cost);
if (memory_cost > 0.0) ++tiny_term_count;
continue;
}
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(e[edge_idx][j]);
constraint->SetCoefficient(e[edge_idx][j],
accumulated_coefficient + memory_cost);
}
tiny_term_total += tiny_term_max;
}
}
constraint->SetUB(kMemoryMultiplier *
(request.memory_budget() - tiny_term_total));
return tiny_term_count;
}

AutoShardingSolverResult SolveAndExtractSolution(
const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
const MPVariable* overbudget_var, const MPVariable* makespan_var,
absl::Time start_time, MPSolver& solver) {
int tiny_term_count = 0;
absl::flat_hash_set<LivenessIdx> peak_times, small_times;
if (request.memory_budget() > 0) {
// Always enforce constraints that have a relatively small number of terms.
for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) {
if (request.live(time_idx).nodes_size() <= kMemoryCardinalityThreshold) {
small_times.insert(time_idx);
tiny_term_count += ImposeMemoryConstraint(request, s, e, overbudget_var,
solver, time_idx);
}
}
// Also add in any peak times that were encountered in previous iterations.
if (!request.deterministic_mode()) {
for (const LivenessIdx peak_time_idx : request.peak_times()) {
if (small_times.contains(peak_time_idx)) continue;
peak_times.insert(peak_time_idx);
tiny_term_count += ImposeMemoryConstraint(request, s, e, overbudget_var,
solver, peak_time_idx);
}
}
}
MPSolver& solver) {
auto status = solver.Solve();
if (request.memory_budget() > 0) {
// Continue to add memory constraints until (a) they are all satisfied,
// (b) the problem becomes infeasible, or (c) the solver times out.
while (status == operations_research::MPSolver::OPTIMAL) {
std::vector<std::pair<const MPVariable*, double>> hint;
for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) {
if (request.s_follow(node_idx) >= 0) continue;
for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) {
hint.push_back({s[node_idx][j], s[node_idx][j]->solution_value()});
}
}
solver.SetHint(hint);
const LivenessIdx peak_time_idx = FindPeakLiveness(request, s, e);
if (peak_time_idx == -1 || peak_times.contains(peak_time_idx)) break;
if (small_times.contains(peak_time_idx)) break;
peak_times.insert(peak_time_idx);
tiny_term_count += ImposeMemoryConstraint(request, s, e, overbudget_var,
solver, peak_time_idx);
if (request.has_solver_timeout()) {
auto remaining_time =
request.solver_timeout().solver_timeout_in_seconds();
remaining_time -= absl::ToInt64Seconds(absl::Now() - start_time);
solver.SetTimeLimit(absl::Seconds(std::max(remaining_time, 0L)));
}
status = solver.Solve();
}
LOG(INFO) << "Imposed " << peak_times.size() + small_times.size()
<< " memory constraints out of " << request.live_size();
}
LOG(INFO) << "Solver Status: " << status;
LOG(INFO) << "Number of tiny terms: " << tiny_term_count;

if (status == operations_research::MPSolver::INFEASIBLE) {
LOG(ERROR) << "MPSolver could not find any feasible solution.";
Expand Down Expand Up @@ -965,7 +925,7 @@ AutoShardingSolverResult SolveAndExtractSolution(
PrintLargestInstructions(chosen_node_strategy, request);
const AutoShardingSolverOutput output = {std::move(chosen_node_strategy),
std::move(chosen_edge_strategy),
unsalted_objective, peak_times};
unsalted_objective};
return AutoShardingSolverResult(output, false);
}

Expand Down

0 comments on commit 38bbeaf

Please sign in to comment.