Skip to content

Commit

Permalink
Save greedy solution for ILP
Browse files Browse the repository at this point in the history
  • Loading branch information
yycdavid committed Aug 11, 2020
1 parent 625f4a3 commit 329cd6a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/main.rs
Expand Up @@ -121,6 +121,11 @@ fn main() {
.long("no_order")
.help("No ordering constraints in ILP"),
)
.arg(
Arg::with_name("initial_with_greedy")
.long("initial_with_greedy")
.help("Initialize ILP with greedy solution"),
)
.get_matches();

let run_mode = matches.value_of("mode").unwrap();
Expand Down Expand Up @@ -336,6 +341,24 @@ fn extract_by_ilp(
create_dir_all("./tmp");
write("./tmp/ilp_data.json", data_str).expect("Unable to write file");

let initialize = matches.is_present("initial_with_greedy");
if initialize {
// Get node_to_i map
let node_to_i: HashMap<Mdl, usize> = (&i_to_nodes).iter().enumerate().map(|(i, node)| (node.clone(), i)).collect();

let tnsr_cost = TensorCost { egraph: egraph };
let mut extractor = Extractor::new(egraph, tnsr_cost);
let (i_list, m_list) = get_init_solution(egraph, root, &extractor.costs, &g_i, &node_to_i);

// Store initial solution
let solution_data = json!({
"i_list": i_list,
"m_list": m_list,
});
let sol_data_str = serde_json::to_string(&solution_data).expect("Fail to convert json to string");
write("./tmp/init_sol.json", sol_data_str).expect("Unable to write file");
}

// Call python script to run ILP
let order_var_int = matches.is_present("order_var_int");
let class_constraint = matches.is_present("class_constraint");
Expand All @@ -350,6 +373,9 @@ fn extract_by_ilp(
if no_order {
arg_vec.push("--no_order");
}
if initialize {
arg_vec.push("--initialize")
}
let child = Command::new("python")
.args(&arg_vec)
.spawn()
Expand Down
37 changes: 37 additions & 0 deletions src/optimize.rs
Expand Up @@ -3,6 +3,7 @@ use egg::*;
use root::taso::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::HashSet;
use std::convert::TryInto;
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -407,3 +408,39 @@ pub fn construct_best_rec(
}
}
}


pub fn get_init_solution(
egraph: &EGraph<Mdl, TensorAnalysis>,
root: Id,
costs: &HashMap<Id, (f32, Mdl)>,
g_i: &[usize],
nodes_to_i: &HashMap<Mdl, usize>,
) -> (
Vec<usize>,
Vec<usize>,
) {
let mut nodes: Vec<Mdl> = Vec::new();
// added_memo maps eclass id to id in expr
let mut added_memo: HashSet<Id> = Default::default();
get_init_rec(egraph, root, &mut added_memo, costs, &mut nodes);

let i_list: Vec<usize> = nodes.iter().map(|node| *nodes_to_i.get(node).unwrap()).collect();
let m_list: Vec<usize> = i_list.iter().map(|i| g_i[*i]).collect();

(i_list, m_list)
}

fn get_init_rec(egraph: &EGraph<Mdl, TensorAnalysis>, eclass: Id, added_memo: &mut HashSet<Id>, costs: &HashMap<Id, (f32, Mdl)>, nodes: &mut Vec<Mdl>) {
let id = egraph.find(eclass);

if !added_memo.contains(&id) {
let (_, best_node) = match costs.get(&id) {
Some(result) => result.clone(),
None => panic!("Failed to extract from eclass {}", id),
};
best_node.for_each(|child| get_init_rec(egraph, child, added_memo, costs, nodes));
nodes.push(best_node);
added_memo.insert(id);
}
}

0 comments on commit 329cd6a

Please sign in to comment.