Skip to content

Commit

Permalink
Lazy: cast to float before we apply per group and integer cov/pearson
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 22, 2022
1 parent fc413ac commit 1afb301
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 51 deletions.
60 changes: 54 additions & 6 deletions polars/polars-core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use crate::utils::concat_df;
use ahash::AHashSet;
use arrow::compute;
use arrow::types::simd::Simd;
use num::{Float, NumCast};
use num::{Float, NumCast, ToPrimitive};
#[cfg(feature = "concat_str")]
use polars_arrow::prelude::ValueSize;
use std::ops::Add;

/// Compute the covariance between two columns.
pub fn cov<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<T::Native>
pub fn cov_f<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<T::Native>
where
T: PolarsFloatType,
T::Native: Float,
Expand All @@ -34,8 +34,44 @@ where
}
}

/// Compute the covariance between two columns.
pub fn cov_i<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<f64>
where
T: PolarsIntegerType,
T::Native: ToPrimitive,
<T::Native as Simd>::Simd: Add<Output = <T::Native as Simd>::Simd>
+ compute::aggregate::Sum<T::Native>
+ compute::aggregate::SimdOrd<T::Native>,
{
if a.len() != b.len() {
None
} else {
let a_mean = a.mean()?;
let b_mean = b.mean()?;
let a = a.apply_cast_numeric::<_, Float64Type>(|a| a.to_f64().unwrap() - a_mean);
let b = b.apply_cast_numeric(|b| b.to_f64().unwrap() - b_mean);

let tmp = a * b;
let n = tmp.len() - tmp.null_count();
Some(tmp.sum()? / (n - 1) as f64)
}
}

/// Compute the pearson correlation between two columns.
pub fn pearson_corr_i<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<f64>
where
T: PolarsIntegerType,
T::Native: ToPrimitive,
<T::Native as Simd>::Simd: Add<Output = <T::Native as Simd>::Simd>
+ compute::aggregate::Sum<T::Native>
+ compute::aggregate::SimdOrd<T::Native>,
ChunkedArray<T>: ChunkVar<f64>,
{
Some(cov_i(a, b)? / (a.std()? * b.std()?))
}

/// Compute the pearson correlation between two columns.
pub fn pearson_corr<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<T::Native>
pub fn pearson_corr_f<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<T::Native>
where
T: PolarsFloatType,
T::Native: Float,
Expand All @@ -44,7 +80,7 @@ where
+ compute::aggregate::SimdOrd<T::Native>,
ChunkedArray<T>: ChunkVar<T::Native>,
{
Some(cov(a, b)? / (a.std()? * b.std()?))
Some(cov_f(a, b)? / (a.std()? * b.std()?))
}

