Skip to content

Commit

Permalink
fix(rust, python): pow return type evaluation (#15506)
Browse files Browse the repository at this point in the history
  • Loading branch information
CanglongCl committed Apr 6, 2024
1 parent 87b84c9 commit 7dfc53e
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 177 deletions.
17 changes: 17 additions & 0 deletions crates/polars-core/src/utils/mod.rs
Expand Up @@ -319,18 +319,35 @@ macro_rules! with_match_physical_integer_type {(
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use $crate::datatypes::DataType::*;
match $dtype {
#[cfg(feature = "dtype-i8")]
Int8 => __with_ty__! { i8 },
#[cfg(feature = "dtype-i16")]
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
#[cfg(feature = "dtype-u8")]
UInt8 => __with_ty__! { u8 },
#[cfg(feature = "dtype-u16")]
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
dt => panic!("not implemented for dtype {:?}", dt),
}
})}

#[macro_export]
macro_rules! with_match_physical_float_type {(
$dtype:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use $crate::datatypes::DataType::*;
match $dtype {
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
dt => panic!("not implemented for dtype {:?}", dt),
}
})}

#[macro_export]
macro_rules! with_match_physical_float_polars_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
Expand Down
107 changes: 48 additions & 59 deletions crates/polars-plan/src/dsl/function_expr/pow.rs
Expand Up @@ -2,6 +2,7 @@ use arrow::legacy::kernels::pow::pow as pow_kernel;
use num::pow::Pow;
use polars_core::export::num;
use polars_core::export::num::{Float, ToPrimitive};
use polars_core::with_match_physical_integer_type;

use super::*;

Expand Down Expand Up @@ -128,65 +129,53 @@ where

