Skip to content

Commit

Permalink
feat[rust]: spearman rank null/nan handling (#4944)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 23, 2022
1 parent 43c9fb1 commit a3e772b
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 14 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ pub trait SeriesTrait:

/// Drop all null values and return a new Series.
fn drop_nulls(&self) -> Series {
if !self.has_validity() {
if self.null_count() == 0 {
Series(self.clone_inner())
} else {
self.filter(&self.is_not_null()).unwrap()
Expand Down
40 changes: 31 additions & 9 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,49 @@ pub fn pearson_corr(a: Expr, b: Expr, ddof: u8) -> Expr {
}

/// Compute the spearman rank correlation between two columns.
#[cfg(feature = "rank")]
#[cfg_attr(docsrs, doc(cfg(feature = "rank")))]
pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8) -> Expr {
/// Missing data will be excluded from the computation.
/// # Arguments
/// * ddof
/// Delta degrees of freedom
/// * propagate_nans
/// If `true` any `NaN` encountered will lead to `NaN` in the output.
/// If to `false` then `NaN` are regarded as larger than any finite number
/// and thus lead to the highest rank.
#[cfg(all(feature = "rank", feature = "propagate_nans"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "rank", feature = "propagate_nans"))))]
pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> Expr {
use polars_ops::prelude::nan_propagating_aggregate::nan_max_s;

let function = move |a: Series, b: Series| {
let (a, b) = coalesce_nulls_series(&a, &b);

let a = a.rank(RankOptions {
let name = "spearman_rank_correlation";
if propagate_nans && a.dtype().is_float() {
for s in [&a, &b] {
if nan_max_s(s, "").get(0).extract::<f64>().unwrap().is_nan() {
return Ok(Series::new(name, &[f64::NAN]));
}
}
}

// drop nulls so that they are excluded
let a = a.drop_nulls();
let b = b.drop_nulls();

let a_idx = a.rank(RankOptions {
method: RankMethod::Min,
..Default::default()
});
let b = b.rank(RankOptions {
let b_idx = b.rank(RankOptions {
method: RankMethod::Min,
..Default::default()
});
let a = a.idx().unwrap();
let b = b.idx().unwrap();
let a_idx = a_idx.idx().unwrap();
let b_idx = b_idx.idx().unwrap();

let name = "spearman_rank_correlation";
Ok(Series::new(
name,
&[polars_core::functions::pearson_corr_i(a, b, ddof)],
&[polars_core::functions::pearson_corr_i(a_idx, b_idx, ddof)],
))
};

Expand Down
14 changes: 12 additions & 2 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,14 @@ def lit(
return pli.wrap_expr(pylit(item, allow_object))


def spearman_rank_corr(a: str | pli.Expr, b: str | pli.Expr, ddof: int = 1) -> pli.Expr:
def spearman_rank_corr(
a: str | pli.Expr, b: str | pli.Expr, ddof: int = 1, propagate_nans: bool = False
) -> pli.Expr:
"""
Compute the spearman rank correlation between two columns.
Missing data will be excluded from the computation.
Parameters
----------
a
Expand All @@ -792,13 +796,19 @@ def spearman_rank_corr(a: str | pli.Expr, b: str | pli.Expr, ddof: int = 1) -> p
Column name or Expression.
ddof
Delta degrees of freedom
propagate_nans
If `True` any `NaN` encountered will lead to `NaN` in the output.
Defaults to `False` where `NaN` are regarded as larger than any finite number
and thus lead to the highest rank.
"""
if isinstance(a, str):
a = col(a)
if isinstance(b, str):
b = col(b)
return pli.wrap_expr(pyspearman_rank_corr(a._pyexpr, b._pyexpr, ddof))
return pli.wrap_expr(
pyspearman_rank_corr(a._pyexpr, b._pyexpr, ddof, propagate_nans)
)


def pearson_corr(a: str | pli.Expr, b: str | pli.Expr, ddof: int = 1) -> pli.Expr:
Expand Down
9 changes: 7 additions & 2 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,13 @@ fn pearson_corr(a: dsl::PyExpr, b: dsl::PyExpr, ddof: u8) -> dsl::PyExpr {
}

#[pyfunction]
fn spearman_rank_corr(a: dsl::PyExpr, b: dsl::PyExpr, ddof: u8) -> dsl::PyExpr {
polars::lazy::dsl::spearman_rank_corr(a.inner, b.inner, ddof).into()
fn spearman_rank_corr(
a: dsl::PyExpr,
b: dsl::PyExpr,
ddof: u8,
propagate_nans: bool,
) -> dsl::PyExpr {
polars::lazy::dsl::spearman_rank_corr(a.inner, b.inner, ddof, propagate_nans).into()
}

#[pyfunction]
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import numpy as np
import pytest

import polars as pl
Expand Down Expand Up @@ -119,6 +120,16 @@ def test_null_handling_correlation() -> None:
assert out["pearson"][0] == pytest.approx(1.0)
assert out["spearman"][0] == pytest.approx(1.0)

# see #4930
df1 = pl.DataFrame({"a": [None, 1, 2], "b": [None, 2, 1]})
df2 = pl.DataFrame({"a": [np.nan, 1, 2], "b": [np.nan, 2, 1]})

assert np.isclose(df1.select(pl.spearman_rank_corr("a", "b"))[0, 0], -1.0)
assert (
str(df2.select(pl.spearman_rank_corr("a", "b", propagate_nans=True))[0, 0])
== "nan"
)


def test_align_frames() -> None:
import numpy as np
Expand Down

0 comments on commit a3e772b

Please sign in to comment.