Skip to content

Commit

Permalink
add is_in for categoricals (#4153)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 25, 2022
1 parent a09b192 commit 3b7b605
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
1 change: 1 addition & 0 deletions polars/polars-core/src/chunked_array/ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ impl ChunkFilter<ListType> for ListChunked {
// inner type may be categorical or logical type so we clone the state.
let mut ca = self.clone();
ca.chunks = chunks;
ca.compute_len();
Ok(ca)
}
}
Expand Down
48 changes: 47 additions & 1 deletion polars/polars-core/src/chunked_array/ops/is_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,53 @@ where
impl IsIn for Utf8Chunked {
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
match other.dtype() {
DataType::List(dt) if self.dtype() == &**dt => {
#[cfg(feature = "dtype-categorical")]
DataType::List(dt) if matches!(&**dt, DataType::Categorical(_)) => {
if let DataType::Categorical(Some(rev_map)) = &**dt {
let opt_val = self.get(0);

let other = other.list()?;
match opt_val {
None => {
let mut ca: BooleanChunked = other
.amortized_iter()
.map(|opt_s| {
opt_s.map(|s| s.as_ref().null_count() > 0) == Some(true)
})
.collect_trusted();
ca.rename(self.name());
Ok(ca)
}
Some(value) => {
match rev_map.find(value) {
// all false
None => Ok(BooleanChunked::full(self.name(), false, other.len())),
Some(idx) => {
let mut ca: BooleanChunked = other
.amortized_iter()
.map(|opt_s| {
opt_s.map(|s| {
let s = s.as_ref().to_physical_repr();
let ca = s.as_ref().u32().unwrap();
if ca.null_count() == 0 {
ca.into_no_null_iter().any(|a| a == idx)
} else {
ca.into_iter().any(|a| a == Some(idx))
}
}) == Some(true)
})
.collect_trusted();
ca.rename(self.name());
Ok(ca)
}
}
}
}
} else {
unreachable!()
}
}
DataType::List(dt) if DataType::Utf8 == **dt => {
let mut ca: BooleanChunked = if self.len() == 1 && other.len() != 1 {
let value = self.get(0);
other
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,14 @@ def test_list_hash() -> None:
)
assert out.dtypes == [pl.List(pl.Int64), pl.UInt64]
assert out[0, "b"] == out[2, "b"]


def test_arr_contains_categorical() -> None:
df = pl.DataFrame(
{"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]}
).lazy()
df = df.with_column(pl.col("str").cast(pl.Categorical))
df_groups = df.groupby("group").agg([pl.col("str").list().alias("str_list")])
assert df_groups.filter(pl.col("str_list").arr.contains("C")).collect().to_dict(
False
) == {"group": [2], "str_list": [["A", "C"]]}

0 comments on commit 3b7b605

Please sign in to comment.