Skip to content

Commit

Permalink
use eq_and_validity for strings (#2819)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 3, 2022
1 parent 5a013c6 commit 854a430
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 26 deletions.
34 changes: 8 additions & 26 deletions polars/polars-core/src/chunked_array/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ use num::{NumCast, ToPrimitive};
use std::ops::Not;
use std::sync::Arc;

type LargeStringArray = Utf8Array<i64>;

impl<T> ChunkedArray<T>
where
T: PolarsNumericType,
Expand Down Expand Up @@ -427,24 +425,14 @@ impl Utf8Chunked {
f: impl Fn(&Utf8Array<i64>, &Utf8Array<i64>) -> BooleanArray,
) -> BooleanChunked {
let chunks = self
.chunks
.iter()
.zip(&rhs.chunks)
.downcast_iter()
.zip(rhs.downcast_iter())
.map(|(left, right)| {
let left = left
.as_any()
.downcast_ref::<LargeStringArray>()
.expect("could not downcast one of the chunks");
let right = right
.as_any()
.downcast_ref::<LargeStringArray>()
.expect("could not downcast one of the chunks");
let arr = f(left, right);
Arc::new(arr) as ArrayRef
})
.collect::<Vec<_>>();

ChunkedArray::from_chunks("", chunks)
.collect();
BooleanChunked::from_chunks("", chunks)
}
}

Expand All @@ -467,12 +455,9 @@ impl ChunkCompare<&Utf8Chunked> for Utf8Chunked {
} else {
BooleanChunked::full("", false, self.len())
}
}
// same length
else if self.chunk_id().zip(rhs.chunk_id()).all(|(l, r)| l == r) {
self.comparison(rhs, |l, r| comparison::eq_and_validity(l, r))
} else {
apply_operand_on_chunkedarray_by_iter!(self, rhs, ==)
let (lhs, rhs) = align_chunks_binary(self, rhs);
lhs.comparison(&rhs, comparison::utf8::eq_and_validity)
}
}

Expand All @@ -490,12 +475,9 @@ impl ChunkCompare<&Utf8Chunked> for Utf8Chunked {
} else {
BooleanChunked::full("", false, self.len())
}
}
// same length
else if self.chunk_id().zip(rhs.chunk_id()).all(|(l, r)| l == r) {
self.comparison(rhs, |l, r| comparison::neq_and_validity(l, r))
} else {
apply_operand_on_chunkedarray_by_iter!(self, rhs, !=)
let (lhs, rhs) = align_chunks_binary(self, rhs);
lhs.comparison(&rhs, comparison::utf8::neq_and_validity)
}
}

Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,9 @@ def test_auto_explode() -> None:
.get_column("grouped")
)
assert grouped.dtype == pl.Utf8


def test_null_comparisons() -> None:
s = pl.Series("s", [None, "str", "a"])
assert (s.shift() == s).null_count() == 0
assert (s.shift() != s).null_count() == 0

0 comments on commit 854a430

Please sign in to comment.