Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 144 additions & 133 deletions src/roots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,170 +202,181 @@ fn log2<T: PrimInt>(x: T) -> u32 {
macro_rules! unsigned_roots {
($T:ident) => {
impl Roots for $T {
#[inline]
fn nth_root(&self, n: u32) -> Self {
// Specialize small roots
match n {
0 => panic!("can't find a root of degree 0!"),
1 => return *self,
2 => return self.sqrt(),
3 => return self.cbrt(),
_ => (),
}
fn go(a: $T, n: u32) -> $T {
// Specialize small roots
match n {
0 => panic!("can't find a root of degree 0!"),
1 => return a,
2 => return a.sqrt(),
3 => return a.cbrt(),
_ => (),
}

// The root of values less than 2ⁿ can only be 0 or 1.
if bits::<$T>() <= n || *self < (1 << n) {
return (*self > 0) as $T;
}
// The root of values less than 2ⁿ can only be 0 or 1.
if bits::<$T>() <= n || a < (1 << n) {
return (a > 0) as $T;
}

if bits::<$T>() > 64 {
// 128-bit division is slow, so do a bitwise `nth_root` until it's small enough.
return if *self <= core::u64::MAX as $T {
(*self as u64).nth_root(n) as $T
} else {
let lo = (self >> n).nth_root(n) << 1;
let hi = lo + 1;
// 128-bit `checked_mul` also involves division, but we can't always
// compute `hiⁿ` without risking overflow. Try to avoid it though...
if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
match checked_pow(hi, n as usize) {
Some(x) if x <= *self => hi,
_ => lo,
}
if bits::<$T>() > 64 {
// 128-bit division is slow, so do a bitwise `nth_root` until it's small enough.
return if a <= core::u64::MAX as $T {
(a as u64).nth_root(n) as $T
} else {
if hi.pow(n) <= *self {
hi
let lo = (a >> n).nth_root(n) << 1;
let hi = lo + 1;
// 128-bit `checked_mul` also involves division, but we can't always
// compute `hiⁿ` without risking overflow. Try to avoid it though...
if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
match checked_pow(hi, n as usize) {
Some(x) if x <= a => hi,
_ => lo,
}
} else {
lo
if hi.pow(n) <= a {
hi
} else {
lo
}
}
};
}

#[cfg(feature = "std")]
#[inline]
fn guess(x: $T, n: u32) -> $T {
// for smaller inputs, `f64` doesn't justify its cost.
if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
1 << ((log2(x) + n - 1) / n)
} else {
((x as f64).ln() / f64::from(n)).exp() as $T
}
};
}
}

#[cfg(feature = "std")]
#[inline]
fn guess(x: $T, n: u32) -> $T {
// for smaller inputs, `f64` doesn't justify its cost.
if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $T, n: u32) -> $T {
1 << ((log2(x) + n - 1) / n)
} else {
((x as f64).ln() / f64::from(n)).exp() as $T
}
}

#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $T, n: u32) -> $T {
1 << ((log2(x) + n - 1) / n)
}

// https://en.wikipedia.org/wiki/Nth_root_algorithm
let n1 = n - 1;
let next = |x: $T| {
let y = match checked_pow(x, n1 as usize) {
Some(ax) => self / ax,
None => 0,
// https://en.wikipedia.org/wiki/Nth_root_algorithm
let n1 = n - 1;
let next = |x: $T| {
let y = match checked_pow(x, n1 as usize) {
Some(ax) => a / ax,
None => 0,
};
(y + x * n1 as $T) / n as $T
};
(y + x * n1 as $T) / n as $T
};
fixpoint(guess(*self, n), next)
fixpoint(guess(a, n), next)
}
go(*self, n)
}

#[inline]
fn sqrt(&self) -> Self {
if bits::<$T>() > 64 {
// 128-bit division is slow, so do a bitwise `sqrt` until it's small enough.
// https://en.wikipedia.org/wiki/Integer_square_root#Using_bitwise_operations
return if *self <= core::u64::MAX as $T {
(*self as u64).sqrt() as $T
} else {
let lo = (self >> 2u32).sqrt() << 1;
let hi = lo + 1;
if hi * hi <= *self {
hi
fn go(a: $T) -> $T {
if bits::<$T>() > 64 {
// 128-bit division is slow, so do a bitwise `sqrt` until it's small enough.
return if a <= core::u64::MAX as $T {
(a as u64).sqrt() as $T
} else {
lo
}
};
}
let lo = (a >> 2u32).sqrt() << 1;
let hi = lo + 1;
if hi * hi <= a {
hi
} else {
lo
}
};
}

if *self < 4 {
return (*self > 0) as Self;
}
if a < 4 {
return (a > 0) as $T;
}

#[cfg(feature = "std")]
#[inline]
fn guess(x: $T) -> $T {
(x as f64).sqrt() as $T
}
#[cfg(feature = "std")]
#[inline]
fn guess(x: $T) -> $T {
(x as f64).sqrt() as $T
}

#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $T) -> $T {
1 << ((log2(x) + 1) / 2)
}
#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $T) -> $T {
1 << ((log2(x) + 1) / 2)
}

// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
let next = |x: $T| (self / x + x) >> 1;
fixpoint(guess(*self), next)
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
let next = |x: $T| (a / x + x) >> 1;
fixpoint(guess(a), next)
}
go(*self)
}

