Skip to content

Commit

Permalink
wavenet pulse optims are back
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Mar 15, 2019
1 parent 2e692e6 commit 970593e
Show file tree
Hide file tree
Showing 20 changed files with 80 additions and 74 deletions.
19 changes: 10 additions & 9 deletions cli/src/main.rs
Expand Up @@ -375,6 +375,8 @@ impl Parameters {
None
};

let pulse: Option<usize> = matches.value_of("pulse").map(|s| s.parse()).inside_out()?;

let mut tract_model = if !matches.is_present("skip_analyse") {
info!("Running analyse");
SomeModel::Typed(raw_model.into_typed()?)
Expand All @@ -383,12 +385,10 @@ impl Parameters {
SomeModel::Inference(raw_model)
};

let pulse: Option<usize> = matches.value_of("pulse").map(|s| s.parse()).inside_out()?;

if matches.is_present("optimize") || pulse.is_some() {
if let SomeModel::Typed(typed) = tract_model {
info!("Optimize");
tract_model = SomeModel::Typed(typed.optimize()?);
tract_model = SomeModel::Typed(typed.declutter()?);
} else {
bail!("Can not run optimize without analyse")
}
Expand All @@ -397,17 +397,18 @@ impl Parameters {
let pulse_facts = if let (Some(pulse), &SomeModel::Typed(ref model)) = (pulse, &tract_model) {
info!("Pulsify {}", pulse);
let (model, ifact, ofact) = ::tract_core::pulse::pulsify(&model.clone().into_normalized()?, pulse)?;
if matches.is_present("optimize") {
info!("Optimize pulsing network");
tract_model = SomeModel::Typed(model.into_typed()?.optimize()?);
} else {
tract_model = SomeModel::Typed(model.into_typed()?);
};
tract_model = SomeModel::Typed(model.into_typed()?);
Some((ifact, ofact))
} else {
None
};

if matches.is_present("optimize") {
if let SomeModel::Typed(typed) = tract_model {
tract_model = SomeModel::Typed(typed.codegen()?);
}
}

info!("Model ready");

Ok(Parameters {
Expand Down
46 changes: 25 additions & 21 deletions core/src/model/mod.rs
Expand Up @@ -62,21 +62,17 @@ impl InferenceModel {
compact::compact(&mut self)
}

pub fn into_normalized(self) -> TractResult<NormalizedModel> {
self.into_typed()?.into_normalized()
}

pub fn into_optimized(self) -> TractResult<NormalizedModel> {
self.into_normalized()?.into_codegen()
pub fn into_optimized(self) -> TractResult<TypedModel> {
self.into_typed()?.declutter()?.codegen()
}
}

impl TypedModel {
pub fn optimize(self) -> TractResult<TypedModel> {
pub fn declutter(self) -> TractResult<TypedModel> {
let mut model = self;
loop {
let mut done_something = false;
for p in crate::optim::normalization() {
for p in crate::optim::declutter() {
done_something = done_something || p.pass(&mut model)?;
if cfg!(debug_assertions) {
model.check_edges()?;
Expand All @@ -90,29 +86,37 @@ impl TypedModel {
Ok(model)
}

pub fn into_normalized(self) -> TractResult<NormalizedModel> {
let model = self.optimize()?;
compact::compact(&model)
}
}

impl NormalizedModel {
pub fn into_typed(self) -> TractResult<TypedModel> {
compact::compact(&self)
}
pub fn into_codegen(mut self) -> TractResult<NormalizedModel> {
pub fn codegen(self) -> TractResult<TypedModel> {
let mut model = self;
loop {
let mut done_something = false;
for p in crate::optim::codegen() {
done_something = done_something || p.pass(&mut self)?;
done_something = done_something || p.pass(&mut model)?;
if cfg!(debug_assertions) {
self.check_edges()?;
model.check_edges()?;
}
}
if !done_something {
break;
}
model = compact::compact(&model)?;
}
Ok(model)
}

pub fn into_normalized(self) -> TractResult<NormalizedModel> {
let model = self.declutter()?;
compact::compact(&model)
}

pub fn into_optimized(self) -> TractResult<TypedModel> {
let model = self.codegen()?;
compact::compact(&model)
}
}

impl NormalizedModel {
pub fn into_typed(self) -> TractResult<TypedModel> {
compact::compact(&self)
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/model/tensor_info.rs
Expand Up @@ -146,7 +146,7 @@ impl TryInto<NormalizedTensorInfo> for TypedTensorInfo {
None => {
Ok(NormalizedTensorInfo { shape: self.shape.clone(), datum_type: self.datum_type })
}
_ => bail!("Constant tensor are excluded from normalized stage: {:?}", self),
_ => bail!("Constant tensor are excluded from declutterd stage: {:?}", self),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/squeeze.rs
Expand Up @@ -34,7 +34,7 @@ impl Op for Squeeze {
"Squeeze".into()
}

fn normalize(
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/macros.rs
Expand Up @@ -164,7 +164,7 @@ macro_rules! element_bin {
concat!(stringify!($name), "::Binary").into()
}

fn normalize(&self, model: &$crate::model::TypedModel, node: &$crate::model::TypedNode)
fn declutter(&self, model: &$crate::model::TypedModel, node: &$crate::model::TypedNode)
-> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
if let Some(b) = inputs[1].konst.clone() {
Expand Down
8 changes: 4 additions & 4 deletions core/src/ops/math/mat_mul.rs
Expand Up @@ -238,14 +238,14 @@ impl Op for MatMulUnaryA {

fn codegen(
&self,
model: &NormalizedModel,
node: &NormalizedNode,
) -> TractResult<Option<NormalizedModelPatch>> {
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
if let Some(a_shape) = inputs[0].shape.as_finite() {
let dt = inputs[0].datum_type;
if let Some(op) = dispatch_floatlike!(Self::codegen(dt)(self, &*a_shape))? {
return Ok(Some(NormalizedModelPatch::single_unary_op(model, node, op)?));
return Ok(Some(TypedModelPatch::single_unary_op(model, node, op)?));
}
}
Ok(None)
Expand Down
8 changes: 4 additions & 4 deletions core/src/ops/mod.rs
Expand Up @@ -118,7 +118,7 @@ pub trait Op:
Ok((infered_inputs, infered_outputs))
}

fn normalize(
fn declutter(
&self,
_model: &TypedModel,
_node: &TypedNode,
Expand All @@ -135,9 +135,9 @@ pub trait Op:

fn codegen(
&self,
_model: &NormalizedModel,
_node: &NormalizedNode,
) -> TractResult<Option<NormalizedModelPatch>> {
_model: &TypedModel,
_node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
Ok(None)
}

Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/nn/avgpool.rs
Expand Up @@ -42,7 +42,7 @@ impl Op for AvgPool {
"AvgPool".into()
}

fn codegen(&self, model: &NormalizedModel, node: &NormalizedNode) -> TractResult<Option<NormalizedModelPatch>> {
fn codegen(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
if let Some(shape) = inputs[0].shape.as_finite() {
let dt = inputs[0].datum_type;
Expand All @@ -55,7 +55,7 @@ impl Op for AvgPool {
Box::new(FixedAvgPool::new(patch, count_include_pad))
}
let op = dispatch_floatlike!(fixed(dt)(patch, self.count_include_pad));
return Ok(Some(NormalizedModelPatch::single_unary_op(model, node, op)?));
return Ok(Some(TypedModelPatch::single_unary_op(model, node, op)?));
}
Ok(None)
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/nn/conv/gen.rs
Expand Up @@ -96,7 +96,7 @@ impl Op for Conv {
"Conv".into()
}

fn normalize(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
fn declutter(&self, model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
if let Some(op) = self.to_unary(inputs)? {
return Ok(Some(TypedModelPatch::single_unary_op(model, node, op)?));
Expand Down
19 changes: 10 additions & 9 deletions core/src/ops/nn/conv/unary.rs
Expand Up @@ -312,7 +312,7 @@ impl Op for ConvUnary {
"ConvUnary".into()
}

fn normalize(
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
Expand All @@ -337,9 +337,9 @@ impl Op for ConvUnary {

fn codegen(
&self,
model: &NormalizedModel,
node: &NormalizedNode,
) -> TractResult<Option<NormalizedModelPatch>> {
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
let spatial_rank = self.full_input_shape.len() - 2;
let kernel_spatial_shape = &self.kernel.shape()[self.kernel_fmt.h_axis()..][..spatial_rank];
Expand All @@ -354,7 +354,7 @@ impl Op for ConvUnary {
use crate::ops::math::mat_mul::MatMulUnaryA;
let kernel_shape = &self.kernel.shape()[spatial_rank..];
let kernel = self.kernel.clone().into_shape(&kernel_shape)?;
return Ok(Some(NormalizedModelPatch::single_unary_op(
return Ok(Some(TypedModelPatch::single_unary_op(
model,
node,
MatMulUnaryA::new(kernel),
Expand All @@ -369,18 +369,19 @@ impl Op for ConvUnary {
&& self.bias.is_none()
{
let op = self.to_direct(&*shape)?;
return Ok(Some(NormalizedModelPatch::single_unary_op(model, node, op)?));
return Ok(Some(TypedModelPatch::single_unary_op(model, node, op)?));
} else {
let (op1, shape, op2) =
dispatch_floatlike!(Self::to_boxed_im2col_pair(dt)(self, &shape))?;
let mut patch = NormalizedModelPatch::default();
let mut patch = TypedModelPatch::default();
let _ = patch.tap_model(&model, node.inputs[0])?;
patch.chain_facts(
format!("{}-im2col", node.name),
op1,
tvec!(NormalizedTensorInfo {
tvec!(TypedTensorInfo {
shape: ShapeInfo::from(&*shape),
datum_type: dt
datum_type: dt,
konst: None,
}),
)?;
let mm = patch.chain_facts(
Expand Down
14 changes: 7 additions & 7 deletions core/src/optim/mod.rs
@@ -1,4 +1,4 @@
use crate::model::{NormalizedModel, TypedModel};
use crate::model::TypedModel;
use crate::TractResult;
use std::fmt::Debug;

Expand All @@ -8,15 +8,15 @@ mod push_split_down;
use self::prop_const::PropConst;
use self::push_split_down::PushSplitDown;

pub trait NormalizationPass: Debug + Send + Sync {
pub trait DeclutterPass: Debug + Send + Sync {
fn pass(&self, model: &mut TypedModel) -> TractResult<bool>;
}

pub trait CodegenPass: Debug + Send + Sync {
fn pass(&self, model: &mut NormalizedModel) -> TractResult<bool>;
fn pass(&self, model: &mut TypedModel) -> TractResult<bool>;
}

pub fn normalization() -> Vec<Box<NormalizationPass>> {
pub fn declutter() -> Vec<Box<DeclutterPass>> {
vec![Box::new(PropConst) as _, Box::new(NormalizeOps)]
}

Expand All @@ -27,7 +27,7 @@ pub fn codegen() -> Vec<Box<CodegenPass>> {
#[derive(Debug)]
pub struct NormalizeOps;

impl NormalizationPass for NormalizeOps {
impl DeclutterPass for NormalizeOps {
fn pass(&self, model: &mut TypedModel) -> TractResult<bool> {
let mut done_something = false;
loop {
Expand All @@ -42,7 +42,7 @@ impl NormalizationPass for NormalizeOps {
node.op().name()
);
node.op
.normalize(model, node)
.declutter(model, node)
.map_err(|e| format!("{:?} node {:?}, {:?}", self, node, e))?
};
if let Some(red) = reduced {
Expand Down Expand Up @@ -76,7 +76,7 @@ impl NormalizationPass for NormalizeOps {
pub struct CodegenOps;

impl CodegenPass for CodegenOps {
fn pass(&self, model: &mut NormalizedModel) -> TractResult<bool> {
fn pass(&self, model: &mut TypedModel) -> TractResult<bool> {
let mut done_something = false;
loop {
let mut done_something_this_time = false;
Expand Down
2 changes: 1 addition & 1 deletion core/src/optim/prop_const.rs
Expand Up @@ -5,7 +5,7 @@ use bit_set;
#[derive(Debug)]
pub struct PropConst;

impl super::NormalizationPass for PropConst {
impl super::DeclutterPass for PropConst {
fn pass(&self, model: &mut TypedModel) -> TractResult<bool> {
let mut replaced = 0;
let mut done = bit_set::BitSet::with_capacity(model.nodes().len());
Expand Down
4 changes: 2 additions & 2 deletions core/src/optim/push_split_down.rs
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use crate::model::{InletId, NormalizedModel, OutletId};
use crate::model::{InletId, OutletId};
use crate::ops::prelude::*;

use itertools::Itertools;
Expand All @@ -9,7 +9,7 @@ use itertools::Itertools;
pub struct PushSplitDown;

impl super::CodegenPass for PushSplitDown {
fn pass(&self, model: &mut NormalizedModel) -> TractResult<bool> {
fn pass(&self, model: &mut TypedModel) -> TractResult<bool> {
let mut done_something = false;
loop {
let mut remap = HashMap::<usize, usize>::new();
Expand Down
2 changes: 1 addition & 1 deletion core/src/optim/reduce.rs
Expand Up @@ -33,7 +33,7 @@ impl super::OptimizerPass for Reduce {
node.outputs.iter().map(|o| &o.fact).collect();
*/
match self.0 {
ReductionPhase::Normalize => node.op.normalize(&model, &node),
ReductionPhase::Normalize => node.op.declutter(&model, &node),
ReductionPhase::Codegen => node.op.codegen(&model, &node),
}.map_err(|e| format!("{:?} node {:?}, {:?}", self.0, node, e))?
};
Expand Down
8 changes: 4 additions & 4 deletions core/src/pulse/mod.rs
Expand Up @@ -135,7 +135,7 @@ mod tests {
let _a = model
.add_source_fact("a", TensorFact::dt_shape(DatumType::F32, vec![1, 2, 3]))
.unwrap();
assert!(PulsifiedModel::new(&model.into_normalized().unwrap(), 4).is_err());
assert!(PulsifiedModel::new(&model.into_declutterd().unwrap(), 4).is_err());

let mut model = Model::default();
let _a = model
Expand All @@ -144,7 +144,7 @@ mod tests {
TensorFact::dt_shape(DatumType::F32, vec![1.to_dim(), TDim::s(), 3.to_dim()]),
)
.unwrap();
let pulse = PulsifiedModel::new(&model.into_normalized().unwrap(), 4).unwrap();
let pulse = PulsifiedModel::new(&model.into_declutterd().unwrap(), 4).unwrap();
assert_eq!(
pulse.model.fact(OutletId::new(0, 0)).unwrap().to_tensor_fact(),
TensorFact::dt_shape(DatumType::F32, vec!(1, 4, 3))
Expand All @@ -161,7 +161,7 @@ mod tests {
)
.unwrap();

let pulse = PulsifiedModel::new(&model.into_normalized().unwrap(), 4).unwrap();
let pulse = PulsifiedModel::new(&model.into_declutterd().unwrap(), 4).unwrap();

assert_eq!(
pulse.model.input_fact().unwrap().to_tensor_fact(),
Expand Down Expand Up @@ -191,7 +191,7 @@ mod tests {
let input = [1.0f32, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
let t_input = Tensor::from(arr3(&[[input]]));

let model = model.into_normalized().unwrap();
let model = model.into_declutterd().unwrap();

assert_eq!(model.nodes().len(), 2);
let plan = crate::plan::SimplePlan::new(&model).unwrap();
Expand Down

0 comments on commit 970593e

Please sign in to comment.