Skip to content

Commit

Permalink
fix[rust]: specialize sqrt as power behavior of infinite values is wr…
Browse files Browse the repository at this point in the history
…ong (#4517)
  • Loading branch information
ritchie46 committed Aug 21, 2022
1 parent 82257cb commit ad8df0f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
15 changes: 9 additions & 6 deletions polars/polars-lazy/src/dsl/function_expr/pow.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,38 @@
use num::pow::Pow;
use polars_arrow::utils::CustomIterTools;
use polars_core::export::num;
use polars_core::export::num::ToPrimitive;
use polars_core::export::num::{Float, ToPrimitive};

use super::*;

fn pow_on_floats<T>(base: &ChunkedArray<T>, exponent: &Series) -> Result<Series>
where
T: PolarsFloatType,
T::Native: num::pow::Pow<T::Native, Output = T::Native> + ToPrimitive,
T::Native: num::pow::Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
ChunkedArray<T>: IntoSeries,
{
let dtype = T::get_dtype();
let exponent = exponent.cast(&dtype)?;
let exponent = base.unpack_series_matching_type(&exponent).unwrap();

if exponent.len() == 1 {
let av = exponent
let exponent_value = exponent
.get(0)
.ok_or_else(|| PolarsError::ComputeError("exponent is null".into()))?;
let s = match av.to_f64().unwrap() {
let s = match exponent_value.to_f64().unwrap() {
a if a == 1.0 => base.clone().into_series(),
// specialized sqrt will ensure (-inf)^0.5 = NaN
// and will likely be faster as well.
a if a == 0.5 => base.apply(|v| v.sqrt()).into_series(),
a if a.fract() == 0.0 && a < 10.0 && a > 1.0 => {
let mut out = base.clone();

for _ in 1..av.to_u8().unwrap() {
for _ in 1..exponent_value.to_u8().unwrap() {
out = out * base.clone()
}
out.into_series()
}
_ => base.apply(|v| Pow::pow(v, av)).into_series(),
_ => base.apply(|v| Pow::pow(v, exponent_value)).into_series(),
};
Ok(s)
} else if (base.len() == 1) && (exponent.len() != 1) {
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import polars as pl


def test_sqrt_neg_inf() -> None:
out = pl.DataFrame(
{
"val": [float("-Inf"), -9, 0, 9, float("Inf")],
}
).with_column(pl.col("val").sqrt().alias("sqrt"))
# comparing nans and infinities by string value as they are not cmp
assert str(out["sqrt"].to_list()) == str(
[float("NaN"), float("NaN"), 0.0, 3.0, float("Inf")]
)

0 comments on commit ad8df0f

Please sign in to comment.