Skip to content

Commit

Permalink
use specialize kernels in rolling_groupby aggregation (#3515)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 27, 2022
1 parent 3829abf commit 682663d
Show file tree
Hide file tree
Showing 18 changed files with 1,008 additions and 438 deletions.
3 changes: 2 additions & 1 deletion polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ docs-selection = [
"describe",
"list_eval",
"cumulative_eval",
"timezones",
]

bench = [
Expand Down Expand Up @@ -281,4 +282,4 @@ harness = false
# all-features = true
features = ["docs-selection"]
# defines the configuration attribute `docsrs`
rustdoc-args = ["--cfg", "docsrs"]
rustdoc-args = ["--cfg", "docsrs", "polars-core/docsrs"]
4 changes: 2 additions & 2 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/mean.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::sum::SumWindow;
use super::*;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};
use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls};

pub struct MeanWindow<'a, T> {
sum: SumWindow<'a, T>,
Expand All @@ -9,7 +9,7 @@ pub struct MeanWindow<'a, T> {
impl<
'a,
T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Div<Output = T> + NumCast,
> RollingAggWindow<'a, T> for MeanWindow<'a, T>
> RollingAggWindowNoNulls<'a, T> for MeanWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
Self {
Expand Down
80 changes: 48 additions & 32 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/min_max.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;
use no_nulls;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};
use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls};

pub struct MinWindow<'a, T: NativeType + PartialOrd + IsFloat> {
slice: &'a [T],
Expand All @@ -9,7 +9,7 @@ pub struct MinWindow<'a, T: NativeType + PartialOrd + IsFloat> {
last_end: usize,
}

impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindow<'a, T> for MinWindow<'a, T> {
impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T> for MinWindow<'a, T> {
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
let min = *slice[start..end]
.iter()
Expand All @@ -24,23 +24,32 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindow<'a, T> for MinWi
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
// remove elements that should leave the window
let mut recompute_min = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);
// if we exceed the end, we have a completely new window
// so we recompute
let recompute_min = if start >= self.last_end {
true
} else {
let mut recompute_min = false;

// if the leaving value is the
// max value, we need to recompute the max.
if matches!(
compare_fn_nan_min(leaving_value, &self.min),
Ordering::Equal
) {
recompute_min = true;
break;
// remove elements that should leave the window
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);

// if the leaving value is the
// max value, we need to recompute the max.
if matches!(
compare_fn_nan_min(leaving_value, &self.min),
Ordering::Equal
) {
recompute_min = true;
break;
}
}
}
recompute_min
};

self.last_start = start;

// we traverse all values and compute
Expand Down Expand Up @@ -77,7 +86,7 @@ pub struct MaxWindow<'a, T: NativeType> {
last_end: usize,
}

impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindow<'a, T> for MaxWindow<'a, T> {
impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T> for MaxWindow<'a, T> {
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
let max = *slice[start..end]
.iter()
Expand All @@ -92,21 +101,28 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindow<'a, T> for MaxWi
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
// remove elements that should leave the window
let mut recompute_max = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);
// if the leaving value is the max value, we need to recompute the max.
if matches!(
compare_fn_nan_max(leaving_value, &self.max),
Ordering::Equal
) {
recompute_max = true;
break;
// if we exceed the end, we have a completely new window
// so we recompute
let recompute_max = if start >= self.last_end {
true
} else {
// remove elements that should leave the window
let mut recompute_max = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);
// if the leaving value is the max value, we need to recompute the max.
if matches!(
compare_fn_nan_max(leaving_value, &self.max),
Ordering::Equal
) {
recompute_max = true;
break;
}
}
}
recompute_max
};
self.last_start = start;

// we traverese all values and compute
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ use std::sync::Arc;

pub use mean::*;
pub use min_max::*;
pub use quantile::{rolling_median, rolling_quantile};
pub use quantile::*;
pub use sum::*;
pub use variance::*;

pub trait RollingAggWindow<'a, T: NativeType> {
pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
fn new(slice: &'a [T], start: usize, end: usize) -> Self;

/// Update and recompute the window
Expand All @@ -39,7 +39,7 @@ pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
) -> ArrayRef
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Agg: RollingAggWindow<'a, T>,
Agg: RollingAggWindowNoNulls<'a, T>,
T: Debug + IsFloat + NativeType,
{
let len = values.len();
Expand Down
52 changes: 52 additions & 0 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/quantile.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,58 @@
use super::*;
use crate::index::IdxSize;
use crate::trusted_len::TrustedLen;
use std::fmt::Debug;

// used by agg_quantile
pub fn rolling_quantile_by_iter<T, O>(
values: &[T],
quantile: f64,
interpolation: QuantileInterpolOptions,
offsets: O,
) -> ArrayRef
where
O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
T: std::iter::Sum<T>
+ NativeType
+ Copy
+ std::cmp::PartialOrd
+ num::ToPrimitive
+ NumCast
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Mul<Output = T>
+ IsFloat,
{
if values.is_empty() {
let out: Vec<T> = vec![];
return Arc::new(PrimitiveArray::from_data(
T::PRIMITIVE.into(),
out.into(),
None,
));
}

let mut sorted_window = SortedBuf::new(values, 0, 1);

let out = offsets
.map(|(start, len)| {
let end = start + len;

// safety:
// we are in bounds
if start == end {
None
} else {
let window = unsafe { sorted_window.update(start as usize, end as usize) };
Some(compute_quantile2(window, quantile, interpolation))
}
})
.collect::<PrimitiveArray<T>>();

Arc::new(out)
}

pub(crate) fn compute_quantile2<T>(
vals: &[T],
quantile: f64,
Expand Down
44 changes: 25 additions & 19 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/sum.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;
use no_nulls;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};
use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls};

