Skip to content

Commit

Permalink
Merge pull request #375 from schungx/master
Browse files Browse the repository at this point in the history
Add checked_add and checked_norm_pdf.
  • Loading branch information
paupino committed May 23, 2021
2 parents a729c0d + 4066765 commit 864c8af
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 10 deletions.
50 changes: 42 additions & 8 deletions src/maths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,21 @@ pub trait MathematicalOps {
/// tolerance of roughly `0.0000002`.
fn exp(&self) -> Decimal;

/// The estimated exponential function, e<sup>x</sup>. Stops calculating when it is within
/// tolerance of roughly `0.0000002`. Returns `None` on overflow.
fn checked_exp(&self) -> Option<Decimal>;

/// The estimated exponential function, e<sup>x</sup> using the `tolerance` provided as a hint
/// as to when to stop calculating. A larger tolerance will cause the number to stop calculating
/// sooner at the potential cost of a slightly less accurate result.
fn exp_with_tolerance(&self, tolerance: Decimal) -> Decimal;

/// The estimated exponential function, e<sup>x</sup> using the `tolerance` provided as a hint
/// as to when to stop calculating. A larger tolerance will cause the number to stop calculating
/// sooner at the potential cost of a slightly less accurate result.
/// Returns `None` on overflow.
fn checked_exp_with_tolerance(&self, tolerance: Decimal) -> Option<Decimal>;

/// Raise self to the given integer exponent: x<sup>y</sup>
fn powi(&self, exp: i64) -> Decimal;

Expand Down Expand Up @@ -93,26 +103,40 @@ pub trait MathematicalOps {
/// The Cumulative distribution function for a Normal distribution
fn norm_cdf(&self) -> Decimal;

/// The Probability density function for a Normal distribution
/// The Probability density function for a Normal distribution.
fn norm_pdf(&self) -> Decimal;

/// The Probability density function for a Normal distribution returning `None` on overflow.
fn checked_norm_pdf(&self) -> Option<Decimal>;
}

impl MathematicalOps for Decimal {
fn exp(&self) -> Decimal {
self.exp_with_tolerance(EXP_TOLERANCE)
}

#[inline]
fn checked_exp(&self) -> Option<Decimal> {
self.checked_exp_with_tolerance(EXP_TOLERANCE)
}

fn exp_with_tolerance(&self, tolerance: Decimal) -> Decimal {
match self.checked_exp_with_tolerance(tolerance) {
Some(d) => d,
None => panic!("Exp overflowed"),
}
}

#[inline]
fn checked_exp_with_tolerance(&self, tolerance: Decimal) -> Option<Decimal> {
if self.is_zero() {
return Decimal::ONE;
return Some(Decimal::ONE);
}

let mut term = *self;
let mut result = self + Decimal::ONE;

for factorial in FACTORIAL.iter().skip(2) {
term = self * term;
term = self.checked_mul(term)?;
let next = result + (term / factorial);
let diff = (next - result).abs();
result = next;
Expand All @@ -121,7 +145,7 @@ impl MathematicalOps for Decimal {
}
}

result
Some(result)
}

fn powi(&self, exp: i64) -> Decimal {
Expand Down Expand Up @@ -248,7 +272,7 @@ impl MathematicalOps for Decimal {
Some(e) => e,
None => return None,
};
let mut result = e.exp();
let mut result = e.checked_exp()?;
result.set_sign_negative(negative);
Some(result)
}
Expand Down Expand Up @@ -330,10 +354,20 @@ impl MathematicalOps for Decimal {
(Decimal::ONE + (self / Decimal::from_parts(2318911239, 3292722, 0, false, 16)).erf()) / TWO
}

/// The Probability density function for a Normal distribution
/// The Probability density function for a Normal distribution.
fn norm_pdf(&self) -> Decimal {
match self.checked_norm_pdf() {
Some(d) => d,
None => panic!("Norm Pdf overflowed"),
}
}

/// The Probability density function for a Normal distribution, return `None` on overflow.
fn checked_norm_pdf(&self) -> Option<Decimal> {
let sqrt2pi = Decimal::from_parts_raw(2133383024, 2079885984, 1358845910, 1835008);
(-self.powi(2) / TWO).exp() / sqrt2pi
let factor = -self.checked_powi(2)?;
let factor = factor.checked_div(TWO)?;
factor.checked_exp()?.checked_div(sqrt2pi)
}
}

Expand Down
29 changes: 27 additions & 2 deletions tests/decimal_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3374,13 +3374,18 @@ mod maths {
let x = Decimal::from_str(x).unwrap();
let expected = Decimal::from_str(expected).unwrap();
assert_eq!(expected, x.exp());
assert_eq!(Some(expected), x.checked_exp());
}
}

#[cfg(not(feature = "legacy-ops"))]
#[test]
fn test_exp_with_tolerance() {
let test_cases = &[
// e^0 = 1
("0", "0.0002", "1"),
// e^1 ~= 2.7182539682539682539682539683
("1", "0.0002", "2.7182539682539682539682539683"),
// e^10 ~= 22026.465794806703
(
"10",
Expand All @@ -3389,6 +3394,8 @@ mod maths {
),
// e^11 ~= 59874.14171519778
("11", "0.0002", "59873.388231055804982198781924"),
// e^11.7578 ~= 127741.03548949540892948423052
("11.7578", "0.0002", "127741.03548949540892948423052"),
// e^3 ~= 20.085536923187664
("3", "0.00002", "20.085534430970814899386327955"),
// e^8 ~= 2980.957987041727
Expand All @@ -3397,12 +3404,30 @@ mod maths {
("0.1", "0.0002", "1.1051666666666666666666666667"),
// e^2.0 ~= 7.3890560989306495
("2.0", "0.0002", "7.3890460157126823793490460156"),
// e^11.7578+ starts to overflow
("11.7579", "0.0002", ""),
// e^11.7578+ starts to overflow
("123", "0.0002", ""),
// e^-8+ starts to flip and underflow
("-8", "0.0002", "0.0002858169660624369145768176"),
// e^-1024 starts to flip and underflow
("-1024", "0.0002", ""),
];
for &(x, tolerance, expected) in test_cases {
let x = Decimal::from_str(x).unwrap();
let tolerance = Decimal::from_str(tolerance).unwrap();
let expected = Decimal::from_str(expected).unwrap();
assert_eq!(expected, x.exp_with_tolerance(tolerance));
let expected = if expected.is_empty() {
None
} else {
Some(Decimal::from_str(expected).unwrap())
};

if let Some(expected) = expected {
assert_eq!(expected, x.exp_with_tolerance(tolerance));
assert_eq!(Some(expected), x.checked_exp_with_tolerance(tolerance));
} else {
assert_eq!(None, x.checked_exp_with_tolerance(tolerance));
}
}
}

Expand Down

0 comments on commit 864c8af

Please sign in to comment.