Skip to content

Commit

Permalink
add specialized rolling_std kernel (#3476)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 23, 2022
1 parent 26f7b2a commit 0dfea6d
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 10 deletions.
2 changes: 1 addition & 1 deletion polars/polars-arrow/src/kernels/rolling/no_nulls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use mean::rolling_mean;
pub use min_max::{rolling_max, rolling_min};
pub use quantile::{rolling_median, rolling_quantile};
pub use sum::rolling_sum;
pub use variance::rolling_var;
pub use variance::{rolling_std, rolling_var};

pub(crate) trait RollingAggWindow<'a, T: NativeType> {
fn new(slice: &'a [T], start: usize, end: usize) -> Self;
Expand Down
73 changes: 73 additions & 0 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/variance.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::mean::MeanWindow;
use super::*;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};
use num::pow::Pow;

pub(super) struct SumSquaredWindow<'a, T> {
slice: &'a [T],
Expand Down Expand Up @@ -157,6 +158,78 @@ where
}
}

// E[(xi - E[x])^2]
// can be expanded to
// E[x^2] - E[x]^2
struct StdWindow<'a, T> {
var: VarWindow<'a, T>,
}

impl<
'a,
T: NativeType
+ IsFloat
+ std::iter::Sum
+ AddAssign
+ SubAssign
+ Div<Output = T>
+ NumCast
+ One
+ Sub<Output = T>
+ Pow<T, Output = T>,
> RollingAggWindow<'a, T> for StdWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
Self {
var: VarWindow::new(slice, start, end),
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
let var = self.var.update(start, end);
var.pow(NumCast::from(0.5).unwrap())
}
}

pub fn rolling_std<T>(
values: &[T],
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType
+ Float
+ IsFloat
+ std::iter::Sum
+ AddAssign
+ SubAssign
+ Div<Output = T>
+ NumCast
+ One
+ Sub<Output = T>
+ Pow<T, Output = T>,
{
match (center, weights) {
(true, None) => rolling_apply_agg_window::<StdWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
),
(false, None) => rolling_apply_agg_window::<StdWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
),
(_, Some(_)) => {
panic!("weights not yet supported for rolling_std")
}
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-arrow/src/kernels/rolling/nulls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub use mean::rolling_mean;
pub use min_max::{rolling_max, rolling_min};
pub use quantile::{rolling_median, rolling_quantile};
pub use sum::rolling_sum;
pub use variance::rolling_var;
pub use variance::{rolling_std, rolling_var};

pub(crate) trait RollingAggWindow<'a, T: NativeType> {
unsafe fn new(
Expand Down
78 changes: 78 additions & 0 deletions polars/polars-arrow/src/kernels/rolling/nulls/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::*;
use mean::MeanWindow;
use nulls;
use nulls::{rolling_apply_agg_window, RollingAggWindow};
use num::pow::Pow;

pub struct SumSquaredWindow<'a, T> {
slice: &'a [T],
Expand Down Expand Up @@ -201,3 +202,80 @@ where
)
}
}

struct StdWindow<'a, T> {
var: VarWindow<'a, T>,
}

impl<
'a,
T: NativeType
+ IsFloat
+ std::iter::Sum
+ AddAssign
+ SubAssign
+ Div<Output = T>
+ NumCast
+ One
+ Add<Output = T>
+ Sub<Output = T>
+ Pow<T, Output = T>,
> RollingAggWindow<'a, T> for StdWindow<'a, T>
{
unsafe fn new(
slice: &'a [T],
validity: &'a Bitmap,
start: usize,
end: usize,
min_periods: usize,
) -> Self {
Self {
var: VarWindow::new(slice, validity, start, end, min_periods),
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
self.var
.update(start, end)
.map(|var| var.pow(NumCast::from(0.5).unwrap()))
}
}

pub fn rolling_std<T>(
arr: &PrimitiveArray<T>,
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType
+ std::iter::Sum<T>
+ Zero
+ AddAssign
+ SubAssign
+ IsFloat
+ Float
+ Pow<T, Output = T>,
{
if weights.is_some() {
panic!("weights not yet supported on array with null values")
}
if center {
rolling_apply_agg_window::<StdWindow<_>, _, _>(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
det_offsets_center,
)
} else {
rolling_apply_agg_window::<StdWindow<_>, _, _>(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
det_offsets,
)
}
}
43 changes: 35 additions & 8 deletions polars/polars-core/src/chunked_array/ops/rolling_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ mod inner_mod {
where
ChunkedArray<T>: IntoSeries,
T: PolarsFloatType,
T::Native: Float + IsFloat + SubAssign,
T::Native: Float + IsFloat + SubAssign + num::pow::Pow<T::Native, Output = T::Native>,
{
/// Apply a rolling custom function. This is pretty slow because of dynamic dispatch.
pub fn rolling_apply_float<F>(&self, window_size: usize, f: F) -> Result<Self>
Expand Down Expand Up @@ -478,14 +478,41 @@ mod inner_mod {
/// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting
/// values will be aggregated to their std.
pub fn rolling_std(&self, options: RollingOptions) -> Result<Series> {
let s = self.rolling_var(options)?;
// Safety:
// We are still guarded by the type system.
let out = match self.dtype() {
DataType::Float32 => s.f32().unwrap().pow_f32(0.5).into_series(),
_ => s.f64().unwrap().pow_f64(0.5).into_series(), //Float64 case
check_input(options.window_size, options.min_periods)?;
let ca = self.rechunk();

// weights is only implemented by var kernel
if options.weights.is_some() {
if !matches!(self.dtype(), DataType::Float64 | DataType::Float32) {
let s = ca.cast(&DataType::Float64).unwrap();
return s
.f64()
.unwrap()
.rolling_var(options)
.and_then(|ca| ca.pow(0.5));
} else {
return ca.rolling_var(options).and_then(|ca| ca.pow(0.5));
}
}

let arr = ca.downcast_iter().next().unwrap();
let arr = match self.has_validity() {
false => rolling::no_nulls::rolling_std(
arr.values(),
options.window_size,
options.min_periods,
options.center,
options.weights.as_deref(),
),
_ => rolling::nulls::rolling_std(
arr,
options.window_size,
options.min_periods,
options.center,
options.weights.as_deref(),
),
};
Ok(out)
Series::try_from((self.name(), arr))
}
}
}
Expand Down

0 comments on commit 0dfea6d

Please sign in to comment.