Skip to content

Commit

Permalink
perf: improve rolling_median algorithm (#12704)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 7, 2023
1 parent 3d8dd34 commit 7567759
Show file tree
Hide file tree
Showing 33 changed files with 1,674 additions and 319 deletions.
4 changes: 2 additions & 2 deletions crates/polars-arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ macro_rules! impl_sliced {
offset + length <= self.len(),
"the offset of the new Buffer cannot exceed the existing length"
);
unsafe { self.sliced_unchecked(offset, length) }
unsafe { Self::sliced_unchecked(self, offset, length) }
}

/// Returns this array sliced.
Expand All @@ -493,7 +493,7 @@ macro_rules! impl_sliced {
#[inline]
#[must_use]
pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self {
self.slice_unchecked(offset, length);
Self::slice_unchecked(&mut self, offset, length);
self
}
};
Expand Down
11 changes: 11 additions & 0 deletions crates/polars-arrow/src/array/primitive/iterator.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use polars_utils::iter::IntoIteratorCopied;

use super::{MutablePrimitiveArray, PrimitiveArray};
use crate::array::MutableArray;
use crate::bitmap::utils::{BitmapIter, ZipValidity};
Expand Down Expand Up @@ -45,3 +47,12 @@ impl<'a, T: NativeType> MutablePrimitiveArray<T> {
self.values().iter()
}
}

impl<T: NativeType> IntoIteratorCopied for PrimitiveArray<T> {
type OwnedItem = Option<T>;
type IntoIterCopied = Self::IntoIter;

fn into_iter(self) -> <Self as IntoIteratorCopied>::IntoIterCopied {
<Self as IntoIterator>::into_iter(self)
}
}
71 changes: 59 additions & 12 deletions crates/polars-arrow/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Range;

use either::Either;

use super::Array;
Expand All @@ -18,6 +20,8 @@ mod iterator;
mod mutable;
pub use mutable::*;
use polars_error::{polars_bail, PolarsResult};
use polars_utils::index::{Bounded, Indexable, NullCount};
use polars_utils::slice::SliceAble;

/// A [`PrimitiveArray`] is Arrow's semantically equivalent of an immutable `Vec<Option<T>>` where
/// T is [`NativeType`] (e.g. [`i32`]). It implements [`Array`].
Expand Down Expand Up @@ -190,18 +194,18 @@ impl<T: NativeType> PrimitiveArray<T> {
*self.values.get_unchecked(i)
}

/// Returns the element at index `i` or `None` if it is null
/// # Panics
/// iff `i >= self.len()`
#[inline]
pub fn get(&self, i: usize) -> Option<T> {
if !self.is_null(i) {
// soundness: Array::is_null panics if i >= self.len
unsafe { Some(self.value_unchecked(i)) }
} else {
None
}
}
// /// Returns the element at index `i` or `None` if it is null
// /// # Panics
// /// iff `i >= self.len()`
// #[inline]
// pub fn get(&self, i: usize) -> Option<T> {
// if !self.is_null(i) {
// // soundness: Array::is_null panics if i >= self.len
// unsafe { Some(self.value_unchecked(i)) }
// } else {
// None
// }
// }

/// Slices this [`PrimitiveArray`] by an offset and length.
/// # Implementation
Expand Down Expand Up @@ -438,6 +442,37 @@ impl<T: NativeType> Array for PrimitiveArray<T> {
}
}

impl<T: NativeType> SliceAble for PrimitiveArray<T> {
unsafe fn slice_unchecked(&self, range: Range<usize>) -> Self {
self.clone().sliced_unchecked(range.start, range.len())
}

fn slice(&self, range: Range<usize>) -> Self {
self.clone().sliced(range.start, range.len())
}
}

impl<T: NativeType> Indexable for PrimitiveArray<T> {
type Item = Option<T>;

fn get(&self, i: usize) -> Self::Item {
if !self.is_null(i) {
// soundness: Array::is_null panics if i >= self.len
unsafe { Some(self.value_unchecked(i)) }
} else {
None
}
}

unsafe fn get_unchecked(&self, i: usize) -> Self::Item {
if !self.is_null_unchecked(i) {
Some(self.value_unchecked(i))
} else {
None
}
}
}

/// A type definition [`PrimitiveArray`] for `i8`
pub type Int8Array = PrimitiveArray<i8>;
/// A type definition [`PrimitiveArray`] for `i16`
Expand Down Expand Up @@ -505,3 +540,15 @@ impl<T: NativeType> Default for PrimitiveArray<T> {
PrimitiveArray::new(T::PRIMITIVE.into(), Default::default(), None)
}
}

impl<T: NativeType> Bounded for PrimitiveArray<T> {
fn len(&self) -> usize {
self.values.len()
}
}

impl<T: NativeType> NullCount for PrimitiveArray<T> {
fn null_count(&self) -> usize {
<Self as Array>::null_count(self)
}
}
1 change: 1 addition & 0 deletions crates/polars-arrow/src/legacy/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod no_nulls;
pub mod nulls;
pub mod quantile_filter;
mod window;

use std::any::Any;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ pub fn rolling_quantile<T>(
) -> PolarsResult<ArrayRef>
where
T: NativeType
+ IsFloat
+ Float
+ std::iter::Sum
+ AddAssign
Expand All @@ -125,13 +126,33 @@ where
false => det_offsets,
};
match weights {
None => rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
values,
window_size,
min_periods,
offset_fn,
params,
),
None => {
if !center {
let params = params.as_ref().unwrap();
let params = params.downcast_ref::<RollingQuantileParams>().unwrap();
let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
params.interpol,
min_periods,
window_size,
values,
params.prob,
);
let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
return Ok(Box::new(PrimitiveArray::new(
T::PRIMITIVE.into(),
out.into(),
validity.map(|b| b.into()),
)));
}

rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
values,
window_size,
min_periods,
offset_fn,
params,
)
},
Some(weights) => {
let wsum = weights.iter().sum();
polars_ensure!(
Expand Down
14 changes: 14 additions & 0 deletions crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars_utils::slice::GetSaferUnchecked;

use super::*;
use crate::array::MutablePrimitiveArray;

pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
sorted: SortedBufNulls<'a, T>,
Expand Down Expand Up @@ -126,6 +127,19 @@ where
true => det_offsets_center,
false => det_offsets,
};
if !center {
let params = params.as_ref().unwrap();
let params = params.downcast_ref::<RollingQuantileParams>().unwrap();
let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
params.interpol,
min_periods,
window_size,
arr.clone(),
params.prob,
);
let out: PrimitiveArray<T> = out.into();
return Box::new(out);
}
rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
Expand Down
Loading

0 comments on commit 7567759

Please sign in to comment.