Skip to content

Commit

Permalink
test(transforms): add tests for Exp, Identity, Logistic and Softplus
Browse files Browse the repository at this point in the history
  • Loading branch information
tspooner committed Mar 4, 2019
1 parent 5aa8c6b commit 0d269a0
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/transforms/exponential.rs
Expand Up @@ -9,6 +9,33 @@ impl Transform<f64> for Exp {
}

fn grad(&self, x: f64) -> f64 {
x * self.transform(x)
self.transform(x)
}
}

#[cfg(test)]
mod tests {
use quickcheck::quickcheck;
use std::f64::consts::E;
use super::{Exp, Transform};

#[test]
fn test_f64() {
assert!((Exp.transform(0.0) - 1.0).abs() < 1e-7);
assert!((Exp.transform(1.0) - E).abs() < 1e-7);
assert!((Exp.transform(2.0) - E * E).abs() < 1e-7);

assert!((Exp.transform(0.0) - Exp.grad(0.0)).abs() < 1e-7);
assert!((Exp.transform(1.0) - Exp.grad(1.0)).abs() < 1e-7);
assert!((Exp.transform(2.0) - Exp.grad(2.0)).abs() < 1e-7);
}

#[test]
fn test_f64_positive() {
fn prop_positive(x: f64) -> bool {
Exp.transform(x).is_sign_positive()
}

quickcheck(prop_positive as fn(f64) -> bool);
}
}
16 changes: 16 additions & 0 deletions src/transforms/identity.rs
Expand Up @@ -26,3 +26,19 @@ impl Transform<Vector<f64>> for Identity {
x
}
}

#[cfg(test)]
mod tests {
use super::{Identity, Transform};

#[test]
fn test_f64() {
assert_eq!(Identity.transform(0.0), 0.0);
assert_eq!(Identity.transform(1.0), 1.0);
assert_eq!(Identity.transform(2.0), 2.0);

assert_eq!(Identity.grad(0.0), 1.0);
assert_eq!(Identity.grad(1.0), 1.0);
assert_eq!(Identity.grad(2.0), 1.0);
}
}
27 changes: 27 additions & 0 deletions src/transforms/logistic.rs
Expand Up @@ -15,3 +15,30 @@ impl Transform<f64> for Logistic {
exp_x / exp_x_plus_1 / exp_x_plus_1
}
}

#[cfg(test)]
mod tests {
use quickcheck::quickcheck;
use std::f64::consts::E;
use super::{Logistic, Transform};

#[test]
fn test_f64() {
assert!((Logistic.transform(0.0) - 0.5).abs() < 1e-7);
assert!((Logistic.transform(1.0) - 1.0 / (1.0 + 1.0 / E)).abs() < 1e-7);
assert!((Logistic.transform(2.0) - 1.0 / (1.0 + 1.0 / E / E)).abs() < 1e-7);

assert!((Logistic.grad(0.0) - 0.25).abs() < 1e-5);
assert!((Logistic.grad(1.0) - 0.196612).abs() < 1e-5);
assert!((Logistic.grad(2.0) - 0.104994).abs() < 1e-5);
}

#[test]
fn test_f64_positive() {
fn prop_positive(x: f64) -> bool {
Logistic.transform(x).is_sign_positive()
}

quickcheck(prop_positive as fn(f64) -> bool);
}
}
28 changes: 28 additions & 0 deletions src/transforms/softplus.rs
Expand Up @@ -12,3 +12,31 @@ impl Transform<f64> for Softplus {
Logistic.transform(x)
}
}

#[cfg(test)]
mod tests {
use crate::transforms::Logistic;
use quickcheck::quickcheck;
use std::f64::consts::E;
use super::{Softplus, Transform};

#[test]
fn test_f64() {
assert!((Softplus.transform(0.0) - 0.693147).abs() < 1e-5);
assert!((Softplus.transform(1.0) - 1.31326).abs() < 1e-5);
assert!((Softplus.transform(2.0) - 2.12693).abs() < 1e-5);

assert!((Softplus.grad(0.0) - Logistic.transform(0.0)).abs() < 1e-7);
assert!((Softplus.grad(1.0) - Logistic.transform(1.0)).abs() < 1e-7);
assert!((Softplus.grad(2.0) - Logistic.transform(2.0)).abs() < 1e-7);
}

#[test]
fn test_f64_positive() {
fn prop_positive(x: f64) -> bool {
Softplus.transform(x).is_sign_positive()
}

quickcheck(prop_positive as fn(f64) -> bool);
}
}

0 comments on commit 0d269a0

Please sign in to comment.