Skip to content

Commit

Permalink
python dataframe.filter accept numpy boolean mask (#4350)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 9, 2022
1 parent c1e8174 commit d5a0aff
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
10 changes: 8 additions & 2 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,7 +1971,10 @@ def insert_at_idx(self, index: int, series: pli.Series) -> None:
index = len(self.columns) + index
self._df.insert_at_idx(index, series._s)

def filter(self, predicate: pli.Expr | str | pli.Series | list[bool]) -> DataFrame:
def filter(
self,
predicate: pli.Expr | str | pli.Series | list[bool] | np.ndarray[Any, Any],
) -> DataFrame:
"""
Filter the rows in the DataFrame based on a predicate expression.
Expand Down Expand Up @@ -2017,9 +2020,12 @@ def filter(self, predicate: pli.Expr | str | pli.Series | list[bool]) -> DataFra
└─────┴─────┴─────┘
"""
if _NUMPY_AVAILABLE and isinstance(predicate, np.ndarray):
predicate = pli.Series(predicate)

return (
self.lazy()
.filter(predicate)
.filter(predicate) # type: ignore[arg-type]
.collect(no_optimization=True, string_cache=False)
)

Expand Down
1 change: 1 addition & 0 deletions py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ def filter(self: LDF, predicate: pli.Expr | str | pli.Series | list[bool]) -> LD
"""
if isinstance(predicate, list):
predicate = pli.Series(predicate)

return self._from_pyldf(
self._ldf.filter(
pli.expr_to_lit_or_expr(predicate, str_to_lit=False)._pyexpr
Expand Down
3 changes: 2 additions & 1 deletion py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,6 +2021,7 @@ def test_len_compute(df: pl.DataFrame) -> None:
assert len(taken[col]) == 2


def test_filter_python_list() -> None:
def test_filter_sequence() -> None:
df = pl.DataFrame({"a": [1, 2, 3]})
assert df.filter([True, False, True])["a"].to_list() == [1, 3]
assert df.filter(np.array([True, False, True]))["a"].to_list() == [1, 3]

0 comments on commit d5a0aff

Please sign in to comment.