Skip to content

Commit

Permalink
Supporting Struct comparison and any/all API (#3180)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjermain committed Apr 19, 2022
1 parent 69dc5ba commit 8b2db30
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 16 deletions.
46 changes: 46 additions & 0 deletions polars/polars-core/src/chunked_array/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,52 @@ impl ChunkEqualElement for Utf8Chunked {

impl ChunkEqualElement for ListChunked {}

#[cfg(feature = "dtype-struct")]
impl ChunkCompare<&StructChunked> for StructChunked {
fn eq_missing(&self, rhs: &StructChunked) -> BooleanChunked {
self.equal(rhs)
}

fn equal(&self, rhs: &StructChunked) -> BooleanChunked {
if self.len() != rhs.len() {
BooleanChunked::full("", false, self.len())
} else {
let equal_count: usize = self
.fields()
.iter()
.zip(rhs.fields().iter())
.map(|(l, r)| l.series_equal(r) as usize)
.sum();
if equal_count == self.fields().len() {
BooleanChunked::full("", true, self.len())
} else {
BooleanChunked::full("", false, self.len())
}
}
}

fn not_equal(&self, rhs: &StructChunked) -> BooleanChunked {
self.equal(rhs).not()
}

// following are not implemented because gt, lt comparison of series don't make sense
fn gt(&self, _rhs: &StructChunked) -> BooleanChunked {
unimplemented!()
}

fn gt_eq(&self, _rhs: &StructChunked) -> BooleanChunked {
unimplemented!()
}

fn lt(&self, _rhs: &StructChunked) -> BooleanChunked {
unimplemented!()
}

fn lt_eq(&self, _rhs: &StructChunked) -> BooleanChunked {
unimplemented!()
}
}

#[cfg(test)]
mod test {
use super::super::{arithmetic::test::create_two_chunked, test::get_chunked_array};
Expand Down
5 changes: 5 additions & 0 deletions polars/polars-core/src/series/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ macro_rules! impl_compare {
.unwrap()
.logical()
.$method(rhs.categorical().unwrap().logical()),
#[cfg(feature = "dtype-struct")]
DataType::Struct(_) => lhs
.struct_()
.unwrap()
.$method(rhs.struct_().unwrap().deref()),

_ => unimplemented!(),
}
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,25 +547,25 @@ def sqrt(self) -> "Series":
"""
return self ** 0.5

def any(self) -> "Series":
def any(self) -> bool:
"""
Check if any boolean value in the column is `True`
Returns
-------
Boolean literal
"""
return self.to_frame().select(pli.col(self.name).any()).to_series()
return self.to_frame().select(pli.col(self.name).any()).to_series()[0]

def all(self) -> "Series":
def all(self) -> bool:
"""
Check if all boolean values in the column are `True`
Returns
-------
Boolean literal
"""
return self.to_frame().select(pli.col(self.name).all()).to_series()
return self.to_frame().select(pli.col(self.name).all()).to_series()[0]

def log(self, base: float = math.e) -> "Series":
"""
Expand Down
18 changes: 6 additions & 12 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,22 +1443,16 @@ def test_extend_constant() -> None:

def test_any_all() -> None:
a = pl.Series("a", [True, False, True])
expected = pl.Series("a", [True])
verify_series_and_expr_api(a, expected, "any")
expected = pl.Series("a", [False])
verify_series_and_expr_api(a, expected, "all")
assert a.any() is True
assert a.all() is False

a = pl.Series("a", [True, True, True])
expected = pl.Series("a", [True])
verify_series_and_expr_api(a, expected, "any")
expected = pl.Series("a", [True])
verify_series_and_expr_api(a, expected, "all")
assert a.any() is True
assert a.all() is True

a = pl.Series("a", [False, False, False])
expected = pl.Series("a", [False])
verify_series_and_expr_api(a, expected, "any")
expected = pl.Series("a", [False])
verify_series_and_expr_api(a, expected, "all")
assert a.any() is False
assert a.all() is False


def test_product() -> None:
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
import pytest

import polars as pl

Expand Down Expand Up @@ -126,6 +127,25 @@ def test_value_counts_expr() -> None:
assert out == [("a", 1), ("b", 2), ("c", 3)]


def test_struct_comparison() -> None:
s1 = pl.DataFrame({"b": [1, 2, 3]}).to_struct("a")
s2 = pl.DataFrame({"b": [0, 0, 0]}).to_struct("a")
s3 = pl.DataFrame({"c": [1, 2, 3]}).to_struct("a")
s4 = pl.DataFrame({"b": [1, 2, 3]}).to_struct("a")

pl.testing.assert_series_equal(s1, s1)
pl.testing.assert_series_equal(s1, s4)

with pytest.raises(AssertionError):
pl.testing.assert_series_equal(s1, s2)

with pytest.raises(AssertionError):
pl.testing.assert_series_equal(s1, s3)

assert (s1 != s2).all() is True
assert (s1 == s4).all() is True


def test_nested_struct() -> None:
df = pl.DataFrame({"d": [1, 2, 3], "e": ["foo", "bar", "biz"]})
# Nest the datafame
Expand Down

0 comments on commit 8b2db30

Please sign in to comment.