Skip to content

Commit

Permalink
improve performance of rolling_var/ rolling_std
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 19, 2021
1 parent d1b8cdc commit d053cd0
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 60 deletions.
127 changes: 124 additions & 3 deletions polars/polars-core/src/chunked_array/ops/rolling_window.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::prelude::*;
use arrow::array::{Array, PrimitiveArray};
use arrow::bitmap::utils::count_zeros;
use arrow::bitmap::MutableBitmap;
use num::{Bounded, NumCast, One, Zero};
use num::{Bounded, Float, NumCast, One, Zero};
use polars_arrow::bit_util::unset_bit_raw;
use polars_arrow::trusted_len::PushUnchecked;
use polars_arrow::utils::CustomIterTools;
Expand Down Expand Up @@ -447,10 +448,26 @@ where
}
}

fn variance<T>(vals: &[T]) -> T
where
T: Float + std::iter::Sum,
{
let len = T::from(vals.len()).unwrap();
let mean = vals.iter().copied().sum::<T>() / len;

let mut sum = T::zero();
for &val in vals {
let v = val - mean;
sum = sum + v * v
}
sum / (len - T::one())
}

impl<T> ChunkedArray<T>
where
ChunkedArray<T>: IntoSeries,
T: PolarsFloatType,
T::Native: Default,
T::Native: Default + std::iter::Sum + Float,
{
pub fn rolling_apply_float<F>(&self, window_size: usize, f: F) -> Result<Self>
where
Expand All @@ -466,7 +483,7 @@ where

let mut validity = MutableBitmap::with_capacity(ca.len());
validity.extend_constant(window_size - 1, false);
validity.extend_constant(ca.len() - window_size - 1, true);
validity.extend_constant(ca.len() - (window_size - 1), true);
let validity_ptr = validity.as_slice().as_ptr() as *mut u8;

let mut values = AlignedVec::with_capacity(ca.len());
Expand Down Expand Up @@ -496,6 +513,80 @@ where
);
Ok(Self::new_from_chunks(self.name(), vec![Arc::new(arr)]))
}

pub fn rolling_var(&self, window_size: usize) -> Self {
let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
let values = arr.values().as_slice();

let mut validity = MutableBitmap::with_capacity(ca.len());
validity.extend_constant(window_size - 1, false);
validity.extend_constant(ca.len() - (window_size - 1), true);
let validity_ptr = validity.as_slice().as_ptr() as *mut u8;

let mut rolling_values = AlignedVec::with_capacity(ca.len());
rolling_values.extend_constant(window_size - 1, Default::default());

if ca.null_count() == 0 {
for offset in 0..self.len() + 1 - window_size {
let window = &values[offset..offset + window_size];
let val = variance(window);

unsafe {
// Safety:
// We pre-allocated enough capacity
rolling_values.push_unchecked(val);
};
}
} else {
let old_validity = arr.validity().as_ref().unwrap().clone();
let (bytes, bytes_offset, _) = old_validity.as_slice();
for offset in 0..self.len() + 1 - window_size {
if count_zeros(bytes, bytes_offset + offset, window_size) > 0 {
unsafe {
// Safety:
// We pre-allocated enough capacity
rolling_values.push_unchecked(Default::default());
// Safety:
// We are in bounds
unset_bit_raw(validity_ptr, offset + window_size - 1)
};
} else {
let window = &values[offset..offset + window_size];
let val = variance(window);
// Safety:
// We pre-allocated enough capacity
unsafe { rolling_values.push_unchecked(val) };
}
}
}

let arr = PrimitiveArray::from_data(
T::get_dtype().to_arrow(),
rolling_values.into(),
Some(validity.into()),
);
Self::new_from_chunks(self.name(), vec![Arc::new(arr)])
}

pub fn rolling_std(&self, window_size: usize) -> Self {
let s = self.rolling_var(window_size).into_series();
// Safety:
// We are still guarded by the type system.
match self.dtype() {
DataType::Float32 => unsafe {
std::mem::transmute::<Float32Chunked, ChunkedArray<T>>(
s.f32().unwrap().pow_f32(0.5),
)
},
DataType::Float64 => unsafe {
std::mem::transmute::<Float64Chunked, ChunkedArray<T>>(
s.f64().unwrap().pow_f64(0.5),
)
},
_ => unreachable!(),
}
}
}

impl ChunkRollApply for ListChunked {}
Expand Down Expand Up @@ -614,4 +705,34 @@ mod test {
]
);
}

