Skip to content

Commit

Permalink
handle null values in correlation (#4318)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 8, 2022
1 parent decd539 commit 796a5ab
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 42 deletions.
2 changes: 1 addition & 1 deletion polars/benches/groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ fn q9(c: &mut Criterion) {
.lazy()
.drop_nulls(Some(vec![col("v1"), col("v2")]))
.groupby([col("id2"), col("id4")])
.agg([pearson_corr(col("v1"), col("v2")).alias("r2").pow(2.0)])
.agg([pearson_corr(col("v1"), col("v2"), 1).alias("r2").pow(2.0)])
.collect()
.unwrap();
})
Expand Down
21 changes: 21 additions & 0 deletions polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,27 @@ impl<T: PolarsDataType> ChunkedArray<T> {
BooleanChunked::from_chunks(self.name(), chunks)
}

pub(crate) fn coalesce_nulls(&self, other: &[ArrayRef]) -> Self {
assert_eq!(self.chunks.len(), other.len());
let chunks = self
.chunks
.iter()
.zip(other)
.map(|(a, b)| {
assert_eq!(a.len(), b.len());
let validity = match (a.validity(), b.validity()) {
(None, Some(b)) => Some(b.clone()),
(Some(a), Some(b)) => Some(a & b),
(Some(a), None) => Some(a.clone()),
(None, None) => None,
};

a.with_validity(validity)
})
.collect();
self.copy_with_chunks(chunks, true)
}

/// Get data type of ChunkedArray.
pub fn dtype(&self) -> &DataType {
self.field.data_type()
Expand Down
21 changes: 16 additions & 5 deletions polars/polars-core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#[cfg(feature = "sort_multiple")]
use crate::chunked_array::ops::sort::prepare_argsort;
use crate::prelude::*;
use crate::utils::coalesce_nulls;
#[cfg(feature = "diagonal_concat")]
use crate::utils::concat_df;
#[cfg(feature = "diagonal_concat")]
Expand Down Expand Up @@ -58,7 +59,7 @@ where
}

/// Compute the pearson correlation between two columns.
pub fn pearson_corr_i<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<f64>
pub fn pearson_corr_i<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>, ddof: u8) -> Option<f64>
where
T: PolarsIntegerType,
T::Native: ToPrimitive,
Expand All @@ -67,11 +68,15 @@ where
+ compute::aggregate::SimdOrd<T::Native>,
ChunkedArray<T>: ChunkVar<f64>,
{
Some(cov_i(a, b)? / (a.std(1)? * b.std(1)?))
let (a, b) = coalesce_nulls(a, b);
let a = a.as_ref();
let b = b.as_ref();

Some(cov_i(a, b)? / (a.std(ddof)? * b.std(ddof)?))
}

/// Compute the pearson correlation between two columns.
pub fn pearson_corr_f<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<T::Native>
pub fn pearson_corr_f<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>, ddof: u8) -> Option<T::Native>
where
T: PolarsFloatType,
T::Native: Float,
Expand All @@ -80,7 +85,11 @@ where
+ compute::aggregate::SimdOrd<T::Native>,
ChunkedArray<T>: ChunkVar<T::Native>,
{
Some(cov_f(a, b)? / (a.std(1)? * b.std(1)?))
let (a, b) = coalesce_nulls(a, b);
let a = a.as_ref();
let b = b.as_ref();

Some(cov_f(a, b)? / (a.std(ddof)? * b.std(ddof)?))
}

#[cfg(feature = "sort_multiple")]
Expand Down Expand Up @@ -298,7 +307,9 @@ mod test {
let a = Series::new("a", &[1.0f32, 2.0]);
let b = Series::new("b", &[1.0f32, 2.0]);
assert!((cov_f(a.f32().unwrap(), b.f32().unwrap()).unwrap() - 0.5).abs() < 0.001);
assert!((pearson_corr_f(a.f32().unwrap(), b.f32().unwrap()).unwrap() - 1.0).abs() < 0.001);
assert!(
(pearson_corr_f(a.f32().unwrap(), b.f32().unwrap(), 1).unwrap() - 1.0).abs() < 0.001
);
}

#[test]
Expand Down
41 changes: 41 additions & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,47 @@ where
None
}

/// ensure that nulls are propagated to both arrays
pub fn coalesce_nulls<'a, T: PolarsDataType>(
a: &'a ChunkedArray<T>,
b: &'a ChunkedArray<T>,
) -> (Cow<'a, ChunkedArray<T>>, Cow<'a, ChunkedArray<T>>) {
if a.null_count() > 0 || b.null_count() > 0 {
let (a, b) = align_chunks_binary(a, b);
let mut b = b.into_owned();
let a = a.coalesce_nulls(b.chunks());

for arr in a.chunks().iter() {
for arr_b in unsafe { b.chunks_mut() } {
*arr_b = arr_b.with_validity(arr.validity().cloned())
}
}
(Cow::Owned(a), Cow::Owned(b))
} else {
(Cow::Borrowed(a), Cow::Borrowed(b))
}
}

