Skip to content

Commit

Permalink
unary -> element_wise (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 17, 2019
1 parent ea6071a commit 1644a3f
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 55 deletions.
2 changes: 1 addition & 1 deletion core/src/lib.rs
Expand Up @@ -125,7 +125,7 @@ pub mod internal {
pub use crate::analyser::types::*;
pub use crate::datum::FloatLike;
pub use crate::dim::{DimLike, TDim, ToDim};
pub use crate::ops::unary::UnaryMiniOp;
pub use crate::ops::element_wise::ElementWiseMiniOp;
pub use crate::framework::*;
pub use crate::model::*;
pub use crate::ops::{
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cnn/conv/direct.rs
Expand Up @@ -50,7 +50,7 @@ impl Op for Direct {
))));
}
}
} else if let Some(op) = succ.op_as::<ops::unary::UnaryOp>() {
} else if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>() {
if let Some(op) = op.0.downcast_ref::<ops::math::ScalarMax>() {
return Ok(Some(tvec!(FusedSpec::Max(op.max.as_()))));
} else if let Some(op) = op.0.downcast_ref::<ops::math::ScalarMin>() {
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cnn/conv/mat_mat.rs
Expand Up @@ -141,7 +141,7 @@ where
))));
}
}
} else if let Some(op) = succ.op_as::<ops::unary::UnaryOp>() {
} else if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>() {
if let Some(op) = op.0.downcast_ref::<ops::math::ScalarMax>() {
return Ok(Some(tvec!(FusedSpec::Max(op.max.as_()))));
} else if let Some(op) = op.0.downcast_ref::<ops::math::ScalarMin>() {
Expand Down
24 changes: 12 additions & 12 deletions core/src/ops/unary.rs → core/src/ops/element_wise.rs
Expand Up @@ -2,17 +2,17 @@ use crate::internal::*;
use downcast_rs::Downcast;
use std::fmt;

pub trait UnaryMiniOp: fmt::Debug + objekt::Clone + Send + Sync + 'static + Downcast {
pub trait ElementWiseMiniOp: fmt::Debug + objekt::Clone + Send + Sync + 'static + Downcast {
fn name(&self) -> &'static str;
fn eval_in_place(&self, t: &mut Tensor) -> TractResult<()>;
}
clone_trait_object!(UnaryMiniOp);
downcast_rs::impl_downcast!(UnaryMiniOp);
clone_trait_object!(ElementWiseMiniOp);
downcast_rs::impl_downcast!(ElementWiseMiniOp);

#[derive(Debug, Clone)]
pub struct UnaryOp(pub Box<dyn UnaryMiniOp>);
pub struct ElementWiseOp(pub Box<dyn ElementWiseMiniOp>);

impl Op for UnaryOp {
impl Op for ElementWiseOp {
fn name(&self) -> Cow<str> {
format!("{}", self.0.name()).into()
}
Expand All @@ -26,15 +26,15 @@ impl Op for UnaryOp {
op_as_typed_op!();
}

impl StatelessOp for UnaryOp {
impl StatelessOp for ElementWiseOp {
fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
let mut t = args_1!(inputs).into_tensor();
self.0.eval_in_place(&mut t)?;
Ok(tvec!(t.into_arc_tensor()))
}
}

impl InferenceRulesOp for UnaryOp {
impl InferenceRulesOp for ElementWiseOp {
fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
s: &mut Solver<'r>,
Expand All @@ -51,7 +51,7 @@ impl InferenceRulesOp for UnaryOp {
inference_op_as_op!();
}

impl TypedOp for UnaryOp {
impl TypedOp for ElementWiseOp {
fn output_facts(&self, inputs: &[&TypedTensorInfo]) -> TractResult<TVec<TypedTensorInfo>> {
Ok(tvec!(TypedTensorInfo::dt_shape(inputs[0].datum_type, inputs[0].shape.clone())?))
}
Expand All @@ -74,11 +74,11 @@ impl TypedOp for UnaryOp {
}

#[macro_export]
macro_rules! unary {
macro_rules! element_wise {
($func:ident, $Op:ident $({$($var: ident : $var_typ: path),*})?, $( [$($typ:ident),*] => $f:expr),*) => {
#[derive(Debug, Clone)]
pub struct $Op { $( $(pub $var: $var_typ),* )? }
impl $crate::ops::unary::UnaryMiniOp for $Op {
impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
fn name(&self) -> &'static str {
stringify!($Op)
}
Expand All @@ -95,8 +95,8 @@ macro_rules! unary {
bail!("{} does not support {:?}", self.name(), t.datum_type());
}
}
pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::unary::UnaryOp {
$crate::ops::unary::UnaryOp(Box::new($Op { $( $($var),* )? } ))
pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
$crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? } ))
}
}
}
2 changes: 1 addition & 1 deletion core/src/ops/logic.rs
Expand Up @@ -19,7 +19,7 @@ bin_to_bool!(lesser_equal, LesserEqual, [bool, u8, i8, i16, i32, i64, f32, f64]
bin_to_bool!(greater, Greatser, [bool, u8, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a > b);
bin_to_bool!(greater_equal, GreaterEqual, [bool, u8, i8, i16, i32, i64, f32, f64] => |c, &a, &b | *c = a >= b);

unary!(not, Not, [bool] => |_, vs| vs.iter_mut().for_each(|a| *a = !*a));
element_wise!(not, Not, [bool] => |_, vs| vs.iter_mut().for_each(|a| *a = !*a));

#[derive(Debug, Clone, new, Default)]
pub struct Iff;
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/math/mat_mul.rs
Expand Up @@ -522,7 +522,7 @@ where
))));
}
}
} else if let Some(op) = succ.op_as::<ops::unary::UnaryOp>() {
} else if let Some(op) = succ.op_as::<ops::element_wise::ElementWiseOp>() {
if let Some(op) = op.0.downcast_ref::<ops::math::ScalarMax>() {
return Ok(Some(tvec!(FusedSpec::Max(op.max.as_()))));
} else if let Some(op) = op.0.downcast_ref::<ops::math::ScalarMin>() {
Expand Down
50 changes: 25 additions & 25 deletions core/src/ops/math/mod.rs
Expand Up @@ -39,48 +39,48 @@ fn flip_sub(_op: &dyn BinMiniOp, t: &Arc<Tensor>) -> Option<UnaryOp> {
Some(UnaryOp::new(Box::new(Add), Arc::new(t)))
}

unary!(abs, Abs, [f16, f32, i32] => |_, xs| xs.iter_mut().for_each(|x| *x = x.abs()));
unary!(exp, Exp, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.exp()));
unary!(ln, Ln, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.ln()));
unary!(sqrt, Sqrt, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.sqrt()));
unary!(recip, Recip, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.recip()));
unary!(rsqrt, Rsqrt, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.sqrt().recip()));
element_wise!(abs, Abs, [f16, f32, i32] => |_, xs| xs.iter_mut().for_each(|x| *x = x.abs()));
element_wise!(exp, Exp, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.exp()));
element_wise!(ln, Ln, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.ln()));
element_wise!(sqrt, Sqrt, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.sqrt()));
element_wise!(recip, Recip, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.recip()));
element_wise!(rsqrt, Rsqrt, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.sqrt().recip()));

