Skip to content

Commit

Permalink
pow fast paths (#3738)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 19, 2022
1 parent e1f9cc3 commit efae6b2
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 131 deletions.
49 changes: 1 addition & 48 deletions polars/polars-core/src/chunked_array/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
use crate::prelude::*;
use crate::utils::{align_chunks_binary, align_chunks_binary_owned};
use arrow::array::PrimitiveArray;
use arrow::{
compute,
compute::{arithmetics::basic, arity_assign},
};
use arrow::compute::{arithmetics::basic, arity_assign};
use num::{Num, NumCast, ToPrimitive};
use std::borrow::Cow;
use std::ops::{Add, Div, Mul, Rem, Sub};
Expand Down Expand Up @@ -489,43 +486,6 @@ impl Add<&str> for &Utf8Chunked {
}
}

pub trait Pow {
fn pow_f32(&self, _exp: f32) -> Float32Chunked {
unimplemented!()
}
fn pow_f64(&self, _exp: f64) -> Float64Chunked {
unimplemented!()
}
}

impl<T> Pow for ChunkedArray<T>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkCast,
{
fn pow_f32(&self, exp: f32) -> Float32Chunked {
let s = self.cast(&DataType::Float32).unwrap();
s.f32().unwrap().apply_kernel(&|arr| {
Box::new(compute::arity::unary(
arr,
|x| x.powf(exp),
DataType::Float32.to_arrow(),
))
})
}

fn pow_f64(&self, exp: f64) -> Float64Chunked {
let s = self.cast(&DataType::Float64).unwrap();
s.f64().unwrap().apply_kernel(&|arr| {
Box::new(compute::arity::unary(
arr,
|x| x.powf(exp),
DataType::Float64.to_arrow(),
))
})
}
}

#[cfg(test)]
pub(crate) mod test {
use crate::prelude::*;
Expand Down Expand Up @@ -554,11 +514,4 @@ pub(crate) mod test {
let _ = &a1 / &a1;
let _ = &a1 * &a1;
}

#[test]
fn test_power() {
let a = UInt32Chunked::new("", &[1, 2, 3]);
let b = a.pow_f64(2.);
println!("{:?}", b);
}
}
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/ops/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<T: PolarsNumericType> ChunkedArray<T> {
}