#[inline]
fn cbrt(&self) -> Self {
if bits::<$T>() > 64 {
// 128-bit division is slow, so do a bitwise `cbrt` until it's small enough.
return if *self <= core::u64::MAX as $T {
(*self as u64).cbrt() as $T
} else {
let lo = (self >> 3u32).cbrt() << 1;
let hi = lo + 1;
if hi * hi * hi <= *self {
hi
fn go(a: $T) -> $T {
if bits::<$T>() > 64 {
// 128-bit division is slow, so do a bitwise `cbrt` until it's small enough.
return if a <= core::u64::MAX as $T {
(a as u64).cbrt() as $T
} else {
lo
}
};
}
let lo = (a >> 3u32).cbrt() << 1;
let hi = lo + 1;
if hi * hi * hi <= a {
hi
} else {
lo
}
};
}

if bits::<$T>() <= 32 {
// Implementation based on Hacker's Delight `icbrt2`
let mut x = *self;
let mut y2 = 0;
let mut y = 0;
let smax = bits::<$T>() / 3;
for s in (0..smax + 1).rev() {
let s = s * 3;
y2 *= 4;
y *= 2;
let b = 3 * (y2 + y) + 1;
if x >> s >= b {
x -= b << s;
y2 += 2 * y + 1;
y += 1;
if bits::<$T>() <= 32 {
// Implementation based on Hacker's Delight `icbrt2`
let mut x = a ;
let mut y2 = 0;
let mut y = 0;
let smax = bits::<$T>() / 3;
for s in (0..smax + 1).rev() {
let s = s * 3;
y2 *= 4;
y *= 2;
let b = 3 * (y2 + y) + 1;
if x >> s >= b {
x -= b << s;
y2 += 2 * y + 1;
y += 1;
}
}
return y;
}
return y;
}

if *self < 8 {
return (*self > 0) as Self;
}
if *self <= core::u32::MAX as $T {
return (*self as u32).cbrt() as $T;
}
if a < 8 {
return (a > 0) as $T;
}
if a <= core::u32::MAX as $T {
return (a as u32).cbrt() as $T;
}

#[cfg(feature = "std")]
#[inline]
fn guess(x: $T) -> $T {
(x as f64).cbrt() as $T
}
#[cfg(feature = "std")]
#[inline]
fn guess(x: $T) -> $T {
(x as f64).cbrt() as $T
}

#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $T) -> $T {
1 << ((log2(x) + 2) / 3)
}
#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $T) -> $T {
1 << ((log2(x) + 2) / 3)
}

// https://en.wikipedia.org/wiki/Cube_root#Numerical_methods
let next = |x: $T| (self / (x * x) + x * 2) / 3;
fixpoint(guess(*self), next)
// https://en.wikipedia.org/wiki/Cube_root#Numerical_methods
let next = |x: $T| (a / (x * x) + x * 2) / 3;
fixpoint(guess(a), next)
}
go(*self)
}
}
};
Expand Down