#[cfg(feature = "sort_multiple")]
Expand Down Expand Up @@ -241,12 +277,24 @@ pub fn diag_concat_df(dfs: &[DataFrame]) -> Result<DataFrame> {
mod test {
use super::*;

#[test]
fn test_cov() {
let a = Series::new("a", &[1.0f32, 2.0, 5.0]);
let b = Series::new("b", &[1.0f32, 2.0, -3.0]);
let out = cov_f(a.f32().unwrap(), b.f32().unwrap());
assert_eq!(out, Some(-5.0));
let a = a.cast(&DataType::Int32).unwrap();
let b = b.cast(&DataType::Int32).unwrap();
let out = cov_i(a.i32().unwrap(), b.i32().unwrap());
assert_eq!(out, Some(-5.0));
}

#[test]
fn test_pearson_corr() {
let a = Series::new("a", &[1.0f32, 2.0]);
let b = Series::new("b", &[1.0f32, 2.0]);
assert!((cov(a.f32().unwrap(), b.f32().unwrap()).unwrap() - 0.5).abs() < 0.001);
assert!((pearson_corr(a.f32().unwrap(), b.f32().unwrap()).unwrap() - 1.0).abs() < 0.001);
assert!((cov_f(a.f32().unwrap(), b.f32().unwrap()).unwrap() - 0.5).abs() < 0.001);
assert!((pearson_corr_f(a.f32().unwrap(), b.f32().unwrap()).unwrap() - 1.0).abs() < 0.001);
}

#[test]
Expand Down
72 changes: 38 additions & 34 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1653,46 +1653,32 @@ impl Expr {
#[cfg_attr(docsrs, doc(cfg(feature = "rolling_window")))]
#[cfg(feature = "rolling_window")]
pub fn rolling_var(self, options: RollingOptions) -> Expr {
self.apply(
move |s| match s.dtype() {
DataType::Float32 => s.f32().unwrap().rolling_var(options.clone()),
DataType::Float64 => s.f64().unwrap().rolling_var(options.clone()),
_ => s
.cast(&DataType::Float64)?
.f64()
.unwrap()
.rolling_var(options.clone()),
},
GetOutput::map_field(|field| match field.data_type() {
DataType::Float64 => field.clone(),
DataType::Float32 => Field::new(field.name(), DataType::Float32),
_ => Field::new(field.name(), DataType::Float64),
}),
)
.with_fmt("rolling_var")
self.to_float()
.apply(
move |s| match s.dtype() {
DataType::Float32 => s.f32().unwrap().rolling_var(options.clone()),
DataType::Float64 => s.f64().unwrap().rolling_var(options.clone()),
_ => unreachable!(),
},
GetOutput::same_type(),
)
.with_fmt("rolling_var")
}

/// Apply a rolling std-dev
#[cfg_attr(docsrs, doc(cfg(feature = "rolling_window")))]
#[cfg(feature = "rolling_window")]
pub fn rolling_std(self, options: RollingOptions) -> Expr {
self.apply(
move |s| match s.dtype() {
DataType::Float32 => s.f32().unwrap().rolling_std(options.clone()),
DataType::Float64 => s.f64().unwrap().rolling_std(options.clone()),
_ => s
.cast(&DataType::Float64)?
.f64()
.unwrap()
.rolling_std(options.clone()),
},
GetOutput::map_field(|field| match field.data_type() {
DataType::Float64 => field.clone(),
DataType::Float32 => Field::new(field.name(), DataType::Float32),
_ => Field::new(field.name(), DataType::Float64),
}),
)
.with_fmt("rolling_std")
self.to_float()
.apply(
move |s| match s.dtype() {
DataType::Float32 => s.f32().unwrap().rolling_std(options.clone()),
DataType::Float64 => s.f64().unwrap().rolling_std(options.clone()),
_ => unreachable!(),
},
GetOutput::same_type(),
)
.with_fmt("rolling_std")
}

#[cfg_attr(docsrs, doc(cfg(feature = "rolling_window")))]
Expand Down Expand Up @@ -2021,6 +2007,24 @@ impl Expr {
})
}

/// This is useful if an `apply` function needs a floating point type.
/// Because this cast is done on a `map` level, it will be faster.
pub fn to_float(self) -> Self {
self.map(
|s| match s.dtype() {
DataType::Float32 | DataType::Float64 => Ok(s),
_ => s.cast(&DataType::Float64),
},
GetOutput::map_dtype(|dt| {
if matches!(dt, DataType::Float32) {
DataType::Float32
} else {
DataType::Float64
}
}),
)
}

#[cfg(feature = "strings")]
pub fn str(self) -> string::StringNameSpace {
string::StringNameSpace(self)
Expand Down
52 changes: 46 additions & 6 deletions polars/polars-lazy/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,39 @@ pub fn cov(a: Expr, b: Expr) -> Expr {
DataType::Float32 => {
let ca_a = a.f32().unwrap();
let ca_b = b.f32().unwrap();
Series::new(name, &[polars_core::functions::cov(ca_a, ca_b)])
Series::new(name, &[polars_core::functions::cov_f(ca_a, ca_b)])
}
DataType::Float64 => {
let ca_a = a.f64().unwrap();
let ca_b = b.f64().unwrap();
Series::new(name, &[polars_core::functions::cov(ca_a, ca_b)])
Series::new(name, &[polars_core::functions::cov_f(ca_a, ca_b)])
}
DataType::Int32 => {
let ca_a = a.i32().unwrap();
let ca_b = b.i32().unwrap();
Series::new(name, &[polars_core::functions::cov_i(ca_a, ca_b)])
}
DataType::Int64 => {
let ca_a = a.i64().unwrap();
let ca_b = b.i64().unwrap();
Series::new(name, &[polars_core::functions::cov_i(ca_a, ca_b)])
}
DataType::UInt32 => {
let ca_a = a.u32().unwrap();
let ca_b = b.u32().unwrap();
Series::new(name, &[polars_core::functions::cov_i(ca_a, ca_b)])
}
DataType::UInt64 => {
let ca_a = a.u64().unwrap();
let ca_b = b.u64().unwrap();
Series::new(name, &[polars_core::functions::cov_i(ca_a, ca_b)])
}
_ => {
let a = a.cast(&DataType::Float64)?;
let b = b.cast(&DataType::Float64)?;
let ca_a = a.f64().unwrap();
let ca_b = b.f64().unwrap();
Series::new(name, &[polars_core::functions::cov(ca_a, ca_b)])
Series::new(name, &[polars_core::functions::cov_f(ca_a, ca_b)])
}
};
Ok(s)
Expand Down Expand Up @@ -61,19 +81,39 @@ pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
DataType::Float32 => {
let ca_a = a.f32().unwrap();
let ca_b = b.f32().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr(ca_a, ca_b)])
Series::new(name, &[polars_core::functions::pearson_corr_f(ca_a, ca_b)])
}
DataType::Float64 => {
let ca_a = a.f64().unwrap();
let ca_b = b.f64().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr(ca_a, ca_b)])
Series::new(name, &[polars_core::functions::pearson_corr_f(ca_a, ca_b)])
}
DataType::Int32 => {
let ca_a = a.i32().unwrap();
let ca_b = b.i32().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_i(ca_a, ca_b)])
}
DataType::Int64 => {
let ca_a = a.i64().unwrap();
let ca_b = b.i64().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_i(ca_a, ca_b)])
}
DataType::UInt32 => {
let ca_a = a.u32().unwrap();
let ca_b = b.u32().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_i(ca_a, ca_b)])
}
DataType::UInt64 => {
let ca_a = a.u64().unwrap();
let ca_b = b.u64().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_i(ca_a, ca_b)])
}
_ => {
let a = a.cast(&DataType::Float64)?;
let b = b.cast(&DataType::Float64)?;
let ca_a = a.f64().unwrap();
let ca_b = b.f64().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr(ca_a, ca_b)])
Series::new(name, &[polars_core::functions::pearson_corr_f(ca_a, ca_b)])
}
};
Ok(s)
Expand Down
15 changes: 10 additions & 5 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::prelude::*;
use polars_arrow::utils::CustomIterTools;
use polars_core::frame::groupby::GroupsProxy;
use polars_core::prelude::*;
use polars_core::POOL;
use rayon::prelude::*;
use std::sync::Arc;

Expand All @@ -23,10 +24,12 @@ impl ApplyExpr {
groups: &'a GroupsProxy,
state: &ExecutionState,
) -> Result<Vec<AggregationContext<'a>>> {
self.inputs
.par_iter()
.map(|e| e.evaluate_on_groups(df, groups, state))
.collect()
POOL.install(|| {
self.inputs
.par_iter()
.map(|e| e.evaluate_on_groups(df, groups, state))
.collect()
})
}

fn finish_apply_groups<'a>(
Expand Down Expand Up @@ -186,7 +189,9 @@ impl PhysicalExpr for ApplyExpr {

let s = self.function.call_udf(&mut s)?;
let mut ac = acs.pop().unwrap();
ac.with_update_groups(UpdateGroups::WithGroupsLen);
if ac.is_aggregated() {
ac.with_update_groups(UpdateGroups::WithGroupsLen);
}
ac.with_series(s, false);
Ok(ac)
}
Expand Down

0 comments on commit 1afb301

Please sign in to comment.