Skip to content

Commit

Permalink
fix: allow search_sorted directly on multiple chunks, and fix behavio…
Browse files Browse the repository at this point in the history
…r around nulls (#16447)
  • Loading branch information
orlp committed May 24, 2024
1 parent bb1c73c commit d4c3aba
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 222 deletions.
36 changes: 11 additions & 25 deletions crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
//! Implementations of the ChunkAgg trait.
use num_traits::Float;

use self::search_sorted::{
binary_search_array, slice_sorted_non_null_and_offset, SearchSortedSide,
};
use self::search_sorted::{binary_search_ca, SearchSortedSide};
use crate::prelude::*;

impl<T> ChunkedArray<T>
Expand All @@ -14,31 +11,21 @@ where
fn float_arg_max_sorted_ascending(&self) -> usize {
let ca = self;
debug_assert!(ca.is_sorted_ascending_flag());
let is_descending = false;
let side = SearchSortedSide::Left;

let maybe_max_idx = ca.last_non_null().unwrap();

let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) };
if !maybe_max.is_nan() {
return maybe_max_idx;
}

let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) };
let arr = unsafe { ca.downcast_get_unchecked(0) };
let search_val = T::Native::nan();
let idx = binary_search_array(side, arr, search_val, is_descending) as usize;

let idx = idx.saturating_sub(1);

offset + idx
let search_val = std::iter::once(Some(T::Native::nan()));
let idx = binary_search_ca(ca, search_val, SearchSortedSide::Left, false)[0] as usize;
idx.saturating_sub(1)
}

fn float_arg_max_sorted_descending(&self) -> usize {
let ca = self;
debug_assert!(ca.is_sorted_descending_flag());
let is_descending = true;
let side = SearchSortedSide::Right;

let maybe_max_idx = ca.first_non_null().unwrap();

Expand All @@ -47,14 +34,13 @@ where
return maybe_max_idx;
}

let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) };
let arr = unsafe { ca.downcast_get_unchecked(0) };
let search_val = T::Native::nan();
let idx = binary_search_array(side, arr, search_val, is_descending) as usize;

let idx = if idx == arr.len() { idx - 1 } else { idx };

offset + idx
let search_val = std::iter::once(Some(T::Native::nan()));
let idx = binary_search_ca(ca, search_val, SearchSortedSide::Right, true)[0] as usize;
if idx == ca.len() {
idx - 1
} else {
idx
}
}
}

Expand Down
275 changes: 184 additions & 91 deletions crates/polars-core/src/chunked_array/ops/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::cmp::Ordering;
use std::fmt::Debug;

#[cfg(feature = "serde")]
Expand All @@ -15,114 +14,208 @@ pub enum SearchSortedSide {
Right,
}

/// Search the left or right index that still fulfills the requirements.
fn get_side_idx<'a, A>(side: SearchSortedSide, mid: IdxSize, arr: &'a A, len: usize) -> IdxSize
/// Computes the first point on [lo, hi) where f is true, assuming it is first
/// always false and then always true. It is assumed f(hi) is true.
/// midpoint is a function that returns some lo < i < hi if one exists, else lo.
fn lower_bound<I, F, M>(mut lo: I, mut hi: I, midpoint: M, f: F) -> I
where
A: StaticArray,
A::ValueT<'a>: TotalOrd + Debug + Copy,
I: PartialEq + Eq,
M: Fn(&I, &I) -> I,
F: Fn(&I) -> bool,
{
let mut mid = mid;

// approach the boundary from any side
// this is O(n) we could make this binary search later
match side {
SearchSortedSide::Any => mid,
SearchSortedSide::Left => {
if mid as usize == len {
mid -= 1;
}
loop {
let m = midpoint(&lo, &hi);
if m == lo {
return if f(&lo) { lo } else { hi };
}

let current = unsafe { arr.get_unchecked(mid as usize) };
loop {
if mid == 0 {
return mid;
}
mid -= 1;
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) {
return mid + 1;
}
}
},
SearchSortedSide::Right => {
if mid as usize == len {
return mid;
}
let current = unsafe { arr.get_unchecked(mid as usize) };
let bound = (len - 1) as IdxSize;
loop {
if mid >= bound {
return mid + 1;
}
mid += 1;
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) {
return mid;
}
}
},
if f(&m) {
hi = m;
} else {
lo = m;
}
}
}

