Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): extend filter capabilities with new support for *args predicates, **kwargs constraints, and chained boolean masks #11740

Merged
merged 9 commits into from
Oct 16, 2023
49 changes: 38 additions & 11 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3899,7 +3899,8 @@ def insert_at_idx(self, index: int, series: Series) -> Self:

def filter(
self,
predicate: (Expr | str | Series | list[bool] | np.ndarray[Any, Any] | bool),
*predicates: IntoExpr | list[bool] | np.ndarray[Any, Any],
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
**constraints: Any,
) -> DataFrame:
"""
Filter the rows in the DataFrame based on a predicate expression.
Expand All @@ -3908,8 +3909,10 @@ def filter(

Parameters
----------
predicate
predicates
Expression that evaluates to a boolean Series.
constraints
Column filters. Use name=value to filter column name by the supplied value.

Examples
--------
Expand All @@ -3923,18 +3926,18 @@ def filter(

Filter on one condition:

>>> df.filter(pl.col("foo") < 3)
>>> df.filter(pl.col("foo") > 1)
shape: (2, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ foo ┆ bar ┆ ham β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ str β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════║
β”‚ 1 ┆ 6 ┆ a β”‚
β”‚ 2 ┆ 7 ┆ b β”‚
β”‚ 3 ┆ 8 ┆ c β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

Filter on multiple conditions:
Filter on multiple conditions, combined with and/or operators:

>>> df.filter((pl.col("foo") < 3) & (pl.col("ham") == "a"))
shape: (1, 3)
Expand All @@ -3946,8 +3949,6 @@ def filter(
β”‚ 1 ┆ 6 ┆ a β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

Filter on an OR condition:

>>> df.filter((pl.col("foo") == 1) | (pl.col("ham") == "c"))
shape: (2, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
Expand All @@ -3959,12 +3960,38 @@ def filter(
β”‚ 3 ┆ 8 ┆ c β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

"""
if _check_for_numpy(predicate) and isinstance(predicate, np.ndarray):
predicate = pl.Series(predicate)
Provide multiple filters using `*args` syntax:

