Skip to content

Commit

Permalink
python fix arr.contains type (#3782)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 23, 2022
1 parent ac8b273 commit d4f3cc3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
3 changes: 2 additions & 1 deletion py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from polars import internals as pli
from polars.datatypes import (
Boolean,
DataType,
Date,
Datetime,
Expand Down Expand Up @@ -3958,7 +3959,7 @@ def contains(self, item: Union[float, str, bool, int, date, datetime]) -> "Expr"
└───────┘
"""
return wrap_expr(self._pyexpr).map(lambda s: s.arr.contains(item))
return wrap_expr(self._pyexpr).map(lambda s: s.arr.contains(item), Boolean)

def join(self, separator: str) -> "Expr":
"""
Expand Down
14 changes: 8 additions & 6 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -864,12 +864,14 @@ impl PySeries {
}

#[cfg(feature = "is_in")]
pub fn is_in(&self, other: &PySeries) -> PyResult<Self> {
let out = self
.series
.is_in(&other.series)
.map_err(PyPolarsErr::from)?;
Ok(out.into_series().into())
pub fn is_in(&self, py: Python, other: &PySeries) -> PyResult<Self> {
py.allow_threads(|| {
let out = self
.series
.is_in(&other.series)
.map_err(PyPolarsErr::from)?;
Ok(out.into_series().into())
})
}

pub fn clone(&self) -> Self {
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,20 @@ def test_regex_in_filter() -> None:
assert df.filter(
pl.fold(acc=False, f=lambda acc, s: acc | s, exprs=(pl.col("^nrs|flt*$") < 3))
).row(0) == (1, "foo", 1.0)


def test_arr_contains() -> None:
df_groups = pl.DataFrame(
{
"str_list": [
["cat", "mouse", "dog"],
["dog", "mouse", "cat"],
["dog", "mouse", "aardvark"],
],
}
)
assert df_groups.lazy().filter(
pl.col("str_list").arr.contains("cat")
).collect().to_dict(False) == {
"str_list": [["cat", "mouse", "dog"], ["dog", "mouse", "cat"]]
}

0 comments on commit d4f3cc3

Please sign in to comment.