fn pow_on_series(base: &Series, exponent: &Series) -> PolarsResult<Option<Series>> {
use DataType::*;
match (base.dtype(), exponent.dtype()) {
#[cfg(feature = "dtype-u8")]
(UInt8, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.u8().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
#[cfg(feature = "dtype-i8")]
(Int8, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.i8().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
#[cfg(feature = "dtype-u16")]
(UInt16, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.u16().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
#[cfg(feature = "dtype-i16")]
(Int16, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.i16().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(UInt32, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.u32().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(Int32, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.i32().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(UInt64, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.u64().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(Int64, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.i64().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(Float32, _) => {
let ca = base.f32().unwrap();
let exponent = exponent.strict_cast(&DataType::Float32)?;
pow_on_floats(ca, exponent.f32().unwrap())
},
(Float64, _) => {
let ca = base.f64().unwrap();
let exponent = exponent.strict_cast(&DataType::Float64)?;
pow_on_floats(ca, exponent.f64().unwrap())
},
_ => {
let base = base.cast(&DataType::Float64)?;
pow_on_series(&base, exponent)
},

let base_dtype = base.dtype();
polars_ensure!(
base_dtype.is_numeric(),
InvalidOperation: "`pow` operation not supported for dtype `{}` as base", base_dtype
);
let exponent_dtype = exponent.dtype();
polars_ensure!(
exponent_dtype.is_numeric(),
InvalidOperation: "`pow` operation not supported for dtype `{}` as exponent", exponent_dtype
);

// if false, dtype is float
if base_dtype.is_integer() {
with_match_physical_integer_type!(base_dtype, |$native_type| {
if exponent_dtype.is_float() {
match exponent_dtype {
Float32 => {
let ca = base.cast(&DataType::Float32)?;
pow_on_floats(ca.f32().unwrap(), exponent.f32().unwrap())
},
Float64 => {
let ca = base.cast(&DataType::Float64)?;
pow_on_floats(ca.f64().unwrap(), exponent.f64().unwrap())
},
_ => unreachable!(),
}
} else {
let ca = base.$native_type().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
}
})
} else {
match base_dtype {
Float32 => {
let ca = base.f32().unwrap();
let exponent = exponent.strict_cast(&DataType::Float32)?;
pow_on_floats(ca, exponent.f32().unwrap())
},
Float64 => {
let ca = base.f64().unwrap();
let exponent = exponent.strict_cast(&DataType::Float64)?;
pow_on_floats(ca, exponent.f64().unwrap())
},
_ => unreachable!(),
}
}
}

Expand Down
18 changes: 10 additions & 8 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Expand Up @@ -466,14 +466,16 @@ impl<'a> FieldsMapper<'a> {
}

pub(super) fn pow_dtype(&self) -> PolarsResult<Field> {
// base, exponent
match (self.fields[0].data_type(), self.fields[1].data_type()) {
(
base_dtype,
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64,
) => Ok(Field::new(self.fields[0].name(), base_dtype.clone())),
(DataType::Float32, _) => Ok(Field::new(self.fields[0].name(), DataType::Float32)),
(_, _) => Ok(Field::new(self.fields[0].name(), DataType::Float64)),
let base_dtype = self.fields[0].data_type();
let exponent_dtype = self.fields[1].data_type();
if base_dtype.is_integer() {
if exponent_dtype.is_float() {
Ok(Field::new(self.fields[0].name(), exponent_dtype.clone()))
} else {
Ok(Field::new(self.fields[0].name(), base_dtype.clone()))
}
} else {
Ok(Field::new(self.fields[0].name(), base_dtype.clone()))
}
}

Expand Down
60 changes: 30 additions & 30 deletions py-polars/polars/dataframe/frame.py
Expand Up @@ -8192,16 +8192,16 @@ def with_columns(
... )
>>> df.with_columns((pl.col("a") ** 2).alias("a^2"))
shape: (4, 4)
┌─────┬──────┬───────┬─────
│ a ┆ b ┆ c ┆ a^2
│ --- ┆ --- ┆ --- ┆ ---
│ i64 ┆ f64 ┆ bool ┆ f64
╞═════╪══════╪═══════╪═════
│ 1 ┆ 0.5 ┆ true ┆ 1.0
│ 2 ┆ 4.0 ┆ true ┆ 4.0
│ 3 ┆ 10.0 ┆ false ┆ 9.0
│ 4 ┆ 13.0 ┆ true ┆ 16.0
└─────┴──────┴───────┴─────
┌─────┬──────┬───────┬─────┐
│ a ┆ b ┆ c ┆ a^2 │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ bool ┆ i64
╞═════╪══════╪═══════╪═════╡
│ 1 ┆ 0.5 ┆ true ┆ 1
│ 2 ┆ 4.0 ┆ true ┆ 4
│ 3 ┆ 10.0 ┆ false ┆ 9
│ 4 ┆ 13.0 ┆ true ┆ 16
└─────┴──────┴───────┴─────┘
Added columns will replace existing columns with the same name.
Expand All @@ -8228,16 +8228,16 @@ def with_columns(
... ]
... )
shape: (4, 6)
┌─────┬──────┬───────┬─────┬──────┬───────┐
│ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ bool ┆ f64 ┆ f64 ┆ bool │
╞═════╪══════╪═══════╪═════╪══════╪═══════╡
│ 1 ┆ 0.5 ┆ true ┆ 1.0 ┆ 0.25 ┆ false │
│ 2 ┆ 4.0 ┆ true ┆ 4.0 ┆ 2.0 ┆ false │
│ 3 ┆ 10.0 ┆ false ┆ 9.0 ┆ 5.0 ┆ true │
│ 4 ┆ 13.0 ┆ true ┆ 16.0 ┆ 6.5 ┆ false │
└─────┴──────┴───────┴─────┴──────┴───────┘
┌─────┬──────┬───────┬─────┬──────┬───────┐
│ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ bool ┆ i64 ┆ f64 ┆ bool │
╞═════╪══════╪═══════╪═════╪══════╪═══════╡
│ 1 ┆ 0.5 ┆ true ┆ 1 ┆ 0.25 ┆ false │
│ 2 ┆ 4.0 ┆ true ┆ 4 ┆ 2.0 ┆ false │
│ 3 ┆ 10.0 ┆ false ┆ 9 ┆ 5.0 ┆ true │
│ 4 ┆ 13.0 ┆ true ┆ 16 ┆ 6.5 ┆ false │
└─────┴──────┴───────┴─────┴──────┴───────┘
Multiple columns also can be added using positional arguments instead of a list.
Expand All @@ -8247,16 +8247,16 @@ def with_columns(
... (pl.col("c").not_()).alias("not c"),
... )
shape: (4, 6)
┌─────┬──────┬───────┬─────┬──────┬───────┐
│ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ bool ┆ f64 ┆ f64 ┆ bool │
╞═════╪══════╪═══════╪═════╪══════╪═══════╡
│ 1 ┆ 0.5 ┆ true ┆ 1.0 ┆ 0.25 ┆ false │
│ 2 ┆ 4.0 ┆ true ┆ 4.0 ┆ 2.0 ┆ false │
│ 3 ┆ 10.0 ┆ false ┆ 9.0 ┆ 5.0 ┆ true │
│ 4 ┆ 13.0 ┆ true ┆ 16.0 ┆ 6.5 ┆ false │
└─────┴──────┴───────┴─────┴──────┴───────┘
┌─────┬──────┬───────┬─────┬──────┬───────┐
│ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ bool ┆ i64 ┆ f64 ┆ bool │
╞═════╪══════╪═══════╪═════╪══════╪═══════╡
│ 1 ┆ 0.5 ┆ true ┆ 1 ┆ 0.25 ┆ false │
│ 2 ┆ 4.0 ┆ true ┆ 4 ┆ 2.0 ┆ false │
│ 3 ┆ 10.0 ┆ false ┆ 9 ┆ 5.0 ┆ true │
│ 4 ┆ 13.0 ┆ true ┆ 16 ┆ 6.5 ┆ false │
└─────┴──────┴───────┴─────┴──────┴───────┘
Use keyword arguments to easily name your expression inputs.
Expand Down
32 changes: 16 additions & 16 deletions py-polars/polars/expr/expr.py
Expand Up @@ -5330,16 +5330,16 @@ def pow(self, exponent: IntoExprColumn | int | float) -> Self:
... pl.col("x").pow(pl.col("x").log(2)).alias("x ** xlog2"),
... )
shape: (4, 3)
┌─────┬──────┬────────────┐
│ x ┆ cube ┆ x ** xlog2 │
│ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ f64 │
╞═════╪══════╪════════════╡
│ 1 ┆ 1.0 ┆ 1.0 │
│ 2 ┆ 8.0 ┆ 2.0 │
│ 4 ┆ 64.0 ┆ 16.0 │
│ 8 ┆ 512.0 ┆ 512.0 │
└─────┴──────┴────────────┘
┌─────┬──────┬────────────┐
│ x ┆ cube ┆ x ** xlog2 │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ f64 │
╞═════╪══════╪════════════╡
│ 1 ┆ 1 ┆ 1.0 │
│ 2 ┆ 8 ┆ 2.0 │
│ 4 ┆ 64 ┆ 16.0 │
│ 8 ┆ 512 ┆ 512.0 │
└─────┴──────┴────────────┘
"""
return self.__pow__(exponent)

Expand Down Expand Up @@ -9185,13 +9185,13 @@ def cumulative_eval(
┌────────┐
│ values │
│ --- │
f64
i64
╞════════╡
│ 0.0
│ -3.0
│ -8.0
│ -15.0
│ -24.0
│ 0
│ -3
│ -8
│ -15
│ -24
└────────┘
"""
return self._from_pyexpr(
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/functions/lazy.py
Expand Up @@ -2210,10 +2210,10 @@ def sql_expr(sql: str | Sequence[str]) -> Expr | list[Expr]:
┌─────┬─────┬───────┐
│ a ┆ a_a ┆ a_txt │
│ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ str │
│ i64 ┆ i64 ┆ str │
╞═════╪═════╪═══════╡
│ 2 ┆ 4.0 ┆ 2 │
│ 1 ┆ 1.0 ┆ 1 │
│ 2 ┆ 4 ┆ 2 │
│ 1 ┆ 1 ┆ 1 │
└─────┴─────┴───────┘
"""
if isinstance(sql, str):
Expand Down

0 comments on commit 7dfc53e

Please sign in to comment.