Skip to content

Commit

Permalink
feat(rust, python): add search_sorted for arrays and utf8 dtype (#6083)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 6, 2023
1 parent 48995af commit c129710
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 51 deletions.
2 changes: 2 additions & 0 deletions polars/polars-core/src/chunked_array/ops/take/take_random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ where
{
type Item = I;

#[inline]
fn get(&self, index: usize) -> Option<Self::Item> {
match self {
Self::SingleNoNull(s) => s.get(index),
Expand All @@ -94,6 +95,7 @@ where
}
}

#[inline]
unsafe fn get_unchecked(&self, index: usize) -> Option<Self::Item> {
match self {
Self::SingleNoNull(s) => s.get_unchecked(index),
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub enum FunctionExpr {
#[cfg(feature = "arg_where")]
ArgWhere,
#[cfg(feature = "search_sorted")]
SearchSorted,
SearchSorted(SearchSortedSide),
#[cfg(feature = "strings")]
StringExpr(StringFunction),
#[cfg(feature = "dtype-binary")]
Expand Down Expand Up @@ -133,7 +133,7 @@ impl Display for FunctionExpr {
#[cfg(feature = "arg_where")]
ArgWhere => "arg_where",
#[cfg(feature = "search_sorted")]
SearchSorted => "search_sorted",
SearchSorted(_) => "search_sorted",
#[cfg(feature = "strings")]
StringExpr(s) => return write!(f, "{s}"),
#[cfg(feature = "dtype-binary")]
Expand Down Expand Up @@ -280,8 +280,8 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
wrap!(arg_where::arg_where)
}
#[cfg(feature = "search_sorted")]
SearchSorted => {
wrap!(search_sorted::search_sorted_impl)
SearchSorted(side) => {
map_as_slice!(search_sorted::search_sorted_impl, side)
}
#[cfg(feature = "strings")]
StringExpr(s) => s.into(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl FunctionExpr {
#[cfg(feature = "arg_where")]
ArgWhere => with_dtype(IDX_DTYPE),
#[cfg(feature = "search_sorted")]
SearchSorted => with_dtype(IDX_DTYPE),
SearchSorted(_) => with_dtype(IDX_DTYPE),
#[cfg(feature = "strings")]
StringExpr(s) => {
use StringFunction::*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use polars_ops::prelude::search_sorted;

use super::*;

pub(super) fn search_sorted_impl(s: &mut [Series]) -> PolarsResult<Series> {
pub(super) fn search_sorted_impl(s: &mut [Series], side: SearchSortedSide) -> PolarsResult<Series> {
let sorted_array = &s[0];
let search_value = s[1].get(0).unwrap();
let search_value = &s[1];

search_sorted(sorted_array, &search_value).map(|idx| Series::new(sorted_array.name(), &[idx]))
search_sorted(sorted_array, search_value, side).map(|ca| ca.into_series())
}
5 changes: 3 additions & 2 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,15 +537,16 @@ impl Expr {

#[cfg(feature = "search_sorted")]
/// Find indices where elements should be inserted to maintain order.
pub fn search_sorted<E: Into<Expr>>(self, element: E) -> Expr {
pub fn search_sorted<E: Into<Expr>>(self, element: E, side: SearchSortedSide) -> Expr {
let element = element.into();
Expr::Function {
input: vec![self, element],
function: FunctionExpr::SearchSorted,
function: FunctionExpr::SearchSorted(side),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
auto_explode: true,
fmt_str: "search_sorted",
cast_to_supertypes: true,
..Default::default()
},
}
Expand Down
221 changes: 187 additions & 34 deletions polars/polars-ops/src/series/ops/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,124 @@
use std::cmp::Ordering;
use std::fmt::Debug;

use arrow::array::{PrimitiveArray, Utf8Array};
use polars_arrow::kernels::rolling::compare_fn_nan_max;
use polars_arrow::prelude::*;
use polars_core::downcast_as_macro_arg_physical;
use polars_core::export::num::NumCast;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

fn search_sorted_ca<T>(
ca: &ChunkedArray<T>,
search_value: T::Native,
) -> std::result::Result<IdxSize, IdxSize>
where
T: PolarsNumericType,
T::Native: PartialOrd + IsFloat + NumCast,
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum SearchSortedSide {
Any,
Left,
Right,
}

impl Default for SearchSortedSide {
fn default() -> Self {
SearchSortedSide::Any
}
}
// Utility trait to make generics work
trait GetArray<T> {
unsafe fn _get_value_unchecked(&self, i: usize) -> Option<T>;
}

impl<T: NumericNative> GetArray<T> for &PrimitiveArray<T> {
unsafe fn _get_value_unchecked(&self, i: usize) -> Option<T> {
self.get_unchecked(i)
}
}

impl<'a> GetArray<&'a str> for &'a Utf8Array<i64> {
unsafe fn _get_value_unchecked(&self, i: usize) -> Option<&'a str> {
self.get_unchecked(i)
}
}

/// Search the left or right index that still fulfills the requirements.
fn finish_side<G, I>(
side: SearchSortedSide,
out: &mut Vec<IdxSize>,
mid: IdxSize,
arr: G,
len: usize,
) where
G: GetArray<I>,
I: PartialEq + Debug + Copy,
{
let taker = ca.take_rand();
let mut mid = mid;

// approach the boundary from any side
// this is O(n) we could make this binary search later
match side {
SearchSortedSide::Any => {
out.push(mid);
}
SearchSortedSide::Left => {
if mid as usize == len {
mid -= 1;
}

let current = unsafe { arr._get_value_unchecked(mid as usize) };
loop {
if mid == 0 {
out.push(mid);
break;
}
mid -= 1;
if current != unsafe { arr._get_value_unchecked(mid as usize) } {
out.push(mid + 1);
break;
}
}
}
SearchSortedSide::Right => {
if mid as usize == len {
out.push(mid);
return;
}
let current = unsafe { arr._get_value_unchecked(mid as usize) };
let bound = (len - 1) as IdxSize;
loop {
if mid >= bound {
out.push(mid + 1);
break;
}
mid += 1;
if current != unsafe { arr._get_value_unchecked(mid as usize) } {
out.push(mid);
break;
}
}
}
}
}

let mut size = ca.len() as IdxSize;
fn binary_search_array<G, I>(
side: SearchSortedSide,
out: &mut Vec<IdxSize>,
arr: G,
len: usize,
search_value: I,
) where
G: GetArray<I>,
I: PartialEq + Debug + Copy + PartialOrd + IsFloat,
{
let mut size = len as IdxSize;
let mut left = 0 as IdxSize;
let mut right = size;
let current_len = out.len();
while left < right {
let mid = left + size / 2;

// SAFETY: the call is made safe by the following invariants:
// - `mid >= 0`
// - `mid < size`: `mid` is limited by `[left; right)` bound.
let cmp = match unsafe { taker.get_unchecked(mid as usize) } {
let cmp = match unsafe { arr._get_value_unchecked(mid as usize) } {
None => Ordering::Less,
Some(value) => compare_fn_nan_max(&value, &search_value),
};
Expand All @@ -38,38 +131,98 @@ where
} else if cmp == Ordering::Greater {
right = mid;
} else {
return Ok(mid);
finish_side(side, out, mid, arr, len);
break;
}

size = right - left;
}
Err(left)
if out.len() == current_len {
out.push(left);
}
}

pub fn search_sorted(s: &Series, search_value: &AnyValue) -> PolarsResult<IdxSize> {
if s.dtype().is_logical() {
let search_dtype: DataType = search_value.into();
if &search_dtype != s.dtype() {
return Err(PolarsError::ComputeError(
format!(
"Cannot search a series of dtype: {} with a value of dtype: {}",
s.dtype(),
search_dtype
)
.into(),
));
fn search_sorted_ca_array<T>(
ca: &ChunkedArray<T>,
search_values: &ChunkedArray<T>,
side: SearchSortedSide,
) -> Vec<IdxSize>
where
T: PolarsNumericType,
{
let ca = ca.rechunk();
let arr = ca.downcast_iter().next().unwrap();

let mut out = Vec::with_capacity(search_values.len());

for opt_v in search_values {
match opt_v {
None => out.push(0),
Some(search_value) => binary_search_array(side, &mut out, arr, ca.len(), search_value),
}
}
let s = s.to_physical_repr();
out
}

macro_rules! dispatch {
($ca:expr) => {
match search_sorted_ca($ca, search_value.extract().unwrap()) {
Ok(idx) => Ok(idx),
Err(idx) => Ok(idx),
fn search_sorted_utf8_array(
ca: &Utf8Chunked,
search_values: &Utf8Chunked,
side: SearchSortedSide,
) -> Vec<IdxSize> {
let ca = ca.rechunk();
let arr = ca.downcast_iter().next().unwrap();

let mut out = Vec::with_capacity(search_values.len());

for opt_v in search_values {
match opt_v {
None => out.push(0),
Some(search_value) => {
binary_search_array(side, &mut out, arr, ca.len(), search_value);
}
};
}
}
out
}

pub fn search_sorted(
s: &Series,
search_values: &Series,
side: SearchSortedSide,
) -> PolarsResult<IdxCa> {
let original_dtype = s.dtype();
let s = s.to_physical_repr();
let phys_dtype = s.dtype();

match phys_dtype {
DataType::Utf8 => {
let ca = s.utf8().unwrap();
let search_values = search_values.utf8()?;
let idx = search_sorted_utf8_array(ca, search_values, side);

Ok(IdxCa::new_vec(s.name(), idx))
}
dt if dt.is_numeric() => {
let search_values = search_values.to_physical_repr();

let idx = with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let search_values: &ChunkedArray<$T> = search_values.as_ref().as_ref().as_ref();

search_sorted_ca_array(ca, search_values, side)
});
Ok(IdxCa::new_vec(s.name(), idx))
}
_ => Err(PolarsError::ComputeError(
format!("'search_sorted' not allowed on dtype: {original_dtype}").into(),
)),
}
}

downcast_as_macro_arg_physical!(s, dispatch)
#[test]
fn test_search_sorted() {
let s = Series::new("", [1, 1, 4, 4]);
let b = Series::new("", [0, 1, 2, 4, 5]);
let out = search_sorted_array(&s, &b, SearchSortedSide::Right).unwrap();
dbg!(out);
}
2 changes: 1 addition & 1 deletion py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
NullBehavior,
RankMethod,
RollingInterpolationMethod,
SearchSortedSide,
)
elif os.getenv("BUILDING_SPHINX_DOCS"):
property = sphinx_accessor
Expand Down Expand Up @@ -1859,7 +1860,9 @@ def arg_min(self) -> Expr:
"""
return wrap_expr(self._pyexpr.arg_min())

def search_sorted(self, element: Expr | int | float) -> Expr:
def search_sorted(
self, element: Expr | int | float | pli.Series, side: SearchSortedSide = "any"
) -> Expr:
"""
Find indices where elements should be inserted to maintain order.
Expand All @@ -1869,6 +1872,10 @@ def search_sorted(self, element: Expr | int | float) -> Expr:
----------
element
Expression or scalar value.
side : {'any', 'left', 'right'}
If 'any', the index of the first suitable location found is given.
If 'left', the index of the leftmost suitable location found is given.
If 'right', return the rightmost suitable location found is given.
Examples
--------
Expand All @@ -1895,7 +1902,7 @@ def search_sorted(self, element: Expr | int | float) -> Expr:
"""
element = expr_to_lit_or_expr(element, str_to_lit=False)
return wrap_expr(self._pyexpr.search_sorted(element._pyexpr))
return wrap_expr(self._pyexpr.search_sorted(element._pyexpr, side))

def sort_by(
self,
Expand Down

0 comments on commit c129710

Please sign in to comment.