Skip to content

Commit

Permalink
Add rank(method="random"); closes #1540 (#1708)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghuls committed Nov 9, 2021
1 parent a2bbd4d commit 31e4df2
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 7 deletions.
97 changes: 92 additions & 5 deletions polars/polars-core/src/chunked_array/ops/unique/rank.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
use crate::prelude::*;

#[cfg(feature = "random")]
use rand::prelude::SliceRandom;
#[cfg(feature = "random")]
use rand::thread_rng;

#[derive(Copy, Clone)]
pub enum RankMethod {
Dense,
Ordinal,
Average,
Min,
Max,
Average,
Dense,
Ordinal,
#[cfg(feature = "random")]
Random,
}

pub(crate) fn rank(s: &Series, method: RankMethod) -> Series {
Expand Down Expand Up @@ -38,6 +45,14 @@ pub(crate) fn rank(s: &Series, method: RankMethod) -> Series {
unsafe { inv.set_len(len) }
let inv_values = inv.as_mut_slice();

#[cfg(feature = "random")]
let mut count = if let RankMethod::Ordinal | RankMethod::Random = method {
1u32
} else {
0
};

#[cfg(not(feature = "random"))]
let mut count = if let RankMethod::Ordinal = method {
1u32
} else {
Expand All @@ -52,12 +67,66 @@ pub(crate) fn rank(s: &Series, method: RankMethod) -> Series {
count += 1;
});
}
let inv_ca = UInt32Chunked::new_from_aligned_vec(s.name(), inv);

use RankMethod::*;
match method {
Ordinal => inv_ca.into_series(),
Ordinal => {
let inv_ca = UInt32Chunked::new_from_aligned_vec(s.name(), inv);
inv_ca.into_series()
}
#[cfg(feature = "random")]
Random => {
// Safety:
// in bounds
let arr = unsafe { s.take_unchecked(&sort_idx_ca).unwrap() };
let not_consecutive_same = (&arr.slice(1, len - 1))
.neq(&arr.slice(0, len - 1))
.rechunk();
let obs = not_consecutive_same.downcast_iter().next().unwrap();

// Collect slice indices for sort_idx which point to ties in the original series.
let mut ties_indices = Vec::with_capacity(len + 1);
let mut ties_index: usize = 0;

ties_indices.push(ties_index);
obs.iter().for_each(|b| {
if let Some(b) = b {
ties_index += 1;
if b {
ties_indices.push(ties_index)
}
}
});
// Close last slice (if there where nulls in the original series, they will always be in the last slice).
ties_indices.push(len);

let mut sort_idx = sort_idx.to_vec();

let rng = &mut thread_rng();

// Shuffle sort_idx positions which point to ties in the original series.
for i in 0..(ties_indices.len() - 1) as usize {
let ties_index_start = ties_indices[i];
let ties_index_end = ties_indices[i + 1];
if ties_index_end - ties_index_start > 1 {
sort_idx[ties_index_start..ties_index_end].shuffle(rng);
}
}

// Recreate inv_ca (where ties are randomly shuffled compared with Ordinal).
let mut count = 1u32;
unsafe {
sort_idx.iter().for_each(|&i| {
*inv_values.get_unchecked_mut(i as usize) = count;
count += 1;
});
}

let inv_ca = UInt32Chunked::new_from_aligned_vec(s.name(), inv);
inv_ca.into_series()
}
_ => {
let inv_ca = UInt32Chunked::new_from_aligned_vec(s.name(), inv);
// Safety:
// in bounds
let arr = unsafe { s.take_unchecked(&sort_idx_ca).unwrap() };
Expand Down Expand Up @@ -146,6 +215,9 @@ pub(crate) fn rank(s: &Series, method: RankMethod) -> Series {
+ 1.0;
(&a + &b) * 0.5
}
#[cfg(feature = "random")]
Dense | Ordinal | Random => unimplemented!(),
#[cfg(not(feature = "random"))]
Dense | Ordinal => unimplemented!(),
}
}
Expand All @@ -166,6 +238,21 @@ mod test {
.collect::<Vec<_>>();
assert_eq!(out, &[2, 3, 6, 4, 5, 7, 1]);

#[cfg(feature = "random")]
{
let out = rank(&s, RankMethod::Random)
.u32()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out[0], 2);
assert_eq!(out[6], 1);
assert_eq!(out[1] + out[3] + out[4], 12);
assert_eq!(out[2] + out[5], 13);
assert_ne!(out[1], out[3]);
assert_ne!(out[1], out[4]);
assert_ne!(out[3], out[4]);
}

let out = rank(&s, RankMethod::Dense)
.u32()?
.into_no_null_iter()
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/eager/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2843,7 +2843,7 @@ def rank(self, method: str = "average") -> "Series": # type: ignore
Parameters
----------
method
{'average', 'min', 'max', 'dense', 'ordinal'}, optional
{'average', 'min', 'max', 'dense', 'ordinal', 'random'}, optional
The method used to assign ranks to tied elements.
The following methods are available (default is 'average'):
* 'average': The average of the ranks that would have been assigned to
Expand All @@ -2858,6 +2858,8 @@ def rank(self, method: str = "average") -> "Series": # type: ignore
elements.
* 'ordinal': All values are given a distinct rank, corresponding to
the order that the values occur in `a`.
* 'random': Like 'ordinal', but the rank for ties is not dependent
on the order that the values occur in `a`.
"""
return wrap_s(self._s.rank(method))

Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/lazy/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,7 +1611,7 @@ def rank(self, method: str = "average") -> "Expr": # type: ignore
Parameters
----------
method
{'average', 'min', 'max', 'dense', 'ordinal'}, optional
{'average', 'min', 'max', 'dense', 'ordinal', 'random'}, optional
The method used to assign ranks to tied elements.
The following methods are available (default is 'average'):
* 'average': The average of the ranks that would have been assigned to
Expand All @@ -1626,6 +1626,8 @@ def rank(self, method: str = "average") -> "Expr": # type: ignore
elements.
* 'ordinal': All values are given a distinct rank, corresponding to
the order that the values occur in `a`.
* 'random': Like 'ordinal', but the rank for ties is not dependent
on the order that the values occur in `a`.
"""
return wrap_expr(self._pyexpr.rank(method))

Expand Down
1 change: 1 addition & 0 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ pub(crate) fn str_to_rankmethod(method: &str) -> PyResult<RankMethod> {
"average" => RankMethod::Average,
"dense" => RankMethod::Dense,
"ordinal" => RankMethod::Ordinal,
"random" => RankMethod::Random,
_ => {
return Err(PyValueError::new_err(
"use one of 'avg, min, max, dense, ordinal'".to_string(),
Expand Down

0 comments on commit 31e4df2

Please sign in to comment.