Skip to content

Commit

Permalink
Docs of main, formating
Browse files Browse the repository at this point in the history
  • Loading branch information
yycdavid committed Aug 28, 2020
1 parent 7e776fe commit bf8dd37
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 135 deletions.
7 changes: 6 additions & 1 deletion src/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ use egg::*;
const SEQ_LENGTH: i32 = 64;
const HIDDEN_DIMS: i32 = 1024;

fn attention(graph: &mut GraphConverter, input: TensorInfo, heads: i32, input_dim_1: i32) -> (TensorInfo, TensorInfo) {
fn attention(
graph: &mut GraphConverter,
input: TensorInfo,
heads: i32,
input_dim_1: i32,
) -> (TensorInfo, TensorInfo) {
let d_model = input_dim_1;
let d_k = d_model / heads;
assert!(input_dim_1 % heads == 0);
Expand Down
56 changes: 42 additions & 14 deletions src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ pub struct GraphConverter {
name_gen: NameGen,
}

/// Struct for storing information of a tensor. This is passed between functions
/// Struct for storing information of a tensor. This is passed between functions
/// during graph creation.
#[derive(Copy, Clone, Default)]
pub struct TensorInfo {
/// Id into the RecExpr constructed
pub id: Id,
pub id: Id,
/// Shape of the tensor. We deal with tensor up to MAX_DIM dimensions
pub shape: [i32; MAX_DIM],
/// Number of dimensions of this tensor
Expand Down Expand Up @@ -98,7 +98,9 @@ impl GraphConverter {
let kernel_h = wght.shape[2];
let kernel_w = wght.shape[3];

let (output_h, output_w) = self.get_conv_shape(input_h, input_w, stride_h, stride_w, kernel_h, kernel_w, padding);
let (output_h, output_w) = self.get_conv_shape(
input_h, input_w, stride_h, stride_w, kernel_h, kernel_w, padding,
);
shape[0] = inpt.shape[0];
shape[1] = wght.shape[0];
shape[2] = output_h;
Expand Down Expand Up @@ -133,7 +135,7 @@ impl GraphConverter {

pub fn sigmoid(&mut self, inpt: TensorInfo) -> TensorInfo {
let new_node = Mdl::Sigmoid(inpt.id);

TensorInfo {
id: self.rec_expr.add(new_node),
shape: inpt.shape,
Expand All @@ -143,7 +145,7 @@ impl GraphConverter {

pub fn add(&mut self, inpt_1: TensorInfo, inpt_2: TensorInfo) -> TensorInfo {
let new_node = Mdl::Ewadd([inpt_1.id, inpt_2.id]);

TensorInfo {
id: self.rec_expr.add(new_node),
shape: inpt_1.shape,
Expand All @@ -170,15 +172,21 @@ impl GraphConverter {

pub fn mul(&mut self, inpt_1: TensorInfo, inpt_2: TensorInfo) -> TensorInfo {
let new_node = Mdl::Ewmul([inpt_1.id, inpt_2.id]);

TensorInfo {
id: self.rec_expr.add(new_node),
shape: inpt_1.shape,
n_dim: inpt_1.n_dim,
}
}

pub fn concat(&mut self, axis: i32, ndim: i32, inpt_1: TensorInfo, inpt_2: TensorInfo) -> TensorInfo {
pub fn concat(
&mut self,
axis: i32,
ndim: i32,
inpt_1: TensorInfo,
inpt_2: TensorInfo,
) -> TensorInfo {
// Only support concat of 2 inputs for now
// To support more, pass in a slice and create more concat nodes here
let axis_id = self.add_or_get_val(axis);
Expand All @@ -199,21 +207,29 @@ impl GraphConverter {

pub fn concat_multi(&mut self, axis: i32, ndim: i32, inputs: &[TensorInfo]) -> TensorInfo {
let n_inputs = inputs.len();
// We can add supports for other number of inputs later when needed.
// We can add supports for other number of inputs later when needed.
// We need to add a new Concat op for each number of inputs
assert!(n_inputs == 5);

let axis_id = self.add_or_get_val(axis);
let ndim_id = self.add_or_get_val(ndim);

let new_node = Mdl::Concat5([axis_id, ndim_id, inputs[0].id, inputs[1].id, inputs[2].id, inputs[3].id, inputs[4].id]);
let new_node = Mdl::Concat5([
axis_id,
ndim_id,
inputs[0].id,
inputs[1].id,
inputs[2].id,
inputs[3].id,
inputs[4].id,
]);

let mut shape = inputs[0].shape;
let n_dim = inputs[0].n_dim;
for i in 1..n_inputs {
shape[axis as usize] += inputs[i].shape[axis as usize];
}

TensorInfo {
id: self.rec_expr.add(new_node),
shape: shape,
Expand Down Expand Up @@ -253,7 +269,9 @@ impl GraphConverter {
let input_h = inpt.shape[2];
let input_w = inpt.shape[3];

let (output_h, output_w) = self.get_conv_shape(input_h, input_w, stride_h, stride_w, kernel_h, kernel_w, padding);
let (output_h, output_w) = self.get_conv_shape(
input_h, input_w, stride_h, stride_w, kernel_h, kernel_w, padding,
);
shape[0] = inpt.shape[0];
shape[1] = inpt.shape[1];
shape[2] = output_h;
Expand Down Expand Up @@ -298,7 +316,9 @@ impl GraphConverter {
let input_h = inpt.shape[2];
let input_w = inpt.shape[3];

let (output_h, output_w) = self.get_conv_shape(input_h, input_w, stride_h, stride_w, kernel_h, kernel_w, padding);
let (output_h, output_w) = self.get_conv_shape(
input_h, input_w, stride_h, stride_w, kernel_h, kernel_w, padding,
);
shape[0] = inpt.shape[0];
shape[1] = inpt.shape[1];
shape[2] = output_h;
Expand All @@ -312,7 +332,6 @@ impl GraphConverter {
}

pub fn enlarge(&mut self, inpt_1: TensorInfo, inpt_2: TensorInfo) -> TensorInfo {

let mut shape = inpt_1.shape;
shape[2] = inpt_2.shape[2];
shape[3] = inpt_2.shape[3];
Expand Down Expand Up @@ -417,7 +436,16 @@ impl GraphConverter {
(shape, dims.len())
}

fn get_conv_shape(&self, input_h: i32, input_w: i32, stride_h: i32, stride_w: i32, kernel_h: i32, kernel_w: i32, padding: i32) -> (i32, i32) {
fn get_conv_shape(
&self,
input_h: i32,
input_w: i32,
stride_h: i32,
stride_w: i32,
kernel_h: i32,
kernel_w: i32,
padding: i32,
) -> (i32, i32) {
if padding == PSAME {
let output_h = (input_h + stride_h - 1) / stride_h;
let output_w = (input_w + stride_w - 1) / stride_w;
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ pub mod benchnet;
pub mod bert;
pub mod input;
pub mod model;
pub mod nasneta;
pub mod nasrnn;
pub mod optimize;
pub mod parse;
pub mod resnet50;
pub mod resnext50;
pub mod rewrites;
pub mod testnet;
pub mod nasneta;

pub mod verify {
use crate::model::*;
Expand Down
32 changes: 23 additions & 9 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use std::time::{Duration, Instant};
use tamago::benchnet;
use tamago::bert;
use tamago::model::*;
use tamago::nasneta;
use tamago::nasrnn;
use tamago::optimize::*;
use tamago::resnet50;
use tamago::resnext50;
use tamago::rewrites::*;
use tamago::testnet;
use tamago::nasneta;
use tamago::{parse::*, verify::*};

use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -196,8 +196,7 @@ fn convert_learned_rules(matches: clap::ArgMatches) {
write(outf, converted).expect("Unable to write file");
}

fn test(matches: clap::ArgMatches) {
}
fn test(matches: clap::ArgMatches) {}

/// Main procedure to run optimization
///
Expand Down Expand Up @@ -255,6 +254,7 @@ fn optimize(matches: clap::ArgMatches) {
let learned_rules =
read_to_string(rule_file).expect("Something went wrong reading the rule file");
let pre_defined_multi = PRE_DEFINED_MULTI.iter().map(|&x| (x, /*symmetric=*/ false));
// The learned rules we have are symmetric. Predefined ones are not
let multi_rules: Vec<(&str, bool)> = learned_rules
.split("\n")
.map(|x| (x, /*symmetric=*/ true))
Expand Down Expand Up @@ -302,6 +302,7 @@ fn optimize(matches: clap::ArgMatches) {
let start_time = Instant::now();
let mut runner = runner.run(&rules[..]);
if do_filter_after {
// Do cycle removal after the final iteration
remove_cycle_by_order(&mut runner);
}
let sat_duration = start_time.elapsed();
Expand All @@ -314,7 +315,8 @@ fn optimize(matches: clap::ArgMatches) {
println!(" Time taken: {:?}", sat_duration);
println!(" Number of iterations: {:?}", num_iter_sat);

let (num_enodes, num_classes, avg_nodes_per_class, num_edges, num_programs) = get_stats(&runner.egraph);
let (num_enodes, num_classes, avg_nodes_per_class, num_edges, num_programs) =
get_stats(&runner.egraph);
println!(" Average nodes per class: {}", avg_nodes_per_class);
println!(" Number of edges: {}", num_edges);
println!(" Number of programs: {}", num_programs);
Expand All @@ -327,7 +329,9 @@ fn optimize(matches: clap::ArgMatches) {

// Run extraction
let extract_mode = matches.value_of("extract").unwrap();
let cost_model = CostModel::with_setting(/*ignore_all_weight_only=*/matches.is_present("all_weight_only"));
let cost_model = CostModel::with_setting(
/*ignore_all_weight_only=*/ matches.is_present("all_weight_only"),
);
let (best, ext_secs) = match extract_mode {
"ilp" => extract_by_ilp(&egraph, root, &matches, &cost_model),
"greedy" => {
Expand Down Expand Up @@ -375,7 +379,8 @@ fn optimize(matches: clap::ArgMatches) {
.open(outf)
.unwrap();

// Stats to write: original runtime, optimized runtime, saturation time, extraction time, number of nodes, number of eclasses, number of possible programs
// Stats to write: original runtime, optimized runtime, saturation time, extraction time,
// number of nodes, number of eclasses, number of possible programs
let data = json!({
"original": time_start,
"optimized": time_ext,
Expand All @@ -386,7 +391,8 @@ fn optimize(matches: clap::ArgMatches) {
"programs": num_programs,
"iter": num_iter_sat,
});
let sol_data_str = serde_json::to_string(&data).expect("Fail to convert json to string");
let sol_data_str =
serde_json::to_string(&data).expect("Fail to convert json to string");

if let Err(e) = writeln!(file, "{}", sol_data_str) {
eprintln!("Couldn't write to file: {}", e);
Expand Down Expand Up @@ -530,8 +536,16 @@ fn get_stats(egraph: &EGraph<Mdl, TensorAnalysis>) -> (usize, usize, f32, usize,
let num_edges = egraph
.classes()
.fold(0, |acc, c| c.iter().fold(0, |sum, n| n.len() + sum) + acc);
let num_programs = egraph.classes().fold(0.0, |acc, c| acc + (c.len() as f32).log2());
(num_enodes, num_classes, avg_nodes_per_class, num_edges, num_programs)
let num_programs = egraph
.classes()
.fold(0.0, |acc, c| acc + (c.len() as f32).log2());
(
num_enodes,
num_classes,
avg_nodes_per_class,
num_edges,
num_programs,
)
}

fn get_full_graph_runtime(runner: &Runner<Mdl, TensorAnalysis, ()>, process: bool) -> f32 {
Expand Down
9 changes: 6 additions & 3 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub struct TensorAnalysis {
pub graph: std::cell::RefCell<Box<Graph>>,
/// Record blacklisted nodes for filtering cycles
pub blacklist_nodes: HashSet<Mdl>,
/// Newly added nodes by order during single output rule application
/// Newly added nodes by order
pub newly_added: Vec<Mdl>,
}

Expand Down Expand Up @@ -404,7 +404,11 @@ impl Analysis<Mdl> for TensorAnalysis {
let t_4 = x(input4).meta;
let t_5 = x(input5).meta;
let axis_val = x(axis).val;
let all_weights = x(input1).all_weights && x(input2).all_weights && x(input3).all_weights && x(input4).all_weights && x(input5).all_weights;
let all_weights = x(input1).all_weights
&& x(input2).all_weights
&& x(input3).all_weights
&& x(input4).all_weights
&& x(input5).all_weights;

// Create tensorhandle and get metadata
let t = [t_1, t_2, t_3, t_4, t_5];
Expand Down Expand Up @@ -513,7 +517,6 @@ impl Analysis<Mdl> for TensorAnalysis {
}
}


Mdl::Split([axis, inpt]) => {
// Check types
assert!(x(axis).dtype == DataKind::Scalar);
Expand Down
Loading

0 comments on commit bf8dd37

Please sign in to comment.