Skip to content

Commit

Permalink
Add flag for filter cycle and multi-pattern iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
yycdavid committed Aug 12, 2020
1 parent fcf21ce commit bda00b7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
3 changes: 3 additions & 0 deletions extractor/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def main():
print('The problem does not have an optimal solution.')
print(status)
print('Objective value =', solver.Objective().Value())
print('Problem solved in %f milliseconds' % solver.wall_time())
print('Problem solved in %d iterations' % solver.iterations())
print('Problem solved in %d branch-and-bound nodes' % solver.nodes())

# Store results
solved_x = [int(x[j].solution_value()) for j in range(num_nodes)]
Expand Down
18 changes: 16 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,18 @@ fn main() {
.takes_value(true)
.help("Number of threads for ILP solver"),
)
.arg(
Arg::with_name("iter_multi")
.long("iter_multi")
.takes_value(true)
.default_value("1")
.help("Max number of iterations to apply multi-pattern rules"),
)
.arg(
Arg::with_name("no_cycle")
.long("no_cycle")
.help("Not allowing cycles in EGraph"),
)
.get_matches();

let run_mode = matches.value_of("mode").unwrap();
Expand Down Expand Up @@ -222,15 +234,17 @@ fn optimize(matches: clap::ArgMatches) {

// Get multi-pattern rules. learned_rules are the learned rules from TASO,
// pre_defined_multi are the hand-specified rules from TASO
let no_cycle = matches.is_present("no_cycle");
let iter_multi = matches.value_of("iter_multi").unwrap().parse::<usize>().unwrap();
let multi_patterns = if let Some(rule_file) = matches.value_of("multi_rules") {
let learned_rules =
read_to_string(rule_file).expect("Something went wrong reading the rule file");
let pre_defined_multi = PRE_DEFINED_MULTI.iter().map(|&x| x);
let multi_rules: Vec<&str> = learned_rules.split("\n").chain(pre_defined_multi).collect();
MultiPatterns::with_rules(multi_rules)
MultiPatterns::with_rules(multi_rules, no_cycle, iter_multi)
} else {
let multi_rules: Vec<&str> = PRE_DEFINED_MULTI.iter().map(|&x| x).collect();
MultiPatterns::with_rules(multi_rules)
MultiPatterns::with_rules(multi_rules, no_cycle, iter_multi)
};

// Run saturation
Expand Down
42 changes: 27 additions & 15 deletions src/rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::model::*;
use egg::{rewrite as rw, *};
use itertools::Itertools;
use root::taso::*;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::convert::TryInto;
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -732,6 +732,10 @@ pub struct MultiPatterns {
canonical_src_pat: Vec<Pattern<Mdl>>,
/// Mapping information for each src pattern. The order is the same as in rules
src_pat_maps: Vec<(MapToCanonical, MapToCanonical)>,
/// Whether to allow cycles in EGraph
no_cycle: bool,
/// Number of iterations to run multi-pattern rules
iter_limit: usize,
}

impl MultiPatterns {
Expand All @@ -740,7 +744,7 @@ impl MultiPatterns {
/// # Parameters
///
/// - `rules`: every adjacent pair of entries should belong to the same multi-pattern rule.
pub fn with_rules(rules: Vec<&str>) -> MultiPatterns {
pub fn with_rules(rules: Vec<&str>, no_cycle: bool, iter_limit: usize) -> MultiPatterns {
assert!(rules.len() % 2 == 0);

let mut multi_rules =
Expand Down Expand Up @@ -788,6 +792,8 @@ impl MultiPatterns {
rules: multi_rules,
canonical_src_pat: canonical_pats,
src_pat_maps: src_pat_maps,
no_cycle: no_cycle,
iter_limit: iter_limit,
}
}

Expand All @@ -798,7 +804,7 @@ impl MultiPatterns {
/// it checks and applies the dst patterns. It won't apply if src_1 and src_2 matches with
/// the same eclass. It always returns Ok()
pub fn run_one(&self, runner: &mut Runner<Mdl, TensorAnalysis, ()>) -> Result<(), String> {
if runner.iterations.len() < 2 {
if runner.iterations.len() < self.iter_limit {
println!("Run one");
// Construct Vec to store matches for each canonicalized pattern
let matches: Vec<Vec<SearchMatches>> = self
Expand Down Expand Up @@ -840,6 +846,7 @@ impl MultiPatterns {
map_2: &MapToCanonical,
runner: &mut Runner<Mdl, TensorAnalysis, ()>,
) {
let mut descendents: HashMap<Id, HashSet<Id>> = Default::default();
for subst_1 in &match_1.substs {
for subst_2 in &match_2.substs {
// De-canonicalize the substitutions
Expand All @@ -849,21 +856,26 @@ impl MultiPatterns {
if compatible(&subst_1_dec, &subst_2_dec, &map_1.var_map) {
// If so, merge two substitutions
let merged_subst = merge_subst(subst_1_dec, subst_2_dec, &map_1.var_map);

// check_pat on both dst patterns
if check_pat(rule.2.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {
if check_pat(rule.3.ast.as_ref(), &mut runner.egraph, &merged_subst).0 {
// apply dst patterns, union
let id_1 =
rule.2
.apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)
[0];
runner.egraph.union(id_1, match_1.eclass);
let id_2 =
rule.3
.apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)
[0];
runner.egraph.union(id_2, match_2.eclass);
let cycle_check_passed = if self.no_cycle {
println!("No cycle");
true
} else { true };
if cycle_check_passed {
// apply dst patterns, union
let id_1 =
rule.2
.apply_one(&mut runner.egraph, match_1.eclass, &merged_subst)
[0];
runner.egraph.union(id_1, match_1.eclass);
let id_2 =
rule.3
.apply_one(&mut runner.egraph, match_2.eclass, &merged_subst)
[0];
runner.egraph.union(id_2, match_2.eclass);
}
}
}
}
Expand Down

0 comments on commit bda00b7

Please sign in to comment.