Skip to content

Commit

Permalink
more accurate sqrt function
Browse files Browse the repository at this point in the history
  • Loading branch information
pascalkuthe committed May 28, 2024
1 parent 91fdc06 commit ed22935
Showing 1 changed file with 121 additions and 30 deletions.
151 changes: 121 additions & 30 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,40 +281,87 @@ impl<T: Float> Complex<T> {
///
/// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`.
#[inline]
pub fn sqrt(self) -> Self {
if self.im.is_zero() {
if self.re.is_sign_positive() {
// simple positive real √r, and copy `im` for its sign
Self::new(self.re.sqrt(), self.im)
pub fn sqrt(mut self) -> Self {
// complex sqrt algorithm based on the algorithm from
// dl.acm.org/doi/abs/10.1145/363717.363780 with additional tweaks
// to increase accuracy. Compared to a naive implementationt that
// reuses the complex exp/ln implementations this algorithm has better
// accuarcy since both (real) sqrt and (real) hypot are garunteed to
// round perfectly. It's also faster since this implementation requires
// less transcendental functions and those it does use (sqrt/hypto) are
// faster comparted to exp/sin/cos.
//
// The musl libc implementation was referenced while implementing the
// algorithm here:
// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c

// TODO: rounding for very tiny subnormal numbers isn't perfect yet so
// the assert shown fails in the very worst case this leads to about
// 10% accuracy loss (see example below). As the magnitude increase the
// error quickly drops to basically zero.
//
// glibc handles that (but other implementations like musl and numpy do
// not) by upscaling very small values. That upscaling (and particularly
// it's reversal) are weird and hard to understand (and rely on mantissa
// bit size which we can't get out of the trait). In general the glibc
// implementation is ever so subtley different and I wouldn't want to
// introduce bugs by trying to adapt the underflow handling.
//
// assert_eq!(
// Complex64::new(5.212e-324, 5.212e-324).sqrt(),
// Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162)
// );

// specical cases for correct nan/inf handling
// see https://en.cppreference.com/w/c/numeric/complex/csqrt

if self.re.is_zero() && self.im.is_zero() {
// 0 +/- 0 i
return Self::new(T::zero(), self.im);
}
if self.im.is_infinite() {
// inf +/- inf i
return Self::new(T::infinity(), self.im);
}
if self.re.is_nan() {
// nan + nan i
return Self::new(self.re, T::nan());
}
if self.re.is_infinite() {
// √(inf +/- NaN i) = inf +/- NaN i
// √(inf +/- x i) = inf +/- 0 i
// √(-inf +/- NaN i) = NaN +/- inf i
// √(-inf +/- x i) = 0 +/- inf i

// if im is inf (or nan) this is nan, otherwise it's zero
#[allow(clippy::eq_op)]
let zero_or_nan = self.im - self.im;
if self.re.is_sign_negative() {
return Self::new(zero_or_nan.abs(), self.re.copysign(self.im));
} else {
// √(r e^(iπ)) = √r e^(iπ/2) = i√r
// √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r
let re = T::zero();
let im = (-self.re).sqrt();
if self.im.is_sign_positive() {
Self::new(re, im)
} else {
Self::new(re, -im)
}
}
} else if self.re.is_zero() {
// √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2)
// √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2)
let one = T::one();
let two = one + one;
let x = (self.im.abs() / two).sqrt();
if self.im.is_sign_positive() {
Self::new(x, x)
} else {
Self::new(x, -x)
return Self::new(self.re, zero_or_nan.copysign(self.im));
}
}
let two = T::one() + T::one();
let four = two + two;
let overflow = T::max_value() / (T::one() + T::sqrt(two));
let max_magnitude = self.re.abs().max(self.im.abs());
let scale = max_magnitude >= overflow;
if scale {
self = self / four;
}
if self.re.is_sign_negative() {
let tmp = ((-self.re + self.norm()) / two).sqrt();
self.re = self.im.abs() / (two * tmp);
self.im = tmp.copysign(self.im);
} else {
// formula: sqrt(r e^(it)) = sqrt(r) e^(it/2)
let one = T::one();
let two = one + one;
let (r, theta) = self.to_polar();
Self::from_polar(r.sqrt(), theta / two)
self.re = ((self.re + self.norm()) / two).sqrt();
self.im = self.im / (two * self.re);
}
if scale {
self = self * two;
}
self
}

/// Computes the principal value of the cube root of `self`.
Expand Down Expand Up @@ -2065,6 +2112,50 @@ pub(crate) mod test {
}
}

#[test]
fn test_sqrt_nan() {
assert!(close_naninf(
Complex64::new(f64::INFINITY, f64::NAN).sqrt(),
Complex64::new(f64::INFINITY, f64::NAN),
));
assert!(close_naninf(
Complex64::new(f64::NAN, f64::INFINITY).sqrt(),
Complex64::new(f64::INFINITY, f64::INFINITY),
));
assert!(close_naninf(
Complex64::new(f64::NEG_INFINITY, -f64::NAN).sqrt(),
Complex64::new(f64::NAN, f64::NEG_INFINITY),
));
assert!(close_naninf(
Complex64::new(f64::NEG_INFINITY, f64::NAN).sqrt(),
Complex64::new(f64::NAN, f64::INFINITY),
));
assert!(close_naninf(
Complex64::new(-0.0, 0.0).sqrt(),
Complex64::new(0.0, 0.0),
));
for x in (-100..100).map(f64::from) {
assert!(close_naninf(
Complex64::new(x, f64::INFINITY).sqrt(),
Complex64::new(f64::INFINITY, f64::INFINITY),
));
assert!(close_naninf(
Complex64::new(f64::NAN, x).sqrt(),
Complex64::new(f64::NAN, f64::NAN),
));
// √(inf + x i) = inf + 0 i
assert!(close_naninf(
Complex64::new(f64::INFINITY, x).sqrt(),
Complex64::new(f64::INFINITY, 0.0.copysign(x)),
));
// √(-inf + x i) = 0 + inf i
assert!(close_naninf(
Complex64::new(f64::NEG_INFINITY, x).sqrt(),
Complex64::new(0.0, f64::INFINITY.copysign(x)),
));
}
}

#[test]
fn test_cbrt() {
assert!(close(_0_0i.cbrt(), _0_0i));
Expand Down

0 comments on commit ed22935

Please sign in to comment.