Skip to content

Commit

Permalink
perf[rust]: SIMD accelerate bound checks in take (#4754)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 7, 2022
1 parent d23ba91 commit 70820ea
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
8 changes: 4 additions & 4 deletions polars/polars-arrow/src/compute/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ pub unsafe fn take_unchecked(arr: &dyn Array, idx: &IdxArr) -> ArrayRef {
match arr.data_type().to_physical_type() {
Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
let arr: &PrimitiveArray<$T> = arr.as_any().downcast_ref().unwrap();
if arr.validity().is_some() {
if arr.null_count() > 0 {
take_primitive_unchecked::<$T>(arr, idx)
} else {
take_no_null_primitive::<$T>(arr, idx)
take_no_null_primitive_unchecked::<$T>(arr, idx)
}
}),
LargeUtf8 => {
Expand Down Expand Up @@ -98,11 +98,11 @@ pub unsafe fn take_primitive_unchecked<T: NativeType>(
/// Take kernel for single chunk without nulls and arrow array as index.
/// # Safety
/// caller must ensure indices are in bounds
pub unsafe fn take_no_null_primitive<T: NativeType>(
pub unsafe fn take_no_null_primitive_unchecked<T: NativeType>(
arr: &PrimitiveArray<T>,
indices: &IdxArr,
) -> Box<PrimitiveArray<T>> {
debug_assert!(!arr.has_validity());
debug_assert!(arr.null_count() == 0);
let array_values = arr.values().as_slice();
let index_values = indices.values().as_slice();

Expand Down
7 changes: 3 additions & 4 deletions polars/polars-core/src/chunked_array/ops/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
//!
use std::borrow::Cow;

use polars_arrow::array::PolarsArray;
use polars_arrow::compute::take::*;
pub use take_random::*;
pub use traits::*;
Expand Down Expand Up @@ -75,9 +74,9 @@ where
if array.null_count() == array.len() {
return Self::full_null(self.name(), array.len());
}
let array = match (self.has_validity(), self.chunks.len()) {
(false, 1) => {
take_no_null_primitive::<T::Native>(chunks.next().unwrap(), array)
let array = match (self.null_count(), self.chunks.len()) {
(0, 1) => {
take_no_null_primitive_unchecked::<T::Native>(chunks.next().unwrap(), array)
as ArrayRef
}
(_, 1) => take_primitive_unchecked::<T::Native>(chunks.next().unwrap(), array)
Expand Down
14 changes: 6 additions & 8 deletions polars/polars-core/src/chunked_array/ops/take/traits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
//! Traits that indicate the allowed arguments in a ChunkedArray::take operation.
use polars_arrow::array::PolarsArray;

use crate::frame::groupby::GroupsProxyIter;
use crate::prelude::*;

Expand Down Expand Up @@ -52,8 +50,8 @@ where

for i in iter {
if i >= bound {
// we will not break here as that prevents SIMD
inbounds = false;
break;
}
}
if inbounds {
Expand All @@ -80,8 +78,8 @@ where

for i in iter.flatten() {
if i >= bound {
// we will not break here as that prevents SIMD
inbounds = false;
break;
}
}
if inbounds {
Expand Down Expand Up @@ -120,21 +118,21 @@ where
TakeIdx::Iter(i) => i.check_bounds(bound),
TakeIdx::IterNulls(i) => i.check_bounds(bound),
TakeIdx::Array(arr) => {
let values = arr.values().as_slice();
let mut inbounds = true;
let len = bound as IdxSize;
if !arr.has_validity() {
for &i in arr.values().as_slice() {
if arr.null_count() == 0 {
for &i in values {
// we will not break here as that prevents SIMD
if i >= len {
inbounds = false;
break;
}
}
} else {
for opt_v in *arr {
match opt_v {
Some(&v) if v >= len => {
inbounds = false;
break;
}
_ => {}
}
Expand Down

0 comments on commit 70820ea

Please sign in to comment.