Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Improve rolling_median algorithm #12704

Merged
merged 42 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f21591d
refactor(rust): simplify rolling_median update
ritchie46 Nov 28, 2023
fd60cc6
nearest
ritchie46 Nov 28, 2023
7fdb488
fix idx
ritchie46 Nov 28, 2023
6450d92
move cmov
ritchie46 Nov 28, 2023
65fa738
WIP
ritchie46 Nov 27, 2023
297b449
ensure median in delete
ritchie46 Nov 27, 2023
b27d815
implement undelete
ritchie46 Nov 27, 2023
dd5dd4f
add test for undelete
ritchie46 Nov 27, 2023
eb11d23
test all block states
ritchie46 Nov 28, 2023
4140f7d
save
ritchie46 Nov 28, 2023
62a6e7f
median helper
ritchie46 Nov 28, 2023
213af31
setup union median
ritchie46 Nov 28, 2023
be973ad
what to do at end?
ritchie46 Nov 28, 2023
9d86a4f
test block_2 correct
ritchie46 Nov 29, 2023
24bf97c
wip
ritchie46 Nov 30, 2023
f385152
merge
ritchie46 Dec 2, 2023
c2737a4
Merge branch 'main' into rolling_median
ritchie46 Dec 3, 2023
cfd4cbd
union block test
ritchie46 Dec 3, 2023
8896981
test block_union 2
ritchie46 Dec 4, 2023
3764ba0
prepare sharing
ritchie46 Dec 4, 2023
8240257
change sort
ritchie46 Dec 5, 2023
e872328
make generic quantile
ritchie46 Dec 5, 2023
d058750
setup dispatch
ritchie46 Dec 5, 2023
2c9c54d
add test against pandas for random values
ritchie46 Dec 5, 2023
3169a41
correct
ritchie46 Dec 5, 2023
23e25ae
add Indexable and Sliceable trait
ritchie46 Dec 5, 2023
628aac3
try make generic
ritchie46 Dec 5, 2023
680a075
wip
ritchie46 Dec 6, 2023
d5a52da
elide bound checks and call nulls
ritchie46 Dec 6, 2023
4f60a63
add min periods
ritchie46 Dec 6, 2023
f9e054e
move pushable to polars-arrow
ritchie46 Dec 6, 2023
069beab
extend is null
ritchie46 Dec 6, 2023
d2452cf
traits work for nulls/nonulls
ritchie46 Dec 6, 2023
aaea834
working nulls
ritchie46 Dec 6, 2023
4ab38c0
clippy
ritchie46 Dec 6, 2023
82428d5
use dedicated null count
ritchie46 Dec 6, 2023
6bbba21
lin/compile
ritchie46 Dec 6, 2023
bb86562
compile test
ritchie46 Dec 6, 2023
262d9ac
compile tests
ritchie46 Dec 7, 2023
2cd3b79
compile test
ritchie46 Dec 7, 2023
158d6d7
lint
ritchie46 Dec 7, 2023
7812ef9
lint
ritchie46 Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/polars-arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ macro_rules! impl_sliced {
offset + length <= self.len(),
"the offset of the new Buffer cannot exceed the existing length"
);
unsafe { self.sliced_unchecked(offset, length) }
unsafe { Self::sliced_unchecked(self, offset, length) }
}

/// Returns this array sliced.
Expand All @@ -493,7 +493,7 @@ macro_rules! impl_sliced {
#[inline]
#[must_use]
pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self {
self.slice_unchecked(offset, length);
Self::slice_unchecked(&mut self, offset, length);
self
}
};
Expand Down
11 changes: 11 additions & 0 deletions crates/polars-arrow/src/array/primitive/iterator.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use polars_utils::iter::IntoIteratorCopied;

use super::{MutablePrimitiveArray, PrimitiveArray};
use crate::array::MutableArray;
use crate::bitmap::utils::{BitmapIter, ZipValidity};
Expand Down Expand Up @@ -45,3 +47,12 @@ impl<'a, T: NativeType> MutablePrimitiveArray<T> {
self.values().iter()
}
}

impl<T: NativeType> IntoIteratorCopied for PrimitiveArray<T> {
type OwnedItem = Option<T>;
type IntoIterCopied = Self::IntoIter;

fn into_iter(self) -> <Self as IntoIteratorCopied>::IntoIterCopied {
<Self as IntoIterator>::into_iter(self)
}
}
71 changes: 59 additions & 12 deletions crates/polars-arrow/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Range;

use either::Either;

use super::Array;
Expand All @@ -18,6 +20,8 @@ mod iterator;
mod mutable;
pub use mutable::*;
use polars_error::{polars_bail, PolarsResult};
use polars_utils::index::{Bounded, Indexable, NullCount};
use polars_utils::slice::SliceAble;

