Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow search_sorted directly on multiple chunks, and fix behavior around nulls #16447

Merged
merged 4 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 11 additions & 25 deletions crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
//! Implementations of the ChunkAgg trait.
use num_traits::Float;

use self::search_sorted::{
binary_search_array, slice_sorted_non_null_and_offset, SearchSortedSide,
};
use self::search_sorted::{binary_search_ca, SearchSortedSide};
use crate::prelude::*;

impl<T> ChunkedArray<T>
Expand All @@ -14,31 +11,21 @@ where
fn float_arg_max_sorted_ascending(&self) -> usize {
let ca = self;
debug_assert!(ca.is_sorted_ascending_flag());
let is_descending = false;
let side = SearchSortedSide::Left;

let maybe_max_idx = ca.last_non_null().unwrap();

let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) };
if !maybe_max.is_nan() {
return maybe_max_idx;
}

let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) };
let arr = unsafe { ca.downcast_get_unchecked(0) };
let search_val = T::Native::nan();
let idx = binary_search_array(side, arr, search_val, is_descending) as usize;

let idx = idx.saturating_sub(1);

offset + idx
let search_val = std::iter::once(Some(T::Native::nan()));
let idx = binary_search_ca(ca, search_val, SearchSortedSide::Left, false)[0] as usize;
idx.saturating_sub(1)
}

fn float_arg_max_sorted_descending(&self) -> usize {
let ca = self;
debug_assert!(ca.is_sorted_descending_flag());
let is_descending = true;
let side = SearchSortedSide::Right;

let maybe_max_idx = ca.first_non_null().unwrap();

Expand All @@ -47,14 +34,13 @@ where
return maybe_max_idx;
}

let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) };
let arr = unsafe { ca.downcast_get_unchecked(0) };
let search_val = T::Native::nan();
let idx = binary_search_array(side, arr, search_val, is_descending) as usize;

let idx = if idx == arr.len() { idx - 1 } else { idx };

offset + idx
let search_val = std::iter::once(Some(T::Native::nan()));
let idx = binary_search_ca(ca, search_val, SearchSortedSide::Right, true)[0] as usize;
if idx == ca.len() {
idx - 1
} else {
idx
}
}
}

Expand Down
275 changes: 184 additions & 91 deletions crates/polars-core/src/chunked_array/ops/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::cmp::Ordering;
use std::fmt::Debug;

#[cfg(feature = "serde")]
Expand All @@ -15,114 +14,208 @@ pub enum SearchSortedSide {
Right,
}

/// Search the left or right index that still fulfills the requirements.
fn get_side_idx<'a, A>(side: SearchSortedSide, mid: IdxSize, arr: &'a A, len: usize) -> IdxSize
/// Computes the first point on [lo, hi) where f is true, assuming it is first
/// always false and then always true. It is assumed f(hi) is true.
/// midpoint is a function that returns some lo < i < hi if one exists, else lo.
fn lower_bound<I, F, M>(mut lo: I, mut hi: I, midpoint: M, f: F) -> I
where
A: StaticArray,
A::ValueT<'a>: TotalOrd + Debug + Copy,
I: PartialEq + Eq,
M: Fn(&I, &I) -> I,
F: Fn(&I) -> bool,
{
let mut mid = mid;

// approach the boundary from any side
// this is O(n) we could make this binary search later
match side {
SearchSortedSide::Any => mid,
SearchSortedSide::Left => {
if mid as usize == len {
mid -= 1;
}
loop {
let m = midpoint(&lo, &hi);
if m == lo {
return if f(&lo) { lo } else { hi };
}

let current = unsafe { arr.get_unchecked(mid as usize) };
loop {
if mid == 0 {
return mid;
}
mid -= 1;
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) {
return mid + 1;
}
}
},
SearchSortedSide::Right => {
if mid as usize == len {
return mid;
}
let current = unsafe { arr.get_unchecked(mid as usize) };
let bound = (len - 1) as IdxSize;
loop {
if mid >= bound {
return mid + 1;
}
mid += 1;
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) {
return mid;
}
}
},
if f(&m) {
hi = m;
} else {
lo = m;
}
}
}

