Skip to content

Commit

Permalink
any/all expressions (#2300)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 7, 2022
1 parent 78aa370 commit 0cf8087
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 1 deletion.
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::utils::NoNull;
impl BooleanChunked {
pub fn arg_true(&self) -> UInt32Chunked {
// the allocation is probably cheaper as the filter is super fast
let ca: NoNull<UInt32Chunked> = (0u32..self.len() as u32).collect();
let ca: NoNull<UInt32Chunked> = (0u32..self.len() as u32).collect_trusted();
ca.into_inner().filter(self).unwrap()
}
}
32 changes: 32 additions & 0 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1954,6 +1954,38 @@ impl Expr {
}),
)
}

/// Check if any boolean value is `true`
pub fn any(self) -> Self {
self.apply(
move |s| {
let boolean = s.bool()?;
// TODO! Optimize this in arrow2/ polars-arrow
if boolean.all_false() {
Ok(Series::new(s.name(), [false]))
} else {
Ok(Series::new(s.name(), [true]))
}
},
GetOutput::from_type(DataType::Boolean),
)
}

/// Check if all boolean values are `true`
pub fn all(self) -> Self {
self.apply(
move |s| {
let boolean = s.bool()?;
// TODO! Optimize this in arrow2/ polars-arrow
if boolean.all_true() {
Ok(Series::new(s.name(), [true]))
} else {
Ok(Series::new(s.name(), [false]))
}
},
GetOutput::from_type(DataType::Boolean),
)
}
}

/// Create a Column Expression based on a column name.
Expand Down
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ Boolean
Expr.is_duplicated
Expr.is_between
Expr.is_in
Expr.any
Expr.all


Computations
Expand Down
7 changes: 7 additions & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ Descriptive stats
Series.n_unique
Series.has_validity

Boolean
-------
.. autosummary::
:toctree: api/

Series.any
Series.all

Computations
------------
Expand Down
20 changes: 20 additions & 0 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,26 @@ def to_physical(self) -> "Expr":
"""
return wrap_expr(self._pyexpr.to_physical())

def any(self) -> "Expr":
"""
Check if any boolean value in the column is `True`
Returns
-------
Boolean literal
"""
return wrap_expr(self._pyexpr.any())

def all(self) -> "Expr":
"""
Check if all boolean values in the column are `True`
Returns
-------
Boolean literal
"""
return wrap_expr(self._pyexpr.all())

def sqrt(self) -> "Expr":
"""
Compute the square root of the elements
Expand Down
20 changes: 20 additions & 0 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,26 @@ def sqrt(self) -> "Series":
"""
return self ** 0.5

def any(self) -> "Series":
"""
Check if any boolean value in the column is `True`
Returns
-------
Boolean literal
"""
return self.to_frame().select(pli.col(self.name).any()).to_series()

def all(self) -> "Series":
"""
Check if all boolean values in the column are `True`
Returns
-------
Boolean literal
"""
return self.to_frame().select(pli.col(self.name).all()).to_series()

def log(self) -> "Series":
"""
Natural logarithm, element-wise.
Expand Down
7 changes: 7 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,13 @@ impl PyExpr {
)
.into()
}
pub fn any(&self) -> Self {
self.inner.clone().any().into()
}

pub fn all(&self) -> Self {
self.inner.clone().all().into()
}
}

impl From<dsl::Expr> for PyExpr {
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,3 +1362,23 @@ def test_extend() -> None:

expected = pl.Series("a", [1, 2, 3, None, None, None])
verify_series_and_expr_api(a, expected, "extend", None, 3)


def test_any_all() -> None:
a = pl.Series("a", [True, False, True])
expected = pl.Series("a", [True])
verify_series_and_expr_api(a, expected, "any")
expected = pl.Series("a", [False])
verify_series_and_expr_api(a, expected, "all")

a = pl.Series("a", [True, True, True])
expected = pl.Series("a", [True])
verify_series_and_expr_api(a, expected, "any")
expected = pl.Series("a", [True])
verify_series_and_expr_api(a, expected, "all")

a = pl.Series("a", [False, False, False])
expected = pl.Series("a", [False])
verify_series_and_expr_api(a, expected, "any")
expected = pl.Series("a", [False])
verify_series_and_expr_api(a, expected, "all")

0 comments on commit 0cf8087

Please sign in to comment.