>>> df.filter(
... pl.col("foo") == 1,
... pl.col("ham") == "a",
... )
shape: (1, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ foo ┆ bar ┆ ham β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ str β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════║
β”‚ 1 ┆ 6 ┆ a β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

Provide multiple filters using `**kwargs` syntax:

>>> df.filter(foo=1, ham="a")
shape: (1, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ foo ┆ bar ┆ ham β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ str β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════║
β”‚ 1 ┆ 6 ┆ a β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

"""
return (
self.lazy().filter(predicate).collect(_eager=True) # type: ignore[arg-type]
self.lazy()
.filter(*predicates, **constraints) # type: ignore[arg-type]
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
.collect(_eager=True)
)

@overload
Expand Down
97 changes: 89 additions & 8 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import os
import warnings
from datetime import date, datetime, time, timedelta
from functools import reduce
from io import BytesIO, StringIO
from operator import and_
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -69,6 +71,7 @@
_prepare_row_count_args,
_process_null_values,
find_stacklevel,
is_bool_sequence,
normalize_filepath,
)

Expand Down Expand Up @@ -2542,16 +2545,18 @@ def clone(self) -> Self:
"""
return self._from_pyldf(self._ldf.clone())

def filter(self, predicate: IntoExpr) -> Self:
def filter(self, *predicates: IntoExpr, **constraints: Any) -> Self:
"""
Filter the rows in the LazyFrame based on a predicate expression.

The original order of the remaining rows is preserved.

Parameters
----------
predicate
predicates
Expression that evaluates to a boolean Series.
constraints
Column filters. Use name=value to filter column name by the supplied value.

Examples
--------
Expand All @@ -2565,15 +2570,15 @@ def filter(self, predicate: IntoExpr) -> Self:

Filter on one condition:

>>> lf.filter(pl.col("foo") < 3).collect()
>>> lf.filter(pl.col("foo") > 1).collect()
shape: (2, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ foo ┆ bar ┆ ham β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ str β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════║
β”‚ 1 ┆ 6 ┆ a β”‚
β”‚ 2 ┆ 7 ┆ b β”‚
β”‚ 3 ┆ 8 ┆ c β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

Filter on multiple conditions:
Expand All @@ -2588,6 +2593,33 @@ def filter(self, predicate: IntoExpr) -> Self:
β”‚ 1 ┆ 6 ┆ a β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

Provide multiple filters using `*args` syntax:

>>> lf.filter(
... pl.col("foo") == 1,
... pl.col("ham") == "a",
... ).collect()
shape: (1, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ foo ┆ bar ┆ ham β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ str β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════║
β”‚ 1 ┆ 6 ┆ a β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

Provide multiple filters using `**kwargs` syntax:

>>> lf.filter(foo=1, ham="a").collect()
shape: (1, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ foo ┆ bar ┆ ham β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ str β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════║
β”‚ 1 ┆ 6 ┆ a β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

Filter on an OR condition:

>>> lf.filter((pl.col("foo") == 1) | (pl.col("ham") == "c")).collect()
Expand All @@ -2602,11 +2634,60 @@ def filter(self, predicate: IntoExpr) -> Self:
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

"""
if isinstance(predicate, list):
predicate = pl.Series(predicate)
# note: identify masks separately from predicates
all_predicates, boolean_masks = [], []
for p in predicates:
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
if is_bool_sequence(p):
boolean_masks.append(pl.Series(p, dtype=Boolean))
else:
all_predicates.append(parse_as_expression(p, wrap=True))

# identify deprecated usage of 'predicate' parameter
if "predicate" in constraints:
is_mask = False
if isinstance(p := constraints["predicate"], pl.Expr) or (
is_mask := is_bool_sequence(p)
):
p = constraints.pop("predicate")
warnings.warn(
"`filter` no longer takes a 'predicate' parameter.\n"
"To silence this warning you should omit the keyword and pass "
"as a positional argument instead.",
DeprecationWarning,
stacklevel=find_stacklevel(),
)
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
if is_mask:
boolean_masks.append(pl.Series(p, dtype=Boolean))
else:
all_predicates += (p,)
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved

# unpack equality constraints from kwargs
all_predicates.extend(
F.col(name).eq(value) for name, value in constraints.items()
)
if not (all_predicates or boolean_masks):
raise ValueError("No predicates or constraints provided to `filter`.")

# if multiple predicates, combine as 'horizontal' expression
combined_predicate = (
(
F.all_horizontal(*all_predicates)
if len(all_predicates) > 1
else all_predicates[0]
)._pyexpr
if all_predicates
else None
)

predicate = parse_as_expression(predicate)
return self._from_pyldf(self._ldf.filter(predicate))
# apply reduced boolean mask first, if applicable, then predicates
ldf = (
self._ldf.filter(F.lit(reduce(and_, boolean_masks))._pyexpr)
if boolean_masks
else self._ldf
)
return self._from_pyldf(
ldf if combined_predicate is None else ldf.filter(combined_predicate)
)

def select(
self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/utils/_parse_expr_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def parse_as_expression(
*,
str_as_lit: bool = False,
structify: bool = False,
wrap: bool = False,
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
) -> PyExpr | Expr:
"""
Parse a single input into an expression.
Expand Down Expand Up @@ -119,7 +120,7 @@ def parse_as_expression(
if structify:
expr = _structify_expression(expr)

return expr._pyexpr
return expr if wrap else expr._pyexpr


def _structify_expression(expr: Expr) -> Expr:
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
Utf8,
unpack_dtypes,
)
from polars.dependencies import _PYARROW_AVAILABLE
from polars.dependencies import _PYARROW_AVAILABLE, _check_for_numpy
from polars.dependencies import numpy as np

if TYPE_CHECKING:
from collections.abc import Reversible
Expand Down Expand Up @@ -68,11 +69,15 @@ def _is_iterable_of(val: Iterable[object], eltype: type | tuple[type, ...]) -> b

def is_bool_sequence(val: object) -> TypeGuard[Sequence[bool]]:
"""Check whether the given sequence is a sequence of booleans."""
if _check_for_numpy(val) and isinstance(val, np.ndarray):
return val.dtype == np.bool_
return isinstance(val, Sequence) and _is_iterable_of(val, bool)


def is_int_sequence(val: object) -> TypeGuard[Sequence[int]]:
"""Check whether the given sequence is a sequence of integers."""
if _check_for_numpy(val) and isinstance(val, np.ndarray):
return np.issubdtype(val.dtype, np.integer)
return isinstance(val, Sequence) and _is_iterable_of(val, int)


Expand Down
34 changes: 34 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2814,6 +2814,40 @@ def test_filter_sequence() -> None:
assert df.filter(np.array([True, False, True]))["a"].to_list() == [1, 3]


def test_filter_multiple_predicates() -> None:
df = pl.DataFrame(
{"a": [1, 1, 1, 2, 2], "b": [1, 1, 2, 2, 2], "c": [1, 1, 2, 3, 4]}
)

# using multiple predicates
out = df.filter(pl.col("a") == 1, pl.col("b") <= 2)
expected = pl.DataFrame({"a": [1, 1, 1], "b": [1, 1, 2], "c": [1, 1, 2]})
assert_frame_equal(out, expected)

# using multiple kwargs
out = df.filter(a=1, b=2)
expected = pl.DataFrame({"a": [1], "b": [2], "c": [2]})
assert_frame_equal(out, expected)

# using both
out = df.filter(pl.col("a") == 1, pl.col("b") <= 2, a=1, b=2)
expected = pl.DataFrame({"a": [1], "b": [2], "c": [2]})
assert_frame_equal(out, expected)

# using boolean mask
out = df.filter([True, False, False, False, True])
expected = pl.DataFrame({"a": [1, 2], "b": [1, 2], "c": [1, 4]})
assert_frame_equal(out, expected)

# using multiple boolean masks
out = df.filter(
np.array([True, True, False, True, False]),
np.array([True, False, True, True, False]),
)
expected = pl.DataFrame({"a": [1, 2], "b": [1, 2], "c": [1, 3]})
assert_frame_equal(out, expected)


def test_indexing_set() -> None:
df = pl.DataFrame({"bool": [True, True], "str": ["N/A", "N/A"], "nr": [1, 2]})

Expand Down
39 changes: 39 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,45 @@ def test_filter_str() -> None:
assert_frame_equal(result, expected)


def test_filter_multiple_predicates() -> None:
ldf = pl.LazyFrame(
{"a": [1, 1, 1, 2, 2], "b": [1, 1, 2, 2, 2], "c": [1, 1, 2, 3, 4]}
)

# using multiple predicates
out = ldf.filter(pl.col("a") == 1, pl.col("b") <= 2).collect()
expected = pl.DataFrame({"a": [1, 1, 1], "b": [1, 1, 2], "c": [1, 1, 2]})
assert_frame_equal(out, expected)

# using multiple kwargs
out = ldf.filter(a=1, b=2).collect()
expected = pl.DataFrame({"a": [1], "b": [2], "c": [2]})
assert_frame_equal(out, expected)

# using both
out = ldf.filter(pl.col("a") == 1, pl.col("b") <= 2, a=1, b=2).collect()
expected = pl.DataFrame({"a": [1], "b": [2], "c": [2]})
assert_frame_equal(out, expected)

# check 'predicate' keyword deprecation:
# note: can disambiguate new/old usage - only warn on old-style usage
with pytest.warns(
DeprecationWarning,
match="`filter` no longer takes a 'predicate' parameter",
):
ldf.filter(
predicate=pl.col("a").ge(1),
).collect()

ldf = pl.LazyFrame(
{
"description": ["eq", "gt", "ge"],
"predicate": ["==", ">", ">="],
},
)
assert ldf.filter(predicate="==").select("description").collect().item() == "eq"


def test_apply_custom_function() -> None:
ldf = pl.LazyFrame(
{
Expand Down