/// A [`PrimitiveArray`] is Arrow's semantically equivalent of an immutable `Vec<Option<T>>` where
/// T is [`NativeType`] (e.g. [`i32`]). It implements [`Array`].
Expand Down Expand Up @@ -190,18 +194,18 @@ impl<T: NativeType> PrimitiveArray<T> {
*self.values.get_unchecked(i)
}

/// Returns the element at index `i` or `None` if it is null
/// # Panics
/// iff `i >= self.len()`
#[inline]
pub fn get(&self, i: usize) -> Option<T> {
if !self.is_null(i) {
// soundness: Array::is_null panics if i >= self.len
unsafe { Some(self.value_unchecked(i)) }
} else {
None
}
}
// /// Returns the element at index `i` or `None` if it is null
// /// # Panics
// /// iff `i >= self.len()`
// #[inline]
// pub fn get(&self, i: usize) -> Option<T> {
// if !self.is_null(i) {
// // soundness: Array::is_null panics if i >= self.len
// unsafe { Some(self.value_unchecked(i)) }
// } else {
// None
// }
// }

/// Slices this [`PrimitiveArray`] by an offset and length.
/// # Implementation
Expand Down Expand Up @@ -438,6 +442,37 @@ impl<T: NativeType> Array for PrimitiveArray<T> {
}
}

impl<T: NativeType> SliceAble for PrimitiveArray<T> {
unsafe fn slice_unchecked(&self, range: Range<usize>) -> Self {
self.clone().sliced_unchecked(range.start, range.len())
}

fn slice(&self, range: Range<usize>) -> Self {
self.clone().sliced(range.start, range.len())
}
}

impl<T: NativeType> Indexable for PrimitiveArray<T> {
type Item = Option<T>;

fn get(&self, i: usize) -> Self::Item {
if !self.is_null(i) {
// soundness: Array::is_null panics if i >= self.len
unsafe { Some(self.value_unchecked(i)) }
} else {
None
}
}

unsafe fn get_unchecked(&self, i: usize) -> Self::Item {
if !self.is_null_unchecked(i) {
Some(self.value_unchecked(i))
} else {
None
}
}
}

/// A type definition [`PrimitiveArray`] for `i8`
pub type Int8Array = PrimitiveArray<i8>;
/// A type definition [`PrimitiveArray`] for `i16`
Expand Down Expand Up @@ -505,3 +540,15 @@ impl<T: NativeType> Default for PrimitiveArray<T> {
PrimitiveArray::new(T::PRIMITIVE.into(), Default::default(), None)
}
}

impl<T: NativeType> Bounded for PrimitiveArray<T> {
fn len(&self) -> usize {
self.values.len()
}
}

impl<T: NativeType> NullCount for PrimitiveArray<T> {
fn null_count(&self) -> usize {
<Self as Array>::null_count(self)
}
}
1 change: 1 addition & 0 deletions crates/polars-arrow/src/legacy/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod no_nulls;
pub mod nulls;
pub mod quantile_filter;
mod window;

use std::any::Any;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ pub fn rolling_quantile<T>(
) -> PolarsResult<ArrayRef>
where
T: NativeType
+ IsFloat
+ Float
+ std::iter::Sum
+ AddAssign
Expand All @@ -125,13 +126,33 @@ where
false => det_offsets,
};
match weights {
None => rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
values,
window_size,
min_periods,
offset_fn,
params,
),
None => {
if !center {
let params = params.as_ref().unwrap();
let params = params.downcast_ref::<RollingQuantileParams>().unwrap();
let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
params.interpol,
min_periods,
window_size,
values,
params.prob,
);
let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
return Ok(Box::new(PrimitiveArray::new(
T::PRIMITIVE.into(),
out.into(),
validity.map(|b| b.into()),
)));
}

rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
values,
window_size,
min_periods,
offset_fn,
params,
)
},
Some(weights) => {
let wsum = weights.iter().sum();
polars_ensure!(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars_utils::slice::GetSaferUnchecked;

use super::*;
use crate::array::MutablePrimitiveArray;

pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
sorted: SortedBufNulls<'a, T>,
Expand Down Expand Up @@ -126,6 +127,19 @@ where
true => det_offsets_center,
false => det_offsets,
};
if !center {
let params = params.as_ref().unwrap();
let params = params.downcast_ref::<RollingQuantileParams>().unwrap();
let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
params.interpol,
min_periods,
window_size,
arr.clone(),
params.prob,
);
let out: PrimitiveArray<T> = out.into();
return Box::new(out);
}
rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
Expand Down
Loading