Skip to content

Commit

Permalink
python arr.contains
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 9, 2021
1 parent ac767f1 commit 6f72cdf
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 33 deletions.
111 changes: 78 additions & 33 deletions polars/polars-core/src/chunked_array/ops/is_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,32 @@ where
return left.is_in(&right);
}

let ca: BooleanChunked = self
.into_iter()
.zip(other.list()?.into_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.unpack::<T>().unwrap();
ca.into_iter().any(|a| a == val)
}
_ => false,
})
.collect_trusted();
let mut ca: BooleanChunked = if self.len() == 1 && other.len() != 1 {
let value = self.get(0);
other
.list()?
.amortized_iter()
.map(|opt_s| {
opt_s.map(|s| {
let ca = s.as_ref().unpack::<T>().unwrap();
ca.into_iter().any(|a| a == value)
}) == Some(true)
})
.trust_my_length(other.len())
.collect_trusted()
} else {
self.into_iter()
.zip(other.list()?.amortized_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.as_ref().unpack::<T>().unwrap();
ca.into_iter().any(|a| a == val)
}
_ => false,
})
.collect_trusted()
};
ca.rename(self.name());
Ok(ca)
}
_ => {
Expand Down Expand Up @@ -109,17 +124,32 @@ impl IsIn for Utf8Chunked {
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
match other.dtype() {
DataType::List(dt) if self.dtype() == &**dt => {
let ca: BooleanChunked = self
.into_iter()
.zip(other.list()?.into_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.unpack::<Utf8Type>().unwrap();
ca.into_iter().any(|a| a == val)
}
_ => false,
})
.collect_trusted();
let mut ca: BooleanChunked = if self.len() == 1 && other.len() != 1 {
let value = self.get(0);
other
.list()?
.amortized_iter()
.map(|opt_s| {
opt_s.map(|s| {
let ca = s.as_ref().unpack::<Utf8Type>().unwrap();
ca.into_iter().any(|a| a == value)
}) == Some(true)
})
.trust_my_length(other.len())
.collect_trusted()
} else {
self.into_iter()
.zip(other.list()?.amortized_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.as_ref().unpack::<Utf8Type>().unwrap();
ca.into_iter().any(|a| a == val)
}
_ => false,
})
.collect_trusted()
};
ca.rename(self.name());
Ok(ca)
}
DataType::Utf8 => {
Expand Down Expand Up @@ -158,17 +188,32 @@ impl IsIn for BooleanChunked {
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
match other.dtype() {
DataType::List(dt) if self.dtype() == &**dt => {
let ca: BooleanChunked = self
.into_iter()
.zip(other.list()?.into_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.unpack::<BooleanType>().unwrap();
ca.into_iter().any(|a| a == val)
}
_ => false,
})
.collect_trusted();
let mut ca: BooleanChunked = if self.len() == 1 && other.len() != 1 {
let value = self.get(0);
other
.list()?
.amortized_iter()
.map(|opt_s| {
opt_s.map(|s| {
let ca = s.as_ref().unpack::<BooleanType>().unwrap();
ca.into_iter().any(|a| a == value)
}) == Some(true)
})
.trust_my_length(other.len())
.collect_trusted()
} else {
self.into_iter()
.zip(other.list()?.amortized_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.as_ref().unpack::<BooleanType>().unwrap();
ca.into_iter().any(|a| a == val)
}
_ => false,
})
.collect_trusted()
};
ca.rename(self.name());
Ok(ca)
}
_ => Err(PolarsError::SchemaMisMatch(
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,4 @@ The following methods are available under the `expr.arr` attribute.
ExprListNameSpace.get
ExprListNameSpace.first
ExprListNameSpace.last
ExprListNameSpace.contains
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,4 @@ The following methods are available under the `Series.arr` attribute.
ListNameSpace.get
ListNameSpace.first
ListNameSpace.last
ListNameSpace.contains
15 changes: 15 additions & 0 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2161,6 +2161,21 @@ def last(self) -> "Expr":
"""
return self.get(-1)

def contains(self, item: Union[float, str, bool, int, date, datetime]) -> "Expr":
"""
Check if sublists contain the given item.
Parameters
----------
item
Item that will be checked for membership
Returns
-------
Boolean mask
"""
return wrap_expr(self._pyexpr).map(lambda s: s.arr.contains(item))


class ExprStringNameSpace:
"""
Expand Down
36 changes: 36 additions & 0 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import typing as tp
from datetime import date, datetime, timedelta
from numbers import Number
Expand Down Expand Up @@ -61,6 +62,11 @@
except ImportError: # pragma: no cover
_PANDAS_AVAILABLE = False

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


def match_dtype(value: Any, dtype: "Type[DataType]") -> Any:
"""
Expand Down Expand Up @@ -923,6 +929,18 @@ def alias(self, name: str) -> "Series":
s._s.rename(name)
return s

@tp.overload
def rename(self, name: str, in_place: Literal[False] = ...) -> "Series":
...

@tp.overload
def rename(self, name: str, in_place: Literal[True]) -> None:
...

@tp.overload
def rename(self, name: str, in_place: bool) -> Optional["Series"]:
...

def rename(self, name: str, in_place: bool = False) -> Optional["Series"]:
"""
Rename this Series.
Expand Down Expand Up @@ -3414,6 +3432,24 @@ def last(self) -> "Series":
"""
return self.get(-1)

def contains(self, item: Union[float, str, bool, int, date, datetime]) -> "Series":
"""
Check if sublists contain the given item.
Parameters
----------
item
Item that will be checked for membership
Returns
-------
Boolean mask
"""
s = pli.Series("", [item])
s_list = wrap_s(self._s)
out = s.is_in(s_list)
return out.rename(s_list.name)


class DateTimeNameSpace:
"""
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,13 @@ def test_list_arr_get() -> None:
out = a.arr.get(-3)
expected = pl.Series("a", [1, None, 7])
testing.assert_series_equal(out, expected)


def test_contains() -> None:
a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]])
out = a.arr.contains(2)
expected = pl.Series("a", [True, True, False])
testing.assert_series_equal(out, expected)

out = pl.select(pl.lit(a).arr.contains(2)).to_series()
testing.assert_series_equal(out, expected)

0 comments on commit 6f72cdf

Please sign in to comment.