Skip to content

Commit

Permalink
Migrated comparison.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored and ritchie46 committed Jun 14, 2021
1 parent 6db0eda commit 1367b07
Showing 1 changed file with 52 additions and 28 deletions.
80 changes: 52 additions & 28 deletions polars/polars-core/src/chunked_array/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ impl Utf8Chunked {
fn comparison(
&self,
rhs: &Utf8Chunked,
operator: impl Fn(&LargeStringArray, &LargeStringArray) -> arrow::error::Result<BooleanArray>,
operator: comparison::Operator,
) -> Result<BooleanChunked> {
let chunks = self
.chunks
Expand All @@ -304,7 +304,7 @@ impl Utf8Chunked {
.as_any()
.downcast_ref::<LargeStringArray>()
.expect("could not downcast one of the chunks");
let arr_res = operator(left, right);
let arr_res = comparison::compare(left, right, operator);
let arr = match arr_res {
Ok(arr) => arr,
Err(e) => return Err(PolarsError::ArrowError(e)),
Expand Down Expand Up @@ -333,7 +333,8 @@ impl ChunkCompare<&Utf8Chunked> for Utf8Chunked {
}
// same length
else if self.chunk_id().zip(rhs.chunk_id()).all(|(l, r)| l == r) {
self.comparison(rhs, eq_utf8).expect("should not fail")
self.comparison(rhs, comparison::Operator::Eq)
.expect("should not fail")
} else {
apply_operand_on_chunkedarray_by_iter!(self, rhs, ==)
}
Expand All @@ -350,7 +351,8 @@ impl ChunkCompare<&Utf8Chunked> for Utf8Chunked {
}
// same length
else if self.chunk_id().zip(rhs.chunk_id()).all(|(l, r)| l == r) {
self.comparison(rhs, neq_utf8).expect("should not fail")
self.comparison(rhs, comparison::Operator::Neq)
.expect("should not fail")
} else {
apply_operand_on_chunkedarray_by_iter!(self, rhs, !=)
}
Expand All @@ -367,7 +369,8 @@ impl ChunkCompare<&Utf8Chunked> for Utf8Chunked {
}
// same length
else if self.chunk_id().zip(rhs.chunk_id()).all(|(l, r)| l == r) {
self.comparison(rhs, gt_utf8).expect("should not fail")
self.comparison(rhs, comparison::Operator::Gt)
.expect("should not fail")
} else {
apply_operand_on_chunkedarray_by_iter!(self, rhs, >)
}
Expand All @@ -384,7 +387,8 @@ impl ChunkCompare<&Utf8Chunked> for Utf8Chunked {
}
// same length
else if self.chunk_id().zip(rhs.chunk_id()).all(|(l, r)| l == r) {
self.comparison(rhs, gt_eq_utf8).expect("should not fail")
self.comparison(rhs, comparison::Operator::GtEq)
.expect("should not fail")
} else {
apply_operand_on_chunkedarray_by_iter!(self, rhs, >=)
}
Expand All @@ -401,7 +405,10 @@ impl ChunkCompare<&Utf8Chunked> for Utf8Chunked {
}
// same length
else if self.chunk_id().zip(rhs.chunk_id()).all(|(l, r)| l == r) {
self.comparison(rhs, lt_utf8).expect("should not fail")
self.comparison(rhs, |x, y| {
comparison::compare(x, comparison::Operator::Lt, y)
})
.expect("should not fail")
} else {
apply_operand_on_chunkedarray_by_iter!(self, rhs, <)
}
Expand All @@ -418,7 +425,8 @@ impl ChunkCompare<&Utf8Chunked> for Utf8Chunked {
}
// same length
else if self.chunk_id().zip(rhs.chunk_id()).all(|(l, r)| l == r) {
self.comparison(rhs, lt_eq_utf8).expect("should not fail")
self.comparison(rhs, comparison::Operator::LtEq)
.expect("should not fail")
} else {
apply_operand_on_chunkedarray_by_iter!(self, rhs, <=)
}
Expand All @@ -438,6 +446,20 @@ impl NumComp for u16 {}
impl NumComp for u32 {}
impl NumComp for u64 {}

impl<T, Rhs> ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: NumCast,
Rhs: NumComp + ToPrimitive,
{
fn primitive_compare_scalar(&self, rhs: Rhs, op: comparison::Operator) {
let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type");
self.apply_kernel_cast(|arr| {
Arc::new(comparison::primitive_compare_scalar(arr, rhs, op).unwrap())
})
}
}

impl<T, Rhs> ChunkCompare<Rhs> for ChunkedArray<T>
where
T: PolarsNumericType,
Expand All @@ -449,33 +471,35 @@ where
}

fn eq(&self, rhs: Rhs) -> BooleanChunked {
let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type");
self.apply_kernel_cast(|arr| Arc::new(eq_scalar(arr, rhs).unwrap()))
self.primitive_compare_scalar(rhs, comparison::Operator::Eq)
}

fn neq(&self, rhs: Rhs) -> BooleanChunked {
let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type");
self.apply_kernel_cast(|arr| Arc::new(neq_scalar(arr, rhs).unwrap()))
self.primitive_compare_scalar(rhs, comparison::Operator::Neq)
}

fn gt(&self, rhs: Rhs) -> BooleanChunked {
let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type");
self.apply_kernel_cast(|arr| Arc::new(gt_scalar(arr, rhs).unwrap()))
self.primitive_compare_scalar(rhs, comparison::Operator::Gt)
}

fn gt_eq(&self, rhs: Rhs) -> BooleanChunked {
let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type");
self.apply_kernel_cast(|arr| Arc::new(gt_eq_scalar(arr, rhs).unwrap()))
self.primitive_compare_scalar(rhs, comparison::Operator::GtEq)
}

fn lt(&self, rhs: Rhs) -> BooleanChunked {
let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type");
self.apply_kernel_cast(|arr| Arc::new(lt_scalar(arr, rhs).unwrap()))
self.primitive_compare_scalar(rhs, comparison::Operator::Lt)
}

fn lt_eq(&self, rhs: Rhs) -> BooleanChunked {
let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type");
self.apply_kernel_cast(|arr| Arc::new(lt_eq_scalar(arr, rhs).unwrap()))
self.primitive_compare_scalar(rhs, comparison::Operator::LtEq)
}
}

impl<T> Utf8Chunked {
fn utf8_compare_scalar(&self, rhs: &str, op: comparison::Operator) {
self.apply_kernel_cast(|arr| {
Arc::new(comparison::utf8_compare_scalar(arr, rhs, op).unwrap())
})
}
}

Expand All @@ -485,26 +509,26 @@ impl ChunkCompare<&str> for Utf8Chunked {
}

fn eq(&self, rhs: &str) -> BooleanChunked {
self.apply_kernel_cast(|arr| Arc::new(eq_utf8_scalar(arr, rhs).unwrap()))
self.utf8_compare_scalar(comparison::Operator::Eq)
}
fn neq(&self, rhs: &str) -> BooleanChunked {
self.apply_kernel_cast(|arr| Arc::new(neq_utf8_scalar(arr, rhs).unwrap()))
self.utf8_compare_scalar(comparison::Operator::Neq)
}

fn gt(&self, rhs: &str) -> BooleanChunked {
self.apply_kernel_cast(|arr| Arc::new(gt_utf8_scalar(arr, rhs).unwrap()))
self.utf8_compare_scalar(comparison::Operator::Gt)
}

fn gt_eq(&self, rhs: &str) -> BooleanChunked {
self.apply_kernel_cast(|arr| Arc::new(gt_eq_utf8_scalar(arr, rhs).unwrap()))
self.utf8_compare_scalar(comparison::Operator::GtEq)
}

fn lt(&self, rhs: &str) -> BooleanChunked {
self.apply_kernel_cast(|arr| Arc::new(lt_utf8_scalar(arr, rhs).unwrap()))
self.utf8_compare_scalar(comparison::Operator::Lt)
}

fn lt_eq(&self, rhs: &str) -> BooleanChunked {
self.apply_kernel_cast(|arr| Arc::new(lt_eq_utf8_scalar(arr, rhs).unwrap()))
self.utf8_compare_scalar(comparison::Operator::LtEq)
}
}

Expand Down Expand Up @@ -599,7 +623,7 @@ impl BooleanChunked {
macro_rules! impl_bitwise_op {
($self:ident, $rhs:ident, $arrow_method:ident, $op:tt) => {{
if $self.chunk_id().zip($rhs.chunk_id()).all(|(l, r)| l == r) {
let result = $self.bit_operation($rhs, compute::$arrow_method);
let result = $self.bit_operation($rhs, compute::boolean::$arrow_method);
result.unwrap()
} else {
let ca = $self
Expand Down Expand Up @@ -655,7 +679,7 @@ impl Not for &BooleanChunked {
let chunks = self
.downcast_iter()
.map(|a| {
let arr = compute::not(a).expect("should not fail");
let arr = compute::boolean::not(a).expect("should not fail");
Arc::new(arr) as ArrayRef
})
.collect::<Vec<_>>();
Expand Down

0 comments on commit 1367b07

Please sign in to comment.