pub fn coalesce_nulls_series(a: &Series, b: &Series) -> (Series, Series) {
if a.null_count() > 0 || b.null_count() > 0 {
let mut a = a.rechunk();
let mut b = b.rechunk();
for (arr_a, arr_b) in unsafe { a.chunks_mut().iter_mut().zip(b.chunks_mut()) } {
let validity = match (arr_a.validity(), arr_b.validity()) {
(None, Some(b)) => Some(b.clone()),
(Some(a), Some(b)) => Some(a & b),
(Some(a), None) => Some(a.clone()),
(None, None) => None,
};
*arr_a = arr_a.with_validity(validity.clone());
*arr_b = arr_b.with_validity(validity);
}
(a, b)
} else {
(a.clone(), b.clone())
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
68 changes: 52 additions & 16 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ use crate::dsl::function_expr::FunctionExpr;
use crate::prelude::*;
use crate::utils::has_wildcard;
use polars_core::export::arrow::temporal_conversions::NANOSECONDS;
use polars_core::functions::pearson_corr_i;
use polars_core::prelude::*;
use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY;
use polars_core::utils::get_supertype;
use polars_core::utils::{coalesce_nulls_series, get_supertype};
#[cfg(feature = "list")]
use polars_ops::prelude::ListNameSpaceImpl;
use rayon::prelude::*;
Expand Down Expand Up @@ -80,46 +81,67 @@ pub fn cov(a: Expr, b: Expr) -> Expr {
}

/// Compute the pearson correlation between two columns.
pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
pub fn pearson_corr(a: Expr, b: Expr, ddof: u8) -> Expr {
let name = "pearson_corr";
let function = move |a: Series, b: Series| {
let s = match a.dtype() {
DataType::Float32 => {
let ca_a = a.f32().unwrap();
let ca_b = b.f32().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_f(ca_a, ca_b)])
Series::new(
name,
&[polars_core::functions::pearson_corr_f(ca_a, ca_b, ddof)],
)
}
DataType::Float64 => {
let ca_a = a.f64().unwrap();
let ca_b = b.f64().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_f(ca_a, ca_b)])
Series::new(
name,
&[polars_core::functions::pearson_corr_f(ca_a, ca_b, ddof)],
)
}
DataType::Int32 => {
let ca_a = a.i32().unwrap();
let ca_b = b.i32().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_i(ca_a, ca_b)])
Series::new(
name,
&[polars_core::functions::pearson_corr_i(ca_a, ca_b, ddof)],
)
}
DataType::Int64 => {
let ca_a = a.i64().unwrap();
let ca_b = b.i64().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_i(ca_a, ca_b)])
Series::new(
name,
&[polars_core::functions::pearson_corr_i(ca_a, ca_b, ddof)],
)
}
DataType::UInt32 => {
let ca_a = a.u32().unwrap();
let ca_b = b.u32().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_i(ca_a, ca_b)])
Series::new(
name,
&[polars_core::functions::pearson_corr_i(ca_a, ca_b, ddof)],
)
}
DataType::UInt64 => {
let ca_a = a.u64().unwrap();
let ca_b = b.u64().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_i(ca_a, ca_b)])
Series::new(
name,
&[polars_core::functions::pearson_corr_i(ca_a, ca_b, ddof)],
)
}
_ => {
let a = a.cast(&DataType::Float64)?;
let b = b.cast(&DataType::Float64)?;
let ca_a = a.f64().unwrap();
let ca_b = b.f64().unwrap();
Series::new(name, &[polars_core::functions::pearson_corr_f(ca_a, ca_b)])
Series::new(
name,
&[polars_core::functions::pearson_corr_f(ca_a, ca_b, ddof)],
)
}
};
Ok(s)
Expand All @@ -146,18 +168,32 @@ pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
/// Compute the spearman rank correlation between two columns.
#[cfg(feature = "rank")]
#[cfg_attr(docsrs, doc(cfg(feature = "rank")))]
pub fn spearman_rank_corr(a: Expr, b: Expr) -> Expr {
pearson_corr(
a.rank(RankOptions {
pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8) -> Expr {
let function = move |a: Series, b: Series| {
let (a, b) = coalesce_nulls_series(&a, &b);

let a = a.rank(RankOptions {
method: RankMethod::Min,
..Default::default()
}),
b.rank(RankOptions {
});
let b = b.rank(RankOptions {
method: RankMethod::Min,
..Default::default()
}),
});
let a = a.idx().unwrap();
let b = b.idx().unwrap();

let name = "spearman_rank_correlation";
Ok(Series::new(name, &[pearson_corr_i(a, b, ddof)]))
};

apply_binary(a, b, function, GetOutput::from_type(DataType::Float64)).with_function_options(
|mut options| {
options.auto_explode = true;
options.fmt_str = "spearman_rank_correlation";
options
},
)
.with_fmt("spearman_rank_correlation")
}

/// Find the indexes that would sort these series in order of appearance.
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/tests/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn test_pearson_corr() -> Result<()> {
.lazy()
.groupby_stable([col("uid")])
// a double aggregation expression.
.agg([pearson_corr(col("day"), col("cumcases")).alias("pearson_corr")])
.agg([pearson_corr(col("day"), col("cumcases"), 1).alias("pearson_corr")])
.collect()?;
let s = out.column("pearson_corr")?.f64()?;
assert!((s.get(0).unwrap() - 0.997176).abs() < 0.000001);
Expand All @@ -24,7 +24,7 @@ fn test_pearson_corr() -> Result<()> {
.lazy()
.groupby_stable([col("uid")])
// a double aggregation expression.
.agg([pearson_corr(col("day"), col("cumcases"))
.agg([pearson_corr(col("day"), col("cumcases"), 1)
.pow(2.0)
.alias("pearson_corr")])
.collect()
Expand Down
18 changes: 8 additions & 10 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,10 +687,7 @@ def lit(value: Any, dtype: type[DataType] | None = None) -> pli.Expr:
return pli.wrap_expr(pylit(item))


def spearman_rank_corr(
a: str | pli.Expr,
b: str | pli.Expr,
) -> pli.Expr:
def spearman_rank_corr(a: str | pli.Expr, b: str | pli.Expr, ddof: int = 1) -> pli.Expr:
"""
Compute the spearman rank correlation between two columns.
Expand All @@ -700,19 +697,18 @@ def spearman_rank_corr(
Column name or Expression.
b
Column name or Expression.
ddof
Delta degrees of freedom
"""
if isinstance(a, str):
a = col(a)
if isinstance(b, str):
b = col(b)
return pli.wrap_expr(pyspearman_rank_corr(a._pyexpr, b._pyexpr))
return pli.wrap_expr(pyspearman_rank_corr(a._pyexpr, b._pyexpr, ddof))


def pearson_corr(
a: str | pli.Expr,
b: str | pli.Expr,
) -> pli.Expr:
def pearson_corr(a: str | pli.Expr, b: str | pli.Expr, ddof: int = 1) -> pli.Expr:
"""
Compute the pearson's correlation between two columns.
Expand All @@ -722,13 +718,15 @@ def pearson_corr(
Column name or Expression.
b
Column name or Expression.
ddof
Delta degrees of freedom
"""
if isinstance(a, str):
a = col(a)
if isinstance(b, str):
b = col(b)
return pli.wrap_expr(pypearson_corr(a._pyexpr, b._pyexpr))
return pli.wrap_expr(pypearson_corr(a._pyexpr, b._pyexpr, ddof))


def cov(
Expand Down
8 changes: 4 additions & 4 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ fn repeat(value: &PyAny, n_times: PyExpr) -> PyExpr {
}

#[pyfunction]
fn pearson_corr(a: dsl::PyExpr, b: dsl::PyExpr) -> dsl::PyExpr {
polars::lazy::dsl::pearson_corr(a.inner, b.inner).into()
fn pearson_corr(a: dsl::PyExpr, b: dsl::PyExpr, ddof: u8) -> dsl::PyExpr {
polars::lazy::dsl::pearson_corr(a.inner, b.inner, ddof).into()
}

#[pyfunction]
fn spearman_rank_corr(a: dsl::PyExpr, b: dsl::PyExpr) -> dsl::PyExpr {
polars::lazy::dsl::spearman_rank_corr(a.inner, b.inner).into()
fn spearman_rank_corr(a: dsl::PyExpr, b: dsl::PyExpr, ddof: u8) -> dsl::PyExpr {
polars::lazy::dsl::spearman_rank_corr(a.inner, b.inner, ddof).into()
}

#[pyfunction]
Expand Down
6 changes: 2 additions & 4 deletions py-polars/tests/db-benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
ON_STRINGS = sys.argv.pop() == "on_strings"

if not ON_STRINGS:
x["id1"] = x["id1"].cast(pl.Categorical)
x["id2"] = x["id2"].cast(pl.Categorical)
x["id3"] = x["id3"].cast(pl.Categorical)
x.with_columns([pl.col(["id1", "id2", "id3"]).cast(pl.Categorical)])
df = x.clone()
x = df.lazy()

Expand Down Expand Up @@ -274,7 +272,7 @@
)
print(time.time() - t0)
assert out.shape == (9216, 3)
assert np.isclose(out["r2"].sum(), 9.896846028461322)
assert np.isclose(out["r2"].sum(), 9.902706276948825)

t0 = time.time()
print("q10")
Expand Down

0 comments on commit 796a5ab

Please sign in to comment.