pub struct SumWindow<'a, T> {
slice: &'a [T],
Expand All @@ -9,8 +9,8 @@ pub struct SumWindow<'a, T> {
last_end: usize,
}

impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign> RollingAggWindow<'a, T>
for SumWindow<'a, T>
impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign>
RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
let sum = slice[start..end].iter().copied().sum::<T>();
Expand All @@ -23,33 +23,39 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign> Rolli
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
// remove elements that should leave the window
let mut recompute_sum = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);
// if we exceed the end, we have a completely new window
// so we recompute
let recompute_sum = if start >= self.last_end {
true
} else {
// remove elements that should leave the window
let mut recompute_sum = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);

if T::is_float() && leaving_value.is_nan() {
recompute_sum = true;
break;
}
if T::is_float() && leaving_value.is_nan() {
recompute_sum = true;
break;
}

self.sum -= *leaving_value;
}
self.sum -= *leaving_value;
}
recompute_sum
};
self.last_start = start;

// we traverese all values and compute
if T::is_float() && recompute_sum {
// we traverse all values and compute
if recompute_sum {
self.sum = self
.slice
.get_unchecked(start..end)
.iter()
.copied()
.sum::<T>();
}
// the max has not left the window, so we only check
// if the entering values are larger
// remove leaving values.
else {
for idx in self.last_end..end {
self.sum += *self.slice.get_unchecked(idx);
Expand Down
38 changes: 23 additions & 15 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/variance.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::mean::MeanWindow;
use super::*;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};
use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls};
use num::pow::Pow;

pub(super) struct SumSquaredWindow<'a, T> {
Expand All @@ -11,7 +11,7 @@ pub(super) struct SumSquaredWindow<'a, T> {
}

impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<Output = T>>
RollingAggWindow<'a, T> for SumSquaredWindow<'a, T>
RollingAggWindowNoNulls<'a, T> for SumSquaredWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
let sum = slice[start..end].iter().map(|v| *v * *v).sum::<T>();
Expand All @@ -24,20 +24,28 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
// remove elements that should leave the window
let mut recompute_sum = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);
// if we exceed the end, we have a completely new window
// so we recompute
let recompute_sum = if start >= self.last_end {
true
} else {
// remove elements that should leave the window
let mut recompute_sum = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);

if T::is_float() && leaving_value.is_nan() {
recompute_sum = true;
break;
}

if T::is_float() && leaving_value.is_nan() {
recompute_sum = true;
break;
self.sum_of_squares -= *leaving_value * *leaving_value;
}
recompute_sum
};

self.sum_of_squares -= *leaving_value * *leaving_value;
}
self.last_start = start;

// we traverese all values and compute
Expand Down Expand Up @@ -82,7 +90,7 @@ impl<
+ One
+ Zero
+ Sub<Output = T>,
> RollingAggWindow<'a, T> for VarWindow<'a, T>
> RollingAggWindowNoNulls<'a, T> for VarWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
Self {
Expand Down Expand Up @@ -185,7 +193,7 @@ impl<
+ Zero
+ Sub<Output = T>
+ Pow<T, Output = T>,
> RollingAggWindow<'a, T> for StdWindow<'a, T>
> RollingAggWindowNoNulls<'a, T> for StdWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
Self {
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-arrow/src/kernels/rolling/nulls/mean.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use super::sum::SumWindow;
use super::*;
use super::{rolling_apply_agg_window, RollingAggWindow};
use super::{rolling_apply_agg_window, RollingAggWindowNulls};

pub(super) struct MeanWindow<'a, T> {
pub struct MeanWindow<'a, T> {
sum: SumWindow<'a, T>,
}

impl<
'a,
T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T> + NumCast + Div<Output = T>,
> RollingAggWindow<'a, T> for MeanWindow<'a, T>
> RollingAggWindowNulls<'a, T> for MeanWindow<'a, T>
{
unsafe fn new(
slice: &'a [T],
Expand Down

0 comments on commit 682663d

Please sign in to comment.