#[test]
fn test_rolling_var() {
let ca = Float64Chunked::new_from_opt_slice(
"foo",
&[
Some(0.0),
Some(1.0),
Some(2.0),
None,
None,
Some(5.0),
Some(6.0),
],
);
let out = ca.rolling_var(3).cast::<Int32Type>().unwrap();
assert_eq!(
Vec::from(&out),
&[None, None, Some(1), None, None, None, None,]
);

let ca = Float64Chunked::new_from_slice("", &[0.0, 2.0, 8.0, 3.0, 12.0, 1.0]);
let out = ca.rolling_var(3).cast::<Int32Type>().unwrap();

assert_eq!(
Vec::from(&out),
&[None, None, Some(17), Some(10), Some(20), Some(34),]
);
dbg!(out);
}
}
46 changes: 46 additions & 0 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,52 @@ impl Expr {
)
}

/// Apply a rolling variance
#[cfg_attr(docsrs, doc(cfg(feature = "rolling_window")))]
#[cfg(feature = "rolling_window")]
pub fn rolling_var(self, window_size: usize) -> Expr {
self.apply(
move |s| match s.dtype() {
DataType::Float32 => Ok(s.f32().unwrap().rolling_var(window_size).into_series()),
DataType::Float64 => Ok(s.f64().unwrap().rolling_var(window_size).into_series()),
_ => Ok(s
.cast_with_dtype(&DataType::Float64)?
.f64()
.unwrap()
.rolling_var(window_size)
.into_series()),
},
GetOutput::map_field(|field| match field.data_type() {
DataType::Float64 => field.clone(),
DataType::Float32 => Field::new(field.name(), DataType::Float32),
_ => Field::new(field.name(), DataType::Float64),
}),
)
}

/// Apply a rolling std-dev
#[cfg_attr(docsrs, doc(cfg(feature = "rolling_window")))]
#[cfg(feature = "rolling_window")]
pub fn rolling_std(self, window_size: usize) -> Expr {
self.apply(
move |s| match s.dtype() {
DataType::Float32 => Ok(s.f32().unwrap().rolling_std(window_size).into_series()),
DataType::Float64 => Ok(s.f64().unwrap().rolling_std(window_size).into_series()),
_ => Ok(s
.cast_with_dtype(&DataType::Float64)?
.f64()
.unwrap()
.rolling_std(window_size)
.into_series()),
},
GetOutput::map_field(|field| match field.data_type() {
DataType::Float64 => field.clone(),
DataType::Float32 => Field::new(field.name(), DataType::Float32),
_ => Field::new(field.name(), DataType::Float64),
}),
)
}

#[cfg_attr(docsrs, doc(cfg(feature = "rolling_window")))]
#[cfg(feature = "rolling_window")]
/// Apply a custom function over a rolling/ moving window of the array.
Expand Down
10 changes: 4 additions & 6 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,8 @@ impl ToPyObject for Wrap<DataType> {
DataType::Boolean => pl.getattr("Boolean").unwrap().into(),
DataType::Utf8 => pl.getattr("Utf8").unwrap().into(),
DataType::List(_) => pl.getattr("List").unwrap().into(),
dt => panic!("{} not supported", dt)
dt => panic!("{} not supported", dt),
}

}
}

