Skip to content

Commit

Permalink
power by expression and improve rust lazy ergonomics (#3475)
Browse files Browse the repository at this point in the history
* power by expression

* fix and improve rust lazy ergonomics
  • Loading branch information
ritchie46 committed May 23, 2022
1 parent c6d11d9 commit 26f7b2a
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 96 deletions.
2 changes: 2 additions & 0 deletions polars/polars-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ pub use arrow;
#[cfg(feature = "temporal")]
pub use chrono;

#[cfg(feature = "private")]
pub use num;
#[cfg(feature = "private")]
pub use once_cell;
#[cfg(feature = "private")]
Expand Down
41 changes: 41 additions & 0 deletions polars/polars-lazy/src/dsl/from.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use super::*;

impl From<AggExpr> for Expr {
fn from(agg: AggExpr) -> Self {
Expr::Agg(agg)
}
}

pub trait RefString {}

impl From<&str> for Expr {
fn from(s: &str) -> Self {
col(s)
}
}

macro_rules! from_literals {
($type:ty) => {
impl From<$type> for Expr {
fn from(val: $type) -> Self {
lit(val)
}
}
};
}

from_literals!(f32);
from_literals!(f64);
#[cfg(feature = "dtype-i8")]
from_literals!(i8);
#[cfg(feature = "dtype-i16")]
from_literals!(i16);
from_literals!(i32);
from_literals!(i64);
#[cfg(feature = "dtype-u8")]
from_literals!(u8);
#[cfg(feature = "dtype-u16")]
from_literals!(u16);
from_literals!(u32);
from_literals!(u64);
from_literals!(bool);
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod pow;

use super::*;
use polars_core::prelude::*;
#[cfg(feature = "serde")]
Expand All @@ -7,7 +9,7 @@ use serde::{Deserialize, Serialize};
#[derive(Clone, PartialEq, Debug)]
pub enum FunctionExpr {
NullCount,
Pow(f64),
Pow,
#[cfg(feature = "row_hash")]
Hash(usize),
}
Expand Down Expand Up @@ -35,7 +37,7 @@ impl FunctionExpr {
use FunctionExpr::*;
match self {
NullCount => with_dtype(IDX_DTYPE),
Pow(_) => float_dtype(),
Pow => float_dtype(),
#[cfg(feature = "row_hash")]
Hash(_) => with_dtype(DataType::UInt64),
}
Expand All @@ -59,12 +61,8 @@ impl From<FunctionExpr> for NoEq<Arc<dyn SeriesUdf>> {
};
wrap!(f)
}
Pow(exponent) => {
let f = move |s: &mut [Series]| {
let s = &s[0];
s.pow(exponent)
};
wrap!(f)
Pow => {
wrap!(pow::pow)
}
#[cfg(feature = "row_hash")]
Hash(seed) => {
Expand Down
72 changes: 72 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/pow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use super::*;
use polars_arrow::utils::CustomIterTools;
use polars_core::export::num;

fn pow_on_floats<T>(base: &ChunkedArray<T>, exponent: &Series) -> Result<Series>
where
T: PolarsFloatType,
T::Native: num::pow::Pow<T::Native, Output = T::Native>,
ChunkedArray<T>: IntoSeries,
{
let dtype = T::get_dtype();
let exponent = exponent.cast(&dtype)?;
let exponent = base.unpack_series_matching_type(&exponent).unwrap();

Ok(base
.into_iter()
.zip(exponent.into_iter())
.map(|(opt_base, opt_exponent)| match (opt_base, opt_exponent) {
(Some(base), Some(exponent)) => Some(num::pow::Pow::pow(base, exponent)),
_ => None,
})
.collect_trusted::<ChunkedArray<T>>()
.into_series())
}

fn pow_on_series(base: &Series, exponent: &Series) -> Result<Series> {
use DataType::*;
match base.dtype() {
Float32 => {
let ca = base.f32().unwrap();
pow_on_floats(ca, exponent)
}
Float64 => {
let ca = base.f64().unwrap();
pow_on_floats(ca, exponent)
}
_ => {
let base = base.cast(&DataType::Float64)?;
pow_on_series(&base, exponent)
}
}
}

pub(super) fn pow(s: &mut [Series]) -> Result<Series> {
let base = &s[0];
let exponent = &s[1];

match exponent.len() {
1 => {
let av = exponent.get(0);
let exponent = av.extract::<f64>().ok_or_else(|| {
PolarsError::ComputeError(
format!(
"expected a numerical exponent in the pow expression, but got dtype: {}",
exponent.dtype()
)
.into(),
)
})?;
base.pow(exponent)
}
len => {
let base_len = base.len();
if len != base_len {
Err(PolarsError::ComputeError(
format!("pow expression: the exponents length: {len} does not match that of the base: {base_len}. Please ensure the lengths match or consider a literal exponent.").into()))
} else {
pow_on_series(base, exponent)
}
}
}
}

0 comments on commit 26f7b2a

Please sign in to comment.