Skip to content

Commit

Permalink
feat(rust, python): pl.min & pl.max accept wildcard similar to pl.sum (
Browse files Browse the repository at this point in the history
  • Loading branch information
sorhawell committed Nov 16, 2022
1 parent a6ccd66 commit b8716bd
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 39 deletions.
53 changes: 14 additions & 39 deletions polars/polars-lazy/polars-plan/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -897,51 +897,26 @@ pub fn sum_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
/// Get the the maximum value per row
pub fn max_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
max_exprs_impl(exprs)
}

fn max_exprs_impl(mut exprs: Vec<Expr>) -> Expr {
if exprs.len() == 1 {
return std::mem::take(&mut exprs[0]);
if exprs.is_empty() {
return Expr::Columns(Vec::new());
}

let first = std::mem::take(&mut exprs[0]);
first
.map_many(
|s| {
let s = s.to_vec();
let df = DataFrame::new_no_checks(s);
df.hmax().map(|s| s.unwrap())
},
&exprs[1..],
GetOutput::super_type(),
)
.alias("max")
let func = |s1, s2| {
let df = DataFrame::new_no_checks(vec![s1, s2]);
df.hmax().map(|s| s.unwrap())
};
reduce_exprs(func, exprs).alias("max")
}

/// Get the the minimum value per row
pub fn min_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
min_exprs_impl(exprs)
}

fn min_exprs_impl(mut exprs: Vec<Expr>) -> Expr {
if exprs.len() == 1 {
return std::mem::take(&mut exprs[0]);
if exprs.is_empty() {
return Expr::Columns(Vec::new());
}

let first = std::mem::take(&mut exprs[0]);
first
.map_many(
|s| {
let s = s.to_vec();
let df = DataFrame::new_no_checks(s);
df.hmin().map(|s| s.unwrap())
},
&exprs[1..],
GetOutput::super_type(),
)
.alias("min")
let func = |s1, s2| {
let df = DataFrame::new_no_checks(vec![s1, s2]);
df.hmin().map(|s| s.unwrap())
};
reduce_exprs(func, exprs).alias("min")
}

/// Evaluate all the expressions with a bitwise or
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,23 @@ def test_max_min_multiple_columns(fruits_cars: pl.DataFrame) -> None:
assert res.to_series(0).series_equal(pl.Series("min", [1, 2, 3, 2, 1]))


def test_max_min_wildcard_columns(fruits_cars: pl.DataFrame) -> None:
res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select(pl.min(["*"]))
assert res.to_series(0).series_equal(pl.Series("min", [1, 2, 3, 2, 1]))
res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select(pl.min([pl.all()]))
assert res.to_series(0).series_equal(pl.Series("min", [1, 2, 3, 2, 1]))

res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select(pl.max(["*"]))
assert res.to_series(0).series_equal(pl.Series("max", [5, 4, 3, 4, 5]))
res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select(pl.max([pl.all()]))
assert res.to_series(0).series_equal(pl.Series("max", [5, 4, 3, 4, 5]))

res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select(
pl.max([pl.all(), "A", "*"])
)
assert res.to_series(0).series_equal(pl.Series("max", [5, 4, 3, 4, 5]))


def test_head_tail(fruits_cars: pl.DataFrame) -> None:
res_expr = fruits_cars.select([pl.head("A", 2)])
res_series = pl.head(fruits_cars["A"], 2)
Expand Down

0 comments on commit b8716bd

Please sign in to comment.