Expand All @@ -192,8 +191,7 @@ impl<'s> FromPyObject<'s> for Wrap<AnyValue<'s>> {
Ok(AnyValue::Utf8(v).into())
} else if let Ok(v) = ob.extract::<bool>() {
Ok(AnyValue::Boolean(v).into())
}
else if ob.get_type().name()?.contains("datetime") {
} else if ob.get_type().name()?.contains("datetime") {
let gil = Python::acquire_gil();
let py = gil.python();

Expand All @@ -205,10 +203,10 @@ impl<'s> FromPyObject<'s> for Wrap<AnyValue<'s>> {
let dt = ob.call_method("replace", (), Some(kwargs))?;

let pytz = PyModule::import(py, "pytz")?;
let tz = pytz.call_method("timezone", ("UTC", ), None)?;
let tz = pytz.call_method("timezone", ("UTC",), None)?;
let kwargs = PyDict::new(py);
kwargs.set_item("is_dst", py.None())?;
let loc_tz = tz.call_method("localize", (dt, ), Some(kwargs))?;
let loc_tz = tz.call_method("localize", (dt,), Some(kwargs))?;
loc_tz.call_method0("timestamp")?;
// s to ms
let v = ts.extract::<f64>()? as i64;
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl PyDataFrame {
low_memory: bool,
comment_char: Option<&str>,
null_values: Option<Wrap<NullValues>>,
parse_dates: bool
parse_dates: bool,
) -> PyResult<Self> {
let null_values = null_values.map(|w| w.0);
let comment_char = comment_char.map(|s| s.as_bytes()[0]);
Expand Down
53 changes: 14 additions & 39 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,64 +764,39 @@ impl PyExpr {
.into()
}

pub fn rolling_std(
&self,
window_size: usize,
) -> Self {
self.inner
.clone()
.rolling_apply_float(window_size, |ca| ca.std()
)
.into()
pub fn rolling_std(&self, window_size: usize) -> Self {
self.inner.clone().rolling_std(window_size).into()
}

pub fn rolling_var(
&self,
window_size: usize,
) -> Self {
self.inner
.clone()
.rolling_apply_float(window_size, |ca| ca.var()
)
.into()
pub fn rolling_var(&self, window_size: usize) -> Self {
self.inner.clone().rolling_var(window_size).into()
}

pub fn rolling_median(
&self,
window_size: usize,
) -> Self {
pub fn rolling_median(&self, window_size: usize) -> Self {
self.inner
.clone()
.rolling_apply_float(window_size, |ca| ChunkAgg::median(ca)
)
.rolling_apply_float(window_size, |ca| ChunkAgg::median(ca))
.into()
}

pub fn rolling_quantile(
&self,
window_size: usize,
quantile: f64
) -> Self {
pub fn rolling_quantile(&self, window_size: usize, quantile: f64) -> Self {
self.inner
.clone()
.rolling_apply_float(window_size, move |ca| ChunkAgg::quantile(ca, quantile).unwrap()
)
.rolling_apply_float(window_size, move |ca| {
ChunkAgg::quantile(ca, quantile).unwrap()
})
.into()
}

pub fn rolling_skew(
&self,
window_size: usize,
bias: bool
) -> Self {
pub fn rolling_skew(&self, window_size: usize, bias: bool) -> Self {
self.inner
.clone()
.rolling_apply_float(window_size, move |ca| ca.clone().into_series().skew(bias).unwrap()
)
.rolling_apply_float(window_size, move |ca| {
ca.clone().into_series().skew(bias).unwrap()
})
.into()
}


fn lst_max(&self) -> Self {
self.inner
.clone()
Expand Down
17 changes: 6 additions & 11 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use pyo3::wrap_pyfunction;
use crate::lazy::dsl::PyExpr;
use crate::{
dataframe::PyDataFrame,
file::EitherRustPythonFile,
lazy::{
dataframe::{PyLazyFrame, PyLazyGroupBy},
dsl,
},
series::PySeries,
file::EitherRustPythonFile
};

pub mod apply;
Expand All @@ -30,12 +30,12 @@ pub mod series;
pub mod utils;

use crate::conversion::{get_df, get_pyseq, Wrap};
use polars_core::export::arrow::io::ipc::read::read_file_metadata;
use crate::error::PyPolarsEr;
use crate::file::get_either_file;
use crate::prelude::DataType;
use mimalloc::MiMalloc;
use polars_core::export::arrow::io::ipc::read::read_file_metadata;
use pyo3::types::PyDict;
use crate::prelude::DataType;
use crate::file::get_either_file;

#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
Expand Down Expand Up @@ -140,18 +140,13 @@ fn concat_df(dfs: &PyAny) -> PyResult<PyDataFrame> {
Ok(df.into())
}




#[pyfunction]
fn ipc_schema(py: Python, py_f: PyObject) -> PyResult<PyObject>{
fn ipc_schema(py: Python, py_f: PyObject) -> PyResult<PyObject> {
let metadata = match get_either_file(py_f, false)? {
EitherRustPythonFile::Rust(mut r) => {
read_file_metadata(&mut r).map_err(PyPolarsEr::from)?
},
EitherRustPythonFile::Py(mut r) => {
read_file_metadata(&mut r).map_err(PyPolarsEr::from)?
}
EitherRustPythonFile::Py(mut r) => read_file_metadata(&mut r).map_err(PyPolarsEr::from)?,
};

let dict = PyDict::new(py);
Expand Down

0 comments on commit d053cd0

Please sign in to comment.