impl<T: PolarsNumericType> ChunkedArray<T> {
pub(crate) fn apply_mut<F>(&mut self, f: F)
pub fn apply_mut<F>(&mut self, f: F)
where
F: Fn(T::Native) -> T::Native + Copy,
{
Expand Down
1 change: 0 additions & 1 deletion polars/polars-core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ pub(crate) use crate::frame::{groupby::aggregations::*, hash_join::*};
pub(crate) use crate::utils::CustomIterTools;
pub use crate::{
chunked_array::{
arithmetic::Pow,
builder::{
BooleanChunkedBuilder, ChunkedBuilder, ListBooleanChunkedBuilder, ListBuilderTrait,
ListPrimitiveChunkedBuilder, ListUtf8ChunkedBuilder, NewChunkedArray,
Expand Down
6 changes: 0 additions & 6 deletions polars/polars-core/src/series/implementations/dates_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,6 @@ macro_rules! impl_dyn_series {
Arc::new(SeriesWrap(Clone::clone(&self.0)))
}

fn pow(&self, _exponent: f64) -> Result<Series> {
Err(PolarsError::ComputeError(
"cannot compute power of logical".into(),
))
}

fn peak_max(&self) -> BooleanChunked {
self.0.peak_max()
}
Expand Down
14 changes: 0 additions & 14 deletions polars/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,20 +421,6 @@ macro_rules! impl_dyn_series {
Arc::new(SeriesWrap(Clone::clone(&self.0)))
}

fn pow(&self, exponent: f64) -> Result<Series> {
let f_err = || {
Err(PolarsError::InvalidOperation(
format!("power operation not supported on dtype {:?}", self.dtype()).into(),
))
};

match self.dtype() {
DataType::Utf8 | DataType::List(_) | DataType::Boolean => f_err(),
DataType::Float32 => Ok(self.0.pow_f32(exponent as f32).into_series()),
_ => Ok(self.0.pow_f64(exponent).into_series()),
}
}

fn peak_max(&self) -> BooleanChunked {
self.0.peak_max()
}
Expand Down
14 changes: 0 additions & 14 deletions polars/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,20 +508,6 @@ macro_rules! impl_dyn_series {
Arc::new(SeriesWrap(Clone::clone(&self.0)))
}

fn pow(&self, exponent: f64) -> Result<Series> {
let f_err = || {
Err(PolarsError::InvalidOperation(
format!("power operation not supported on dtype {:?}", self.dtype()).into(),
))
};

match self.dtype() {
DataType::Utf8 | DataType::List(_) | DataType::Boolean => f_err(),
DataType::Float32 => Ok(self.0.pow_f32(exponent as f32).into_series()),
_ => Ok(self.0.pow_f64(exponent).into_series()),
}
}

fn peak_max(&self) -> BooleanChunked {
self.0.peak_max()
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl Series {

#[doc(hidden)]
#[cfg(feature = "private")]
pub(crate) fn _get_inner_mut(&mut self) -> &mut dyn SeriesTrait {
pub fn _get_inner_mut(&mut self) -> &mut dyn SeriesTrait {
if Arc::weak_count(&self.0) + Arc::strong_count(&self.0) != 1 {
self.0 = self.0.clone_inner();
}
Expand Down
7 changes: 0 additions & 7 deletions polars/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,13 +631,6 @@ pub trait SeriesTrait:
invalid_operation_panic!(self)
}

/// Raise a numeric series to the power of exponent.
fn pow(&self, _exponent: f64) -> Result<Series> {
Err(PolarsError::InvalidOperation(
format!("power operation not supported on dtype {:?}", self.dtype()).into(),
))
}

/// Get a boolean mask of the local maximum peaks.
fn peak_max(&self) -> BooleanChunked {
invalid_operation_panic!(self)
Expand Down
70 changes: 37 additions & 33 deletions polars/polars-lazy/src/dsl/function_expr/pow.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,46 @@
use super::*;
use polars_arrow::utils::CustomIterTools;
use polars_core::export::num;
use polars_core::export::num::ToPrimitive;

fn pow_on_floats<T>(base: &ChunkedArray<T>, exponent: &Series) -> Result<Series>
where
T: PolarsFloatType,
T::Native: num::pow::Pow<T::Native, Output = T::Native>,
T::Native: num::pow::Pow<T::Native, Output = T::Native> + ToPrimitive,
ChunkedArray<T>: IntoSeries,
{
let dtype = T::get_dtype();
let exponent = exponent.cast(&dtype)?;
let exponent = base.unpack_series_matching_type(&exponent).unwrap();

Ok(base
.into_iter()
.zip(exponent.into_iter())
.map(|(opt_base, opt_exponent)| match (opt_base, opt_exponent) {
(Some(base), Some(exponent)) => Some(num::pow::Pow::pow(base, exponent)),
_ => None,
})
.collect_trusted::<ChunkedArray<T>>()
.into_series())
if exponent.len() == 1 {
let av = exponent
.get(0)
.ok_or_else(|| PolarsError::ComputeError("exponent is null".into()))?;
let s = match av.to_f64().unwrap() {
a if a == 1.0 => base.clone().into_series(),
a if a.fract() == 0.0 && a < 10.0 && a > 1.0 => {
let mut out = base.clone();

for _ in 1..av.to_u8().unwrap() {
out = out * base.clone()
}
out.into_series()
}
_ => base.apply(|v| num::pow::Pow::pow(v, av)).into_series(),
};
Ok(s)
} else {
Ok(base
.into_iter()
.zip(exponent.into_iter())
.map(|(opt_base, opt_exponent)| match (opt_base, opt_exponent) {
(Some(base), Some(exponent)) => Some(num::pow::Pow::pow(base, exponent)),
_ => None,
})
.collect_trusted::<ChunkedArray<T>>()
.into_series())
}
}

fn pow_on_series(base: &Series, exponent: &Series) -> Result<Series> {
Expand All @@ -45,28 +65,12 @@ pub(super) fn pow(s: &mut [Series]) -> Result<Series> {
let base = &s[0];
let exponent = &s[1];

match exponent.len() {
1 => {
let av = exponent.get(0);
let exponent = av.extract::<f64>().ok_or_else(|| {
PolarsError::ComputeError(
format!(
"expected a numerical exponent in the pow expression, but got dtype: {}",
exponent.dtype()
)
.into(),
)
})?;
base.pow(exponent)
}
len => {
let base_len = base.len();
if len != base_len {
Err(PolarsError::ComputeError(
format!("pow expression: the exponents length: {len} does not match that of the base: {base_len}. Please ensure the lengths match or consider a literal exponent.").into()))
} else {
pow_on_series(base, exponent)
}
}
let base_len = base.len();
let exp_len = exponent.len();
if exp_len != base_len && (exp_len != 1) {
Err(PolarsError::ComputeError(
format!("pow expression: the exponents length: {exp_len} does not match that of the base: {base_len}. Please ensure the lengths match or consider a literal exponent.").into()))
} else {
pow_on_series(base, exponent)
}
}
10 changes: 5 additions & 5 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ fn test_groupby_rank() -> Result<()> {
fn test_apply_multiple_columns() -> Result<()> {
let df = fruits_cars();

let multiply = |s: &mut [Series]| Ok(&s[0].pow(2.0).unwrap() * &s[1]);
let multiply = |s: &mut [Series]| Ok(&(&s[0] * &s[0]) * &s[1]);

let out = df
.clone()
Expand All @@ -1700,10 +1700,10 @@ fn test_apply_multiple_columns() -> Result<()> {
)])
.collect()?;
let out = out.column("A")?;
let out = out.f64()?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(5.0), Some(16.0), Some(27.0), Some(32.0), Some(25.0)]
&[Some(5), Some(16), Some(27), Some(32), Some(25)]
);

let out = df
Expand All @@ -1718,9 +1718,9 @@ fn test_apply_multiple_columns() -> Result<()> {

let out = out.column("A")?;
let out = out.list()?.get(1).unwrap();
let out = out.f64()?;
let out = out.i32()?;

assert_eq!(Vec::from(out), &[Some(16.0)]);
assert_eq!(Vec::from(out), &[Some(16)]);
Ok(())
}

Expand Down
15 changes: 14 additions & 1 deletion polars/polars-time/src/chunkedarray/rolling_window/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,20 @@ where
.clone()
.into_series()
.rolling_var(options)
.and_then(|ca| ca.pow(0.5));
.map(|mut s| {
match s.dtype().clone() {
DataType::Float32 => {
let ca: &mut ChunkedArray<Float32Type> = s._get_inner_mut().as_mut();
ca.apply_mut(|v| v.powf(0.5))
}
DataType::Float64 => {
let ca: &mut ChunkedArray<Float64Type> = s._get_inner_mut().as_mut();
ca.apply_mut(|v| v.powf(0.5))
}
_ => unreachable!(),
}
s
});
}

rolling_agg(
Expand Down

0 comments on commit efae6b2

Please sign in to comment.