unary!(ceil, Ceil, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.ceil()));
unary!(floor, Floor, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.floor()));
element_wise!(ceil, Ceil, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.ceil()));
element_wise!(floor, Floor, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.floor()));

unary!(scalar_min_max, ScalarMinMax { min: f32, max: f32 },
element_wise!(scalar_min_max, ScalarMinMax { min: f32, max: f32 },
[f32, f64] => |m, xs| xs.iter_mut().for_each(|x| *x = x.max(m.max as _).min(m.min as _))
);

unary!(scalar_min, ScalarMin { min: f32 },
element_wise!(scalar_min, ScalarMin { min: f32 },
[f32, f64] => |m, xs| xs.iter_mut().for_each(|x| *x = x.min(m.min as _))
);

unary!(scalar_max, ScalarMax { max: f32 },
element_wise!(scalar_max, ScalarMax { max: f32 },
[f32, f64] => |m, xs| xs.iter_mut().for_each(|x| *x = x.max(m.max as _))
);

unary!(cos, Cos, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.cos()));
unary!(sin, Sin, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.sin()));
unary!(tan, Tan, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.tan()));
unary!(acos, Acos, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.acos()));
unary!(asin, Asin, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.asin()));
unary!(atan, Atan, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.atan()));
element_wise!(cos, Cos, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.cos()));
element_wise!(sin, Sin, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.sin()));
element_wise!(tan, Tan, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.tan()));
element_wise!(acos, Acos, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.acos()));
element_wise!(asin, Asin, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.asin()));
element_wise!(atan, Atan, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.atan()));

unary!(cosh, Cosh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.cosh()));
unary!(sinh, Sinh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.sinh()));
unary!(tanh, Tanh,
element_wise!(cosh, Cosh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.cosh()));
element_wise!(sinh, Sinh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.sinh()));
element_wise!(tanh, Tanh,
[f32] => |_, xs| <f32 as FloatLike>::tanh().run(xs),
[f16, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.tanh())
);
unary!(acosh, Acosh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.acosh()));
unary!(asinh, Asinh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.asinh()));
unary!(atanh, Atanh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.atanh()));
element_wise!(acosh, Acosh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.acosh()));
element_wise!(asinh, Asinh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.asinh()));
element_wise!(atanh, Atanh, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = x.atanh()));

unary!(neg, Neg, [i8, i16, i32, i64, f16, f32, f64, TDim] => |_, xs| xs.iter_mut().for_each(|x| *x = -x.clone()));
element_wise!(neg, Neg, [i8, i16, i32, i64, f16, f32, f64, TDim] => |_, xs| xs.iter_mut().for_each(|x| *x = -x.clone()));

