Skip to content

Commit

Permalink
feat[rust, python] add Expr::search_sorted (#4456)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 17, 2022
1 parent a55a3a5 commit bf7b091
Show file tree
Hide file tree
Showing 17 changed files with 169 additions and 2 deletions.
1 change: 1 addition & 0 deletions polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ describe = ["polars-core/describe"]
timezones = ["polars-core/timezones"]
string_justify = ["polars-lazy/string_justify", "polars-ops/string_justify"]
arg_where = ["polars-lazy/arg_where"]
search_sorted = ["polars-lazy/search_sorted"]
date_offset = ["polars-lazy/date_offset"]
trigonometry = ["polars-lazy/trigonometry"]
sign = ["polars-lazy/sign"]
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-arrow/src/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type Idx = usize;
type WindowSize = usize;
type Len = usize;

fn compare_fn_nan_min<T>(a: &T, b: &T) -> Ordering
pub fn compare_fn_nan_min<T>(a: &T, b: &T) -> Ordering
where
T: PartialOrd + IsFloat + NativeType,
{
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ python = ["pyo3"]
row_hash = ["polars-core/row_hash", "polars-ops/hash"]
string_justify = ["polars-ops/string_justify"]
arg_where = []
search_sorted = ["polars-ops/search_sorted"]

# no guarantees whatsoever
private = ["polars-time/private"]
Expand Down
10 changes: 10 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ mod pow;
mod rolling;
#[cfg(feature = "row_hash")]
mod row_hash;
#[cfg(feature = "search_sorted")]
mod search_sorted;
mod shift_and_fill;
#[cfg(feature = "sign")]
mod sign;
Expand Down Expand Up @@ -41,6 +43,8 @@ pub enum FunctionExpr {
IsIn,
#[cfg(feature = "arg_where")]
ArgWhere,
#[cfg(feature = "search_sorted")]
SearchSorted,
#[cfg(feature = "strings")]
StringExpr(StringFunction),
#[cfg(feature = "date_offset")]
Expand Down Expand Up @@ -125,6 +129,8 @@ impl FunctionExpr {
IsIn => with_dtype(DataType::Boolean),
#[cfg(feature = "arg_where")]
ArgWhere => with_dtype(IDX_DTYPE),
#[cfg(feature = "search_sorted")]
SearchSorted => with_dtype(IDX_DTYPE),
#[cfg(feature = "strings")]
StringExpr(s) => {
use StringFunction::*;
Expand Down Expand Up @@ -255,6 +261,10 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
ArgWhere => {
wrap!(arg_where::arg_where)
}
#[cfg(feature = "search_sorted")]
SearchSorted => {
wrap!(search_sorted::search_sorted_impl)
}
#[cfg(feature = "strings")]
StringExpr(s) => s.into(),

Expand Down
10 changes: 10 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use polars_ops::prelude::search_sorted;

use super::*;

pub(super) fn search_sorted_impl(s: &mut [Series]) -> Result<Series> {
let sorted_array = &s[0];
let search_value = s[1].get(0);

search_sorted(sorted_array, &search_value).map(|idx| Series::new(sorted_array.name(), &[idx]))
}
16 changes: 16 additions & 0 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,22 @@ impl Expr {
)
}

#[cfg(feature = "search_sorted")]
/// Find indices where elements should be inserted to maintain order.
pub fn search_sorted<E: Into<Expr>>(self, element: E) -> Expr {
let element = element.into();
Expr::Function {
input: vec![self, element],
function: FunctionExpr::SearchSorted,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: true,
fmt_str: "search_sorted",
},
}
}

/// Cast expression to another data type.
/// Throws an error if conversion had overflows
pub fn strict_cast(self, data_type: DataType) -> Self {
Expand Down
1 change: 1 addition & 0 deletions polars/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ log = []
hash = []
rolling_window = ["polars-core/rolling_window"]
moment = ["polars-core/moment"]
search_sorted = []
4 changes: 4 additions & 0 deletions polars/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
mod log;
#[cfg(feature = "rolling_window")]
mod rolling;
#[cfg(feature = "search_sorted")]
mod search_sorted;
mod various;

#[cfg(feature = "log")]
pub use log::*;
use polars_core::prelude::*;
#[cfg(feature = "rolling_window")]
pub use rolling::*;
#[cfg(feature = "search_sorted")]
pub use search_sorted::*;
pub use various::*;

pub trait SeriesSealed {
Expand Down
75 changes: 75 additions & 0 deletions polars/polars-ops/src/series/ops/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::cmp::Ordering;

use polars_arrow::kernels::rolling::compare_fn_nan_min;
use polars_arrow::prelude::*;
use polars_core::downcast_as_macro_arg_physical;
use polars_core::export::num::NumCast;
use polars_core::prelude::*;

fn search_sorted_ca<T>(
ca: &ChunkedArray<T>,
search_value: T::Native,
) -> std::result::Result<IdxSize, IdxSize>
where
T: PolarsNumericType,
T::Native: PartialOrd + IsFloat + NumCast,
{
let taker = ca.take_rand();

let mut size = ca.len() as IdxSize;
let mut left = 0 as IdxSize;
let mut right = size;
while left < right {
let mid = left + size / 2;

// SAFETY: the call is made safe by the following invariants:
// - `mid >= 0`
// - `mid < size`: `mid` is limited by `[left; right)` bound.
let cmp = match unsafe { taker.get_unchecked(mid as usize) } {
None => Ordering::Less,
Some(value) => compare_fn_nan_min(&value, &search_value),
};

// The reason why we use if/else control flow rather than match
// is because match reorders comparison operations, which is perf sensitive.
// This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra.
if cmp == Ordering::Less {
left = mid + 1;
} else if cmp == Ordering::Greater {
right = mid;
} else {
return Ok(mid);
}

size = right - left;
}
Err(left)
}

pub fn search_sorted(s: &Series, search_value: &AnyValue) -> Result<IdxSize> {
if s.is_logical() {
let search_dtype: DataType = search_value.into();
if &search_dtype != s.dtype() {
return Err(PolarsError::ComputeError(
format!(
"Cannot search a series of dtype: {} with a value of dtype: {}",
s.dtype(),
search_dtype
)
.into(),
));
}
}
let s = s.to_physical_repr();

macro_rules! dispatch {
($ca:expr) => {
match search_sorted_ca($ca, search_value.extract().unwrap()) {
Ok(idx) => Ok(idx),
Err(idx) => Ok(idx),
}
};
}

downcast_as_macro_arg_physical!(s, dispatch)
}
3 changes: 2 additions & 1 deletion polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@
//! - `list_to_struct` - Convert `List` to `Struct` dtypes.
//! - `list_eval` - Apply expressions over list elements.
//! - `cumulative_eval` - Apply expressions over cumulatively increasing windows.
//! - `argwhere` Get indices where condition holds.
//! - `arg_where` - Get indices where condition holds.
//! - `search_sorted` - Find indices where elements should be inserted to maintain order.
//! - `date_offset` Add an offset to dates that take months and leap years into account.
//! - `trigonometry` Trigonometric functions.
//! - `sign` Compute the element-wise sign of a Series.
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ all = [
"asof_join",
"cross_join",
"pct_change",
"polars/search_sorted",
]

# we cannot conditionaly activate simd
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ Computations
Expr.rolling_std
Expr.rolling_sum
Expr.rolling_var
Expr.search_sorted
Expr.sign
Expr.sin
Expr.sinh
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ Computations
Series.rolling_std
Series.rolling_sum
Series.rolling_var
Series.search_sorted
Series.sign
Series.sin
Series.sinh
Expand Down
15 changes: 15 additions & 0 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,21 @@ def arg_min(self) -> Expr:
"""
return wrap_expr(self._pyexpr.arg_min())

def search_sorted(self, element: Expr | int | float) -> Expr:
"""
Find indices where elements should be inserted to maintain order.
.. math:: a[i-1] < v <= a[i]
Parameters
----------
element
Expression or scalar value.
"""
element = expr_to_lit_or_expr(element, str_to_lit=False)
return wrap_expr(self._pyexpr.search_sorted(element._pyexpr))

def sort_by(
self,
by: Expr | str | list[Expr | str],
Expand Down
14 changes: 14 additions & 0 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,20 @@ def arg_max(self) -> int | None:
"""Get the index of the maximal value."""
return self._s.arg_max()

def search_sorted(self, element: int | float) -> int:
"""
Find indices where elements should be inserted to maintain order.
.. math:: a[i-1] < v <= a[i]
Parameters
----------
element
Expression or scalar value.
"""
return pli.select(pli.lit(self).search_sorted(element))[0, 0]

def unique(self, maintain_order: bool = False) -> Series:
"""
Get unique elements in series.
Expand Down
4 changes: 4 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ impl PyExpr {
pub fn arg_min(&self) -> PyExpr {
self.clone().inner.arg_min().into()
}

pub fn search_sorted(&self, element: PyExpr) -> PyExpr {
self.inner.clone().search_sorted(element.inner).into()
}
pub fn take(&self, idx: PyExpr) -> PyExpr {
self.clone().inner.take(idx.inner).into()
}
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import cast

import numpy as np

import polars as pl
from polars.testing import assert_series_equal, verify_series_and_expr_api

Expand Down Expand Up @@ -323,3 +325,13 @@ def test_unique_empty() -> None:
for dt in [pl.Utf8, pl.Boolean, pl.Int32, pl.UInt32]:
s = pl.Series([], dtype=dt)
assert s.unique().series_equal(s)


def test_search_sorted() -> None:
for seed in [1, 2, 3]:
np.random.seed(seed)
a = np.sort(np.random.randn(10) * 100)
s = pl.Series(a)

for v in range(int(np.min(a)), int(np.max(a)), 20):
assert np.searchsorted(a, v) == s.search_sorted(v)

0 comments on commit bf7b091

Please sign in to comment.