pub fn binary_search_array<'a, A>(
side: SearchSortedSide,
arr: &'a A,
search_value: A::ValueT<'a>,
descending: bool,
) -> IdxSize
/// Search through a series of chunks for the first position where f(x) is true,
/// assuming it is first always false and then always true. It repeats this for
/// each value in search_values. If the search value is null null_idx is returned.
///
/// Assumes the chunks are non-empty.
pub fn lower_bound_chunks<'a, T, F>(
chunks: &[&'a T::Array],
search_values: impl Iterator<Item = Option<T::Physical<'a>>>,
null_idx: IdxSize,
f: F,
) -> Vec<IdxSize>
where
A: StaticArray,
A::ValueT<'a>: TotalOrd + Debug + Copy,
T: PolarsDataType,
F: Fn(&'a T::Array, usize, &T::Physical<'a>) -> bool,
{
let mut size = arr.len() as IdxSize;
let mut left = 0 as IdxSize;
let mut right = size;
while left < right {
let mid = left + size / 2;
if chunks.is_empty() {
return search_values.map(|_| 0).collect();
}

// SAFETY: the call is made safe by the following invariants:
// - `mid >= 0`
// - `mid < size`: `mid` is limited by `[left; right)` bound.
let cmp = match unsafe { arr.get_unchecked(mid as usize) } {
None => Ordering::Less,
Some(value) => {
if descending {
search_value.tot_cmp(&value)
// Fast-path: only a single chunk.
if chunks.len() == 1 {
let chunk = &chunks[0];
return search_values
.map(|ov| {
if let Some(v) = ov {
lower_bound(0, chunk.len(), |l, r| (l + r) / 2, |m| f(chunk, *m, &v)) as IdxSize
} else {
value.tot_cmp(&search_value)
null_idx
}
},
};

// The reason why we use if/else control flow rather than match
// is because match reorders comparison operations, which is perf sensitive.
// This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra.
if cmp == Ordering::Less {
left = mid + 1;
} else if cmp == Ordering::Greater {
right = mid;
} else {
return get_side_idx(side, mid, arr, arr.len());
}
})
.collect();
}

size = right - left;
// Multiple chunks, precompute prefix sum of lengths so we can look up
// in O(1) the global position of chunk i.
let mut sz = 0;
let mut chunk_len_prefix_sum = Vec::with_capacity(chunks.len() + 1);
for c in chunks {
chunk_len_prefix_sum.push(sz);
sz += c.len();
}
chunk_len_prefix_sum.push(sz);

// For each search value do a binary search on (chunk_idx, idx_in_chunk) pairs.
search_values
.map(|ov| {
let Some(v) = ov else {
return null_idx;
};
let left = (0, 0);
let right = (chunks.len(), 0);
let midpoint = |l: &(usize, usize), r: &(usize, usize)| {
if l.0 == r.0 {
// Within same chunk.
(l.0, (l.1 + r.1) / 2)
} else if l.0 + 1 == r.0 {
// Two adjacent chunks, might have to be l or r.
let left_len = chunks[l.0].len() - l.1;

left
let logical_mid = (left_len + r.1) / 2;
if logical_mid < left_len {
(l.0, l.1 + logical_mid)
} else {
(r.0, logical_mid - left_len)
}
} else {
// Has a chunk in between.
((l.0 + r.0) / 2, 0)
}
};

let bound = lower_bound(left, right, midpoint, |m| {
f(unsafe { chunks.get_unchecked(m.0) }, m.1, &v)
});

(chunk_len_prefix_sum[bound.0] + bound.1) as IdxSize
})
.collect()
}

/// Get a slice of the non-null values of a sorted array. The returned array
/// will have a single chunk.
/// # Safety
/// The array is sorted and has at least one non-null value.
pub unsafe fn slice_sorted_non_null_and_offset<T>(ca: &ChunkedArray<T>) -> (usize, ChunkedArray<T>)
#[allow(clippy::collapsible_else_if)]
pub fn binary_search_ca<'a, T>(
ca: &'a ChunkedArray<T>,
search_values: impl Iterator<Item = Option<T::Physical<'a>>>,
side: SearchSortedSide,
descending: bool,
) -> Vec<IdxSize>
where
T: PolarsDataType,
T::Physical<'a>: TotalOrd + Debug + Copy,
{
let offset = ca.first_non_null().unwrap();
let length = 1 + ca.last_non_null().unwrap() - offset;
let out = ca.slice(offset as i64, length);

debug_assert!(out.null_count() != out.len());
debug_assert!(out.null_count() == 0);
let chunks: Vec<_> = ca.downcast_iter().filter(|c| c.len() > 0).collect();
let has_nulls = ca.null_count() > 0;
let nulls_last = has_nulls && chunks[0].get(0).is_some();
let null_idx = if nulls_last {
if side == SearchSortedSide::Right {
ca.len()
} else {
ca.len() - ca.null_count()
}
} else {
if side == SearchSortedSide::Right {
ca.null_count()
} else {
0
}
} as IdxSize;

(offset, out.rechunk())
if !descending {
if !has_nulls {
if side == SearchSortedSide::Right {
lower_bound_chunks::<T, _>(
&chunks,
search_values,
null_idx,
|chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_gt(sv) },
)
} else {
lower_bound_chunks::<T, _>(
&chunks,
search_values,
null_idx,
|chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_ge(sv) },
)
}
} else {
if side == SearchSortedSide::Right {
lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
if let Some(v) = unsafe { chunk.get_unchecked(i) } {
v.tot_gt(sv)
} else {
nulls_last
}
})
} else {
lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
if let Some(v) = unsafe { chunk.get_unchecked(i) } {
v.tot_ge(sv)
} else {
nulls_last
}
})
}
}
} else {
if !has_nulls {
if side == SearchSortedSide::Right {
lower_bound_chunks::<T, _>(
&chunks,
search_values,
null_idx,
|chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_lt(sv) },
)
} else {
lower_bound_chunks::<T, _>(
&chunks,
search_values,
null_idx,
|chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_le(sv) },
)
}
} else {
if side == SearchSortedSide::Right {
lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
if let Some(v) = unsafe { chunk.get_unchecked(i) } {
v.tot_lt(sv)
} else {
nulls_last
}
})
} else {
lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
if let Some(v) = unsafe { chunk.get_unchecked(i) } {
v.tot_le(sv)
} else {
nulls_last
}
})
}
}
}
}
Loading