diff --git a/src/lib.rs b/src/lib.rs index 661b67b..829f286 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -824,6 +824,43 @@ impl Div> for Complex { } } +pub trait DivAdd { + type Output; + fn div_add(self, rhs: Rhs, addend: Addend) -> Self::Output; +} + +// (a + i b) / (c + i d) + (e + i f) == [(a + i b) * (c - i d)] / (c*c + d*d) + (e + i f) +// == {(a*c + b*d) + i (-a*d + b*c)} / n + (e + i f) for n=(c*c + d*d) +impl DivAdd> for Complex +where + T: Clone + Num + Mul + MulAdd + Neg, +{ + type Output = Self; + + #[inline] + fn div_add(self, other: Complex, add: Complex) -> Self::Output { + let n = other.norm_sqr(); + let (a, b) = (self.re, self.im); + let (c, d) = (other.re, other.im); + + let re = a.clone().mul_add(c.clone(), b.clone() * d.clone()); + let im = a.mul_add(-d, b * c); + + Self::new(re, im) / n + add + } +} +impl DivAdd<&Complex> for &Complex +where + T: Clone + Num + Mul + MulAdd + Neg, +{ + type Output = Complex; + + #[inline] + fn div_add(self, other: &Complex, add: &Complex) -> Self::Output { + self.clone().div_add(other.clone(), add.clone()) + } +} + forward_all_binop!(impl Rem, rem); impl Complex { @@ -1815,7 +1852,7 @@ pub(crate) mod test { close_to_tol(a, b, 1e-10) } - fn close_to_tol(a: Complex64, b: Complex64, tol: f64) -> bool { + pub(crate) fn close_to_tol(a: Complex64, b: Complex64, tol: f64) -> bool { // returns true if a and b are reasonably close let close = (a == b) || (a - b).norm() < tol; if !close { @@ -2502,6 +2539,8 @@ pub(crate) mod test { } mod complex_arithmetic { + use crate::test::float::close_to_tol; + use super::{_05_05i, _0_0i, _0_1i, _1_0i, _1_1i, _4_2i, _neg1_1i, all_consts}; use num_traits::{MulAdd, MulAddAssign, Zero}; @@ -2603,6 +2642,39 @@ pub(crate) mod test { } } + #[test] + fn test_div_add() { + use crate::{Complex, DivAdd}; + + const _0_0i: Complex = Complex { re: 0, im: 0 }; + const _1_0i: Complex = Complex { re: 1, im: 0 }; + const _2_0i: Complex = Complex { re: 2, im: 0 }; + const _1_1i: Complex = Complex { re: 1, im: 1 }; + const _0_1i: Complex = Complex { re: 0, im: 1 }; + const _neg1_1i: Complex = Complex { re: -1, im: 1 }; + const all_consts: [Complex; 6] = [_0_0i, _1_0i, _2_0i, _1_1i, _0_1i, _neg1_1i]; + const non_zero_consts: [Complex; 5] = [_1_0i, _2_0i, _1_1i, _0_1i, _neg1_1i]; + + assert_eq!(_1_0i.div_add(_1_0i, _0_0i), _1_0i); + assert_eq!(_0_1i.div_add(_0_1i, _0_1i), _1_1i); + assert_eq!(_1_0i.div_add(_1_0i, _1_0i), _2_0i); + + // a/b+c ~= 6.34 * e^(2i) + const _a: Complex = Complex { re: 1.23, im: -3.4 }; + const _b: Complex = Complex { re: -6.78, im: 9.0 }; + const _c: Complex = Complex { re: -2.34, im: 5.6 }; + assert!(close_to_tol(_a.div_add(_b, _c), _a / _b + _c, 1e-10)); + + for &a in &all_consts { + for &b in &non_zero_consts { + for &c in &all_consts { + assert_eq!(a.div_add(b, c), a / b + c); + assert_eq!((&a).div_add(&b, &c), a / b + c); + } + } + } + } + #[test] fn test_div() { test_op!(_neg1_1i / _0_1i, _1_1i);