Skip to content

Commit

Permalink
Reintroduces the "tiny term trick."
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623646009
  • Loading branch information
tensorflower-gardener committed Apr 11, 2024
1 parent 71e8ec2 commit 40062bf
Showing 1 changed file with 42 additions and 14 deletions.
Expand Up @@ -69,6 +69,9 @@ constexpr double kMaxCostEpsilon = 1.0001;
// same amount.
constexpr double kMemoryMultiplier = 1e-6;

// Any memory terms below this threshold will be dropped (to reduce MIP size).
constexpr double kTinyTermThreshold = 1e-6;

// Always include memory constraints with this number of terms or fewer.
constexpr int64_t kMemoryCardinalityThreshold = 1000;

Expand Down Expand Up @@ -738,40 +741,59 @@ LivenessIdx FindPeakLiveness(const AutoShardingSolverRequest& request,
return peak_time_idx;
}

// Imposes a new memory constraint at the given location.
void ImposeMemoryConstraint(const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
const MPVariable* overbudget_var, MPSolver& solver,
LivenessIdx time_idx) {
// Imposes a new memory constraint at the given location. Returns the number of
// tiny terms created.
int ImposeMemoryConstraint(const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
const MPVariable* overbudget_var, MPSolver& solver,
LivenessIdx time_idx) {
int tiny_term_count = 0;
VLOG(1) << "Imposing a memory constraint at time index " << time_idx;
MPConstraint* constraint =
solver.MakeRowConstraint(-MPSolver::infinity(), MPSolver::infinity(),
absl::StrCat("mem[", time_idx, "]"));
if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0);
double tiny_term_total = 0.0; // Used to trim the memory budget downward.
for (NodeIdx node_idx : request.live(time_idx).nodes()) {
double tiny_term_max = 0.0;
for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) {
double memory_cost = request.memory_costs(node_idx).costs(j);
if (memory_cost < kTinyTermThreshold * request.memory_budget()) {
tiny_term_max = std::max(tiny_term_max, memory_cost);
if (memory_cost > 0.0) ++tiny_term_count;
continue;
}
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(s[node_idx][j]);
constraint->SetCoefficient(s[node_idx][j],
accumulated_coefficient + memory_cost);
}
tiny_term_total += tiny_term_max;
}
if (!request.live_edges().empty() && request.enable_memory_edge_costs()) {
for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) {
double tiny_term_max = 0.0;
for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) {
double memory_cost = request.memory_edge_costs(edge_idx).costs(j);
if (memory_cost < kTinyTermThreshold * request.memory_budget()) {
tiny_term_max = std::max(tiny_term_max, memory_cost);
if (memory_cost > 0.0) ++tiny_term_count;
continue;
}
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(e[edge_idx][j]);
constraint->SetCoefficient(e[edge_idx][j],
accumulated_coefficient + memory_cost);
}
tiny_term_total += tiny_term_max;
}
}
constraint->SetUB(kMemoryMultiplier * request.memory_budget());
constraint->SetUB(kMemoryMultiplier *
(request.memory_budget() - tiny_term_total));
return tiny_term_count;
}

AutoShardingSolverResult SolveAndExtractSolution(
Expand All @@ -780,21 +802,25 @@ AutoShardingSolverResult SolveAndExtractSolution(
const std::vector<std::vector<MPVariable*>>& e,
const MPVariable* overbudget_var, const MPVariable* makespan_var,
MPSolver& solver) {
int tiny_term_count = 0;
absl::Time start_time = absl::Now();
absl::flat_hash_set<LivenessIdx> peak_times;
absl::flat_hash_set<LivenessIdx> peak_times, small_times;
if (request.memory_budget() > 0) {
// Always enforce constraints that have a relatively small number of terms.
for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) {
if (request.live(time_idx).nodes_size() <= kMemoryCardinalityThreshold) {
ImposeMemoryConstraint(request, s, e, overbudget_var, solver, time_idx);
small_times.insert(time_idx);
tiny_term_count += ImposeMemoryConstraint(request, s, e, overbudget_var,
solver, time_idx);
}
}
// Also add in any peak times that were encountered in previous iterations.
if (!request.deterministic_mode()) {
for (const LivenessIdx peak_time_idx : request.peak_times()) {
if (small_times.contains(peak_time_idx)) continue;
peak_times.insert(peak_time_idx);
ImposeMemoryConstraint(request, s, e, overbudget_var, solver,
peak_time_idx);
tiny_term_count += ImposeMemoryConstraint(request, s, e, overbudget_var,
solver, peak_time_idx);
}
}
}
Expand All @@ -813,9 +839,10 @@ AutoShardingSolverResult SolveAndExtractSolution(
solver.SetHint(hint);
const LivenessIdx peak_time_idx = FindPeakLiveness(request, s, e);
if (peak_time_idx == -1 || peak_times.contains(peak_time_idx)) break;
if (small_times.contains(peak_time_idx)) break;
peak_times.insert(peak_time_idx);
ImposeMemoryConstraint(request, s, e, overbudget_var, solver,
peak_time_idx);
tiny_term_count += ImposeMemoryConstraint(request, s, e, overbudget_var,
solver, peak_time_idx);
if (request.has_solver_timeout()) {
auto remaining_time =
request.solver_timeout().solver_timeout_in_seconds();
Expand All @@ -824,13 +851,14 @@ AutoShardingSolverResult SolveAndExtractSolution(
}
status = solver.Solve();
}
LOG(INFO) << "Imposed " << peak_times.size()
LOG(INFO) << "Imposed " << peak_times.size() + small_times.size()
<< " memory constraints out of " << request.live_size();
}
absl::Time end_time = absl::Now();
auto duration = end_time - start_time;
LOG(INFO) << "Solver took " << absl::ToInt64Milliseconds(duration) << " ms";
LOG(INFO) << "Solver Status: " << status;
LOG(INFO) << "Number of tiny terms: " << tiny_term_count;

if (status == operations_research::MPSolver::INFEASIBLE) {
LOG(ERROR) << "MPSolver could not find any feasible solution.";
Expand Down

0 comments on commit 40062bf

Please sign in to comment.