pub fn binary_search_array<'a, A>(
side: SearchSortedSide,
arr: &'a A,
search_value: A::ValueT<'a>,
descending: bool,
) -> IdxSize
/// Search through a series of chunks for the first position where f(x) is true,
/// assuming it is first always false and then always true. It repeats this for
/// each value in search_values. If the search value is null null_idx is returned.
///
/// Assumes the chunks are non-empty.
pub fn lower_bound_chunks<'a, T, F>(
chunks: &[&'a T::Array],
search_values: impl Iterator<Item = Option<T::Physical<'a>>>,
null_idx: IdxSize,
f: F,
) -> Vec<IdxSize>
where
A: StaticArray,
A::ValueT<'a>: TotalOrd + Debug + Copy,
T: PolarsDataType,
F: Fn(&'a T::Array, usize, &T::Physical<'a>) -> bool,
{
let mut size = arr.len() as IdxSize;
let mut left = 0 as IdxSize;
let mut right = size;
while left < right {
let mid = left + size / 2;
if chunks.is_empty() {
return search_values.map(|_| 0).collect();
}

// SAFETY: the call is made safe by the following invariants:
// - `mid >= 0`
// - `mid < size`: `mid` is limited by `[left; right)` bound.
let cmp = match unsafe { arr.get_unchecked(mid as usize) } {
None => Ordering::Less,
Some(value) => {
if descending {
search_value.tot_cmp(&value)
// Fast-path: only a single chunk.
if chunks.len() == 1 {
let chunk = &chunks[0];
return search_values
.map(|ov| {
if let Some(v) = ov {
lower_bound(0, chunk.len(), |l, r| (l + r) / 2, |m| f(chunk, *m, &v)) as IdxSize
} else {
value.tot_cmp(&search_value)
null_idx
}
},
};

// The reason why we use if/else control flow rather than match
// is because match reorders comparison operations, which is perf sensitive.
// This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra.
if cmp == Ordering::Less {
left = mid + 1;
} else if cmp == Ordering::Greater {
right = mid;
} else {
return get_side_idx(side, mid, arr, arr.len());
}
})
.collect();
}

size = right - left;
// Multiple chunks, precompute prefix sum of lengths so we can look up
// in O(1) the global position of chunk i.
let mut sz = 0;
let mut chunk_len_prefix_sum = Vec::with_capacity(chunks.len() + 1);
for c in chunks {
chunk_len_prefix_sum.push(sz);
sz += c.len();
}
chunk_len_prefix_sum.push(sz);

// For each search value do a binary search on (chunk_idx, idx_in_chunk) pairs.
search_values
.map(|ov| {
let Some(v) = ov else {
return null_idx;
};
let left = (0, 0);
let right = (chunks.len(), 0);
let midpoint = |l: &(usize, usize), r: &(usize, usize)| {
if l.0 == r.0 {
// Within same chunk.
(l.0, (l.1 + r.1) / 2)
} else if l.0 + 1 == r.0 {
// Two adjacent chunks, might have to be l or r.
let left_len = chunks[l.0].len() - l.1;

left
let logical_mid = (left_len + r.1) / 2;
if logical_mid < left_len {
(l.0, l.1 + logical_mid)
} else {
(r.0, logical_mid - left_len)
}
} else {
// Has a chunk in between.
((l.0 + r.0) / 2, 0)
}
};

let bound = lower_bound(left, right, midpoint, |m| {
f(unsafe { chunks.get_unchecked(m.0) }, m.1, &v)
});

(chunk_len_prefix_sum[bound.0] + bound.1) as IdxSize
})
.collect()
}

/// Get a slice of the non-null values of a sorted array. The returned array
/// will have a single chunk.
/// # Safety
/// The array is sorted and has at least one non-null value.
pub unsafe fn slice_sorted_non_null_and_offset<T>(ca: &ChunkedArray<T>) -> (usize, ChunkedArray<T>)
#[allow(clippy::collapsible_else_if)]
pub fn binary_search_ca<'a, T>(
ca: &'a ChunkedArray<T>,
search_values: impl Iterator<Item = Option<T::Physical<'a>>>,
side: SearchSortedSide,
descending: bool,
) -> Vec<IdxSize>
where
T: PolarsDataType,
T::Physical<'a>: TotalOrd + Debug + Copy,
{
let offset = ca.first_non_null().unwrap();
let length = 1 + ca.last_non_null().unwrap() - offset;
let out = ca.slice(offset as i64, length);

debug_assert!(out.null_count() != out.len());
debug_assert!(out.null_count() == 0);
let chunks: Vec<_> = ca.downcast_iter().filter(|c| c.len() > 0).collect();
let has_nulls = ca.null_count() > 0;
let nulls_last = has_nulls && chunks[0].get(0).is_some();
let null_idx = if nulls_last {
if side == SearchSortedSide::Right {
ca.len()
} else {
ca.len() - ca.null_count()
}
} else {
if side == SearchSortedSide::Right {
ca.null_count()
} else {
0
}
} as IdxSize;

(offset, out.rechunk())
if !descending {
if !has_nulls {
if side == SearchSortedSide::Right {
lower_bound_chunks::<T, _>(
&chunks,
search_values,
null_idx,
|chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_gt(sv) },
)
} else {
lower_bound_chunks::<T, _>(
&chunks,
search_values,
null_idx,
|chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_ge(sv) },
)
}
} else {
if side == SearchSortedSide::Right {
lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
if let Some(v) = unsafe { chunk.get_unchecked(i) } {
v.tot_gt(sv)
} else {
nulls_last
}
})
} else {
lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
if let Some(v) = unsafe { chunk.get_unchecked(i) } {
v.tot_ge(sv)
} else {
nulls_last
}
})
}
}
} else {
if !has_nulls {
if side == SearchSortedSide::Right {
lower_bound_chunks::<T, _>(
&chunks,
search_values,
null_idx,
|chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_lt(sv) },
)
} else {
lower_bound_chunks::<T, _>(
&chunks,
search_values,
null_idx,
|chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_le(sv) },
)
}
} else {
if side == SearchSortedSide::Right {
lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
if let Some(v) = unsafe { chunk.get_unchecked(i) } {
v.tot_lt(sv)
} else {
nulls_last
}
})
} else {
lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
if let Some(v) = unsafe { chunk.get_unchecked(i) } {
v.tot_le(sv)
} else {
nulls_last
}
})
}
}
}
}
Loading

0 comments on commit d4c3aba

Please sign in to comment.