Skip to content

Commit

Permalink
lazy upper/lower bound exprs; python: fix horizontal max/ min exprs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 13, 2021
1 parent 42e680d commit e8f6b0f
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 7 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ pub enum DataType {

impl DataType {
/// Convert to the physical data type
pub(crate) fn to_physical(&self) -> DataType {
pub fn to_physical(&self) -> DataType {
use DataType::*;
match self {
Date => Int32,
Expand Down
66 changes: 66 additions & 0 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,72 @@ impl Expr {
GetOutput::from_type(DataType::Utf8),
)
}

/// Get maximal value that could be hold by this dtype.
pub fn upper_bound(self) -> Expr {
self.map(
|s| {
let name = s.name();
use DataType::*;
let s = match s.dtype().to_physical() {
#[cfg(feature = "dtype-i8")]
Int8 => Series::new(name, &[i8::MAX]),
#[cfg(feature = "dtype-i16")]
Int16 => Series::new(name, &[i16::MAX]),
Int32 => Series::new(name, &[i32::MAX]),
Int64 => Series::new(name, &[i64::MAX]),
#[cfg(feature = "dtype-u8")]
UInt8 => Series::new(name, &[u8::MAX]),
#[cfg(feature = "dtype-u16")]
UInt16 => Series::new(name, &[u16::MAX]),
UInt32 => Series::new(name, &[u32::MAX]),
UInt64 => Series::new(name, &[u64::MAX]),
Float32 => Series::new(name, &[f32::INFINITY]),
Float64 => Series::new(name, &[f64::INFINITY]),
dt => {
return Err(PolarsError::ComputeError(
format!("cannot determine upper bound of dtype {}", dt).into(),
))
}
};
Ok(s)
},
GetOutput::same_type(),
)
}

/// Get minimal value that could be hold by this dtype.
pub fn lower_bound(self) -> Expr {
self.map(
|s| {
let name = s.name();
use DataType::*;
let s = match s.dtype().to_physical() {
#[cfg(feature = "dtype-i8")]
Int8 => Series::new(name, &[i8::MIN]),
#[cfg(feature = "dtype-i16")]
Int16 => Series::new(name, &[i16::MIN]),
Int32 => Series::new(name, &[i32::MIN]),
Int64 => Series::new(name, &[i64::MIN]),
#[cfg(feature = "dtype-u8")]
UInt8 => Series::new(name, &[u8::MIN]),
#[cfg(feature = "dtype-u16")]
UInt16 => Series::new(name, &[u16::MIN]),
UInt32 => Series::new(name, &[u32::MIN]),
UInt64 => Series::new(name, &[u64::MIN]),
Float32 => Series::new(name, &[f32::NEG_INFINITY]),
Float64 => Series::new(name, &[f64::NEG_INFINITY]),
dt => {
return Err(PolarsError::ComputeError(
format!("cannot determine lower bound of dtype {}", dt).into(),
))
}
};
Ok(s)
},
GetOutput::same_type(),
)
}
}

/// Create a Column Expression based on a column name.
Expand Down
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ Manipulation/ selection
Expr.interpolate
Expr.argsort
Expr.clip
Expr.lower_bound
Expr.upper_bound
Expr.str_concat

Column names
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def date_like_to_physical(dtype: Type[DataType]) -> Type[DataType]:
return Int32
if dtype == Datetime:
return Int64
if dtype == Time:
return Int64
return dtype


Expand Down
12 changes: 12 additions & 0 deletions py-polars/polars/lazy/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,18 @@ def clip(self, min_val: Union[int, float], max_val: Union[int, float]) -> "Expr"
.otherwise(self)
).keep_name()

def lower_bound(self) -> "Expr":
"""
Returns a unit Series with the lowest value possible for the dtype of this expression.
"""
return wrap_expr(self._pyexpr.lower_bound())

def upper_bound(self) -> "Expr":
"""
Returns a unit Series with the highest value possible for the dtype of this expression.
"""
return wrap_expr(self._pyexpr.upper_bound())

def str_concat(self, delimiter: str = "-") -> "Expr": # type: ignore
"""
Vertically concat the values in the Series to a single string value.
Expand Down
19 changes: 14 additions & 5 deletions py-polars/polars/lazy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,13 @@ def max(column: Union[str, tp.List["pl.Expr"], "pl.Series"]) -> Union["pl.Expr",
elif isinstance(column, list):

def max_(acc: "pl.Series", val: "pl.Series") -> "pl.Series":
mask = acc < val
mask = acc > val
return acc.zip_with(mask, val)

return fold(lit(0), max_, column).alias("max")
first = column[0]
if isinstance(first, str):
first = pl.col(first)
return fold(first, max_, column[1:]).alias("max")
else:
return col(column).max()

Expand All @@ -253,10 +256,13 @@ def min(column: Union[str, tp.List["pl.Expr"], "pl.Series"]) -> Union["pl.Expr",
elif isinstance(column, list):

def min_(acc: "pl.Series", val: "pl.Series") -> "pl.Series":
mask = acc > val
mask = acc < val
return acc.zip_with(mask, val)

return fold(lit(0), min_, column).alias("min")
first = column[0]
if isinstance(first, str):
first = pl.col(first)
return fold(first, min_, column[1:]).alias("min")
else:
return col(column).min()

Expand All @@ -274,7 +280,10 @@ def sum(column: Union[str, tp.List["pl.Expr"], "pl.Series"]) -> Union["pl.Expr",
if isinstance(column, pl.Series):
return column.sum()
elif isinstance(column, list):
return fold(lit(0), lambda a, b: a + b, column).alias("sum")
first = column[0]
if isinstance(first, str):
first = pl.col(first)
return fold(first, lambda a, b: a + b, column[1:]).alias("sum")
else:
return col(column).sum()

Expand Down
8 changes: 8 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,14 @@ impl PyExpr {
.into()
}

pub fn lower_bound(&self) -> Self {
self.inner.clone().lower_bound().into()
}

pub fn upper_bound(&self) -> Self {
self.inner.clone().upper_bound().into()
}

fn lst_max(&self) -> Self {
self.inner
.clone()
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import polars as pl


def test_horizontal_agg(fruits_cars):
df = fruits_cars
out = df.select(pl.max([pl.col("A"), pl.col("B")]))
out[:, 0].to_list() == [5, 4, 3, 4, 5]

out = df.select(pl.min([pl.col("A"), pl.col("B")]))
out[:, 0].to_list() == [1, 2, 3, 2, 1]
10 changes: 9 additions & 1 deletion py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,16 @@ def test_agg():

def test_fold():
df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]})
out = df.lazy().select(pl.sum(["a", "b"])).collect()
out = df.select(
[
pl.sum(["a", "b"]),
pl.max(["a", pl.col("b") ** 2]),
pl.min(["a", pl.col("b") ** 2]),
]
)
assert out["sum"].series_equal(pl.Series("sum", [2.0, 4.0, 6.0]))
assert out["max"].series_equal(pl.Series("max", [1.0, 4.0, 9.0]))
assert out["min"].series_equal(pl.Series("max", [1.0, 2.0, 3.0]))

out = df.select(
pl.fold(acc=lit(0), f=lambda acc, x: acc + x, exprs=pl.col("*")).alias("foo")
Expand Down

0 comments on commit e8f6b0f

Please sign in to comment.