Skip to content

Commit

Permalink
is_in for struct dtype (#3639)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 9, 2022
1 parent ca84d99 commit 99e6cbe
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 2 deletions.
64 changes: 64 additions & 0 deletions polars/polars-core/src/chunked_array/ops/is_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,70 @@ impl IsIn for BooleanChunked {
}
}

#[cfg(feature = "dtype-struct")]
impl IsIn for StructChunked {
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
match other.dtype() {
DataType::List(_) => {
let mut ca: BooleanChunked = if self.len() == 1 && other.len() != 1 {
let mut value = vec![];
let left = self.clone().into_series();
if let AnyValue::Struct(val, _) = left.get(0) {
value = val
}
other
.list()?
.amortized_iter()
.map(|opt_s| {
opt_s.map(|s| {
let ca = s.as_ref().struct_().unwrap();
ca.into_iter().any(|a| a == value)
}) == Some(true)
})
.collect()
} else {
self.into_iter()
.zip(other.list()?.amortized_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.as_ref().struct_().unwrap();
ca.into_iter().any(|a| a == val)
}
_ => false,
})
.collect()
};
ca.rename(self.name());
Ok(ca)
}
_ => {
let other = other.struct_()?;

if self.fields().len() != other.fields().len() {
return Err(PolarsError::ComputeError(format!("Cannot compare structs in 'is_in', the number of fields differ. Fields left: {}, fields right: {}", self.fields().len(), other.fields().len()).into()));
}

let out = self
.fields()
.iter()
.zip(other.fields())
.map(|(lhs, rhs)| lhs.is_in(rhs))
.collect::<Result<Vec<_>>>()?;

let out = out.into_iter().reduce(|acc, val| {
// all false
if !acc.any() {
acc
} else {
acc & val
}
});
out.ok_or_else(|| PolarsError::ComputeError("no fields in struct".into()))
}
}
}
}

#[cfg(test)]
mod test {
use crate::prelude::*;
Expand Down
5 changes: 5 additions & 0 deletions polars/polars-core/src/series/implementations/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ impl SeriesTrait for SeriesWrap<StructChunked> {
.map(|ca| ca.into_series())
}

#[cfg(feature = "is_in")]
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
self.0.is_in(other)
}

fn fmt_list(&self) -> String {
self.0.fmt_list()
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ csv-file = ["polars-io/csv-file"]
temporal = ["polars-core/temporal", "polars-time", "dtype-datetime"]
# debugging purposes
fmt = ["polars-core/fmt"]
strings = ["polars-core/strings"]
strings = ["polars-core/strings", "polars-ops/strings"]
future = []
dtype-u8 = ["polars-core/dtype-u8"]
dtype-u16 = ["polars-core/dtype-u16"]
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,7 +1813,7 @@ def pow(self, exponent: Union[float, "Expr"]) -> "Expr":
exponent = expr_to_lit_or_expr(exponent)
return wrap_expr(self._pyexpr.pow(exponent._pyexpr))

def is_in(self, other: Union["Expr", List[Any]]) -> "Expr":
def is_in(self, other: Union["Expr", List[Any], str]) -> "Expr":
"""
Check if elements of this Series are in the right Series, or List values of the right Series.
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,21 @@ def test_arr_unique() -> None:
assert len(unique_el) == 2
assert {"a": 2, "b": 12} in unique_el
assert {"a": 1, "b": 11} in unique_el


def test_is_in_struct() -> None:
df = pl.DataFrame(
{
"struct_elem": [{"a": 1, "b": 11}, {"a": 1, "b": 90}],
"struct_list": [
[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}],
[{"a": 3, "b": 3}],
],
}
)
df

assert df.filter(pl.col("struct_elem").is_in("struct_list")).to_dict(False) == {
"struct_elem": [{"a": 1, "b": 11}],
"struct_list": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}]],
}

0 comments on commit 99e6cbe

Please sign in to comment.