unary!(sign, Sign, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = if x.is_zero() { *x } else { x.signum() }));
element_wise!(sign, Sign, [f16, f32, f64] => |_, xs| xs.iter_mut().for_each(|x| *x = if x.is_zero() { *x } else { x.signum() }));

#[cfg(test)]
mod tests {
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/mod.rs
Expand Up @@ -8,7 +8,7 @@ use objekt;
#[macro_use]
pub mod macros;
#[macro_use]
pub mod unary;
pub mod element_wise;
#[macro_use]
pub mod binary;

Expand Down
20 changes: 10 additions & 10 deletions core/src/ops/nn/mod.rs
Expand Up @@ -16,35 +16,35 @@ use num_traits::{ AsPrimitive, Float};

pub use crate::internal::*;

unary!(softplus, Softplus, [f32] => |_, xs| xs.iter_mut().for_each(|x| *x = (x.exp() + 1.0).ln()));
unary!(softsign, Softsign, [f32] => |_, xs| xs.iter_mut().for_each(|x| *x = *x / (x.abs() + 1.0)));
unary!(sigmoid, Sigmoid, [f32] => |_, xs| f32::sigmoid().run(xs));
element_wise!(softplus, Softplus, [f32] => |_, xs| xs.iter_mut().for_each(|x| *x = (x.exp() + 1.0).ln()));
element_wise!(softsign, Softsign, [f32] => |_, xs| xs.iter_mut().for_each(|x| *x = *x / (x.abs() + 1.0)));
element_wise!(sigmoid, Sigmoid, [f32] => |_, xs| f32::sigmoid().run(xs));

unary!(elu, Elu { alpha: f32 },
element_wise!(elu, Elu { alpha: f32 },
[f32, f64] => |e, xs| xs.iter_mut().for_each(|x| { *x = x.elu(e.alpha); })
);

unary!(hard_sigmoid, HardSigmoid { alpha: f32, beta: f32 },
element_wise!(hard_sigmoid, HardSigmoid { alpha: f32, beta: f32 },
[f32, f64] => |e, xs| xs.iter_mut().for_each(|x| { *x = x.hard_sigmoid(e.alpha, e.beta); })
);

unary!(leaky_relu, LeakyRelu { alpha: f32 },
element_wise!(leaky_relu, LeakyRelu { alpha: f32 },
[f32, f64] => |e, xs| xs.iter_mut().for_each(|x| { *x = x.leaky_relu(e.alpha); })
);

unary!(parametric_softplus, ParametricSoftplus { alpha: f32, beta: f32 },
element_wise!(parametric_softplus, ParametricSoftplus { alpha: f32, beta: f32 },
[f32, f64] => |e, xs| xs.iter_mut().for_each(|x| { *x = x.parametric_softplus(e.alpha, e.beta); })
);

unary!(scaled_tanh, ScaledTanh { alpha: f32, beta: f32 },
element_wise!(scaled_tanh, ScaledTanh { alpha: f32, beta: f32 },
[f32, f64] => |e, xs| xs.iter_mut().for_each(|x| { *x = x.scaled_tanh(e.alpha, e.beta); })
);

unary!(selu, Selu { alpha: f32, gamma: f32 },
element_wise!(selu, Selu { alpha: f32, gamma: f32 },
[f32, f64] => |e, xs| xs.iter_mut().for_each(|x| { *x = x.selu(e.alpha, e.gamma); })
);

unary!(threshold_relu, ThresholdRelu { alpha: f32 },
element_wise!(threshold_relu, ThresholdRelu { alpha: f32 },
[f32, f64] => |e, xs| xs.iter_mut().for_each(|x| { *x = x.threshold_relu(e.alpha); })
);

Expand Down
2 changes: 1 addition & 1 deletion onnx/src/ops/math.rs
Expand Up @@ -67,7 +67,7 @@ pub fn clip(
Ok((op, vec![]))
}

unary!(erf, Erf, [f32] => |_, xs| xs.iter_mut().for_each(|x| *x = erf_f32(*x)));
element_wise!(erf, Erf, [f32] => |_, xs| xs.iter_mut().for_each(|x| *x = erf_f32(*x)));

#[allow(non_upper_case_globals)]
fn erf_f32(x: f32) -> f32 {
Expand Down
2 changes: 1 addition & 1 deletion onnx/src/ops/nn/mod.rs
Expand Up @@ -257,7 +257,7 @@ pub fn scaled_tanh(
Ok((Box::new(tractops::nn::scaled_tanh(alpha, beta)), vec![]))
}

unary!(shrink_op, Shrink { bias: f32, lambd: f32 },
element_wise!(shrink_op, Shrink { bias: f32, lambd: f32 },
[f16,f32,f64] => |s, xs| xs.iter_mut().for_each(|x| *x = shrink_value(*x, s))
);

Expand Down

0 comments on commit 1644a3f

Please sign in to comment.