Skip to content

Commit

Permalink
quantile and median return floats always (#2408)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcvanheerden committed Jan 25, 2022
1 parent a3311a6 commit 0a8ede9
Show file tree
Hide file tree
Showing 10 changed files with 549 additions and 292 deletions.
538 changes: 390 additions & 148 deletions polars/polars-core/src/chunked_array/ops/aggregate.rs

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,15 @@ pub trait ChunkAgg<T> {
fn mean(&self) -> Option<f64> {
None
}
}

/// Quantile and median aggregation
pub trait ChunkQuantile<T> {
/// Returns the mean value in the array.
/// Returns `None` if the array is empty or only contains null values.
fn median(&self) -> Option<f64> {
fn median(&self) -> Option<T> {
None
}

/// Aggregate a given quantile of the ChunkedArray.
/// Returns `None` if the array is empty or only contains null values.
fn quantile(&self, _quantile: f64, _interpol: QuantileInterpolOptions) -> Result<Option<T>> {
Expand Down
237 changes: 136 additions & 101 deletions polars/polars-core/src/frame/groupby/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,11 @@ where
impl<T> SeriesWrap<ChunkedArray<T>>
where
T: PolarsFloatType,
ChunkedArray<T>: IntoSeries + ChunkVar<T::Native>,
ChunkedArray<T>: IntoSeries
+ ChunkVar<T::Native>
+ VarAggSeries
+ ChunkQuantile<T::Native>
+ QuantileAggSeries,
T::Native: NativeType + PartialOrd + Num + NumCast + Simd + std::iter::Sum<T::Native>,
<T::Native as Simd>::Simd: std::ops::Add<Output = <T::Native as Simd>::Simd>
+ arrow::compute::aggregate::Sum<T::Native>
Expand Down Expand Up @@ -422,11 +426,7 @@ where
return None;
}
let take = unsafe { ca.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
take.into_series()
.var_as_series()
.unpack::<T>()
.unwrap()
.get(0)
take.var_as_series().unpack::<T>().unwrap().get(0)
}),
GroupsProxy::Slice(groups) => agg_helper_slice::<T, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as u32);
Expand All @@ -450,20 +450,79 @@ where
return None;
}
let take = unsafe { ca.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
take.into_series()
.std_as_series()
take.std_as_series().unpack::<T>().unwrap().get(0)
}),
GroupsProxy::Slice(groups) => agg_helper_slice::<T, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as u32);
match len {
0 => None,
1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
_ => {
let arr_group = slice_from_offsets(self, first, len);
arr_group.std().map(|flt| NumCast::from(flt).unwrap())
}
}
}),
}
}

pub(crate) fn agg_quantile(
&self,
groups: &GroupsProxy,
quantile: f64,
interpol: QuantileInterpolOptions,
) -> Option<Series> {
let ca = &self.0;
let invalid_quantile = !(0.0..=1.0).contains(&quantile);
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx::<T, _>(groups, |(_first, idx)| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() | invalid_quantile {
return None;
}
let take = unsafe { ca.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
take.quantile_as_series(quantile, interpol)
.unwrap() // checked with invalid quantile check
.unpack::<T>()
.unwrap()
.get(0)
}),
GroupsProxy::Slice(groups) => agg_helper_slice::<T, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as u32);
match len {
0 => None,
1 => self.get(first as usize),
_ => {
let arr_group = slice_from_offsets(self, first, len);
// unwrap checked with invalid quantile check
arr_group
.quantile(quantile, interpol)
.unwrap()
.map(|flt| NumCast::from(flt).unwrap())
}
}
}),
}
}
pub(crate) fn agg_median(&self, groups: &GroupsProxy) -> Option<Series> {
let ca = &self.0;
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx::<T, _>(groups, |(_first, idx)| {
debug_assert!(idx.len() <= ca.len());
if idx.is_empty() {
return None;
}
let take = unsafe { ca.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
take.median_as_series().unpack::<T>().unwrap().get(0)
}),
GroupsProxy::Slice(groups) => agg_helper_slice::<T, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as u32);
match len {
0 => None,
1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
_ => {
let arr_group = slice_from_offsets(self, first, len);
arr_group.std().map(|flt| NumCast::from(flt).unwrap())
arr_group.median().map(|flt| NumCast::from(flt).unwrap())
}
}
}),
Expand Down Expand Up @@ -557,8 +616,70 @@ where
}
let take =
unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
take.into_series()
.var_as_series()
take.var_as_series().unpack::<Float64Type>().unwrap().get(0)
})
}
GroupsProxy::Slice(groups) => {
agg_helper_slice::<Float64Type, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as u32);
match len {
0 => None,
1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
_ => {
let arr_group = slice_from_offsets(self, first, len);
arr_group.var()
}
}
})
}
}
}
pub(crate) fn agg_std(&self, groups: &GroupsProxy) -> Option<Series> {
match groups {
GroupsProxy::Idx(groups) => {
agg_helper_idx::<Float64Type, _>(groups, |(_first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
return None;
}
let take =
unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
take.std_as_series().unpack::<Float64Type>().unwrap().get(0)
})
}
GroupsProxy::Slice(groups) => {
agg_helper_slice::<Float64Type, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as u32);
match len {
0 => None,
1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
_ => {
let arr_group = slice_from_offsets(self, first, len);
arr_group.std()
}
}
})
}
}
}

pub(crate) fn agg_quantile(
&self,
groups: &GroupsProxy,
quantile: f64,
interpol: QuantileInterpolOptions,
) -> Option<Series> {
match groups {
GroupsProxy::Idx(groups) => {
agg_helper_idx::<Float64Type, _>(groups, |(_first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
return None;
}
let take =
unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
take.quantile_as_series(quantile, interpol)
.unwrap()
.unpack::<Float64Type>()
.unwrap()
.get(0)
Expand All @@ -572,14 +693,14 @@ where
1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
_ => {
let arr_group = slice_from_offsets(self, first, len);
arr_group.var()
arr_group.quantile(quantile, interpol).unwrap()
}
}
})
}
}
}
pub(crate) fn agg_std(&self, groups: &GroupsProxy) -> Option<Series> {
pub(crate) fn agg_median(&self, groups: &GroupsProxy) -> Option<Series> {
match groups {
GroupsProxy::Idx(groups) => {
agg_helper_idx::<Float64Type, _>(groups, |(_first, idx)| {
Expand All @@ -589,8 +710,7 @@ where
}
let take =
unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
take.into_series()
.std_as_series()
take.median_as_series()
.unpack::<Float64Type>()
.unwrap()
.get(0)
Expand All @@ -604,7 +724,7 @@ where
1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
_ => {
let arr_group = slice_from_offsets(self, first, len);
arr_group.std()
arr_group.median()
}
}
})
Expand Down Expand Up @@ -993,88 +1113,3 @@ impl<T: PolarsObject> AggList for ObjectChunked<T> {
Some(listarr.into_series())
}
}

pub(crate) trait AggQuantile {
fn agg_quantile(
&self,
_groups: &GroupsProxy,
_quantile: f64,
_interpol: QuantileInterpolOptions,
) -> Option<Series> {
None
}

fn agg_median(&self, _groups: &GroupsProxy) -> Option<Series> {
None
}
}

impl<T> AggQuantile for ChunkedArray<T>
where
T: PolarsNumericType + Sync,
T::Native: PartialOrd + Num + NumCast + Zero + Simd + std::iter::Sum<T::Native>,
<T::Native as Simd>::Simd: std::ops::Add<Output = <T::Native as Simd>::Simd>
+ arrow::compute::aggregate::Sum<T::Native>
+ arrow::compute::aggregate::SimdOrd<T::Native>,
ChunkedArray<T>: IntoSeries,
{
fn agg_quantile(
&self,
groups: &GroupsProxy,
quantile: f64,
interpol: QuantileInterpolOptions,
) -> Option<Series> {
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx::<T, _>(groups, |(_first, idx)| {
if idx.is_empty() {
return None;
}

let group_vals =
unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
group_vals.quantile(quantile, interpol).unwrap()
}),
GroupsProxy::Slice(groups) => agg_helper_slice::<T, _>(groups, |[first, len]| {
if len == 0 {
return None;
}
let group_vals = slice_from_offsets(self, first, len);
group_vals.quantile(quantile, interpol).unwrap()
}),
}
}

fn agg_median(&self, groups: &GroupsProxy) -> Option<Series> {
match groups {
GroupsProxy::Idx(groups) => {
agg_helper_idx::<Float64Type, _>(groups, |(_first, idx)| {
if idx.is_empty() {
return None;
}

let group_vals =
unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
group_vals.median()
})
}
GroupsProxy::Slice(groups) => {
agg_helper_slice::<Float64Type, _>(groups, |[first, len]| {
if len == 0 {
return None;
}

let group_vals = slice_from_offsets(self, first, len);
group_vals.median()
})
}
}
}
}

impl AggQuantile for Utf8Chunked {}
impl AggQuantile for BooleanChunked {}
impl AggQuantile for ListChunked {}
#[cfg(feature = "dtype-categorical")]
impl AggQuantile for CategoricalChunked {}
#[cfg(feature = "object")]
impl<T> AggQuantile for ObjectChunked<T> {}
22 changes: 2 additions & 20 deletions polars/polars-core/src/series/implementations/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,6 @@ impl private::PrivateSeries for SeriesWrap<BooleanChunked> {
fn agg_list(&self, groups: &GroupsProxy) -> Option<Series> {
self.0.agg_list(groups)
}

fn agg_quantile(
&self,
groups: &GroupsProxy,
quantile: f64,
interpol: QuantileInterpolOptions,
) -> Option<Series> {
self.0.agg_quantile(groups, quantile, interpol)
}

fn agg_median(&self, groups: &GroupsProxy) -> Option<Series> {
self.0.agg_median(groups)
}

fn hash_join_inner(&self, other: &Series) -> Vec<(u32, u32)> {
HashJoin::hash_join_inner(&self.0, other.as_ref().as_ref())
}
Expand Down Expand Up @@ -191,10 +177,6 @@ impl SeriesTrait for SeriesWrap<BooleanChunked> {
self.0.mean()
}

fn median(&self) -> Option<f64> {
self.0.median()
}

fn take(&self, indices: &UInt32Chunked) -> Result<Series> {
let indices = if indices.chunks.len() > 1 {
Cow::Owned(indices.rechunk())
Expand Down Expand Up @@ -345,7 +327,7 @@ impl SeriesTrait for SeriesWrap<BooleanChunked> {
ChunkAggSeries::mean_as_series(&self.0)
}
fn median_as_series(&self) -> Series {
ChunkAggSeries::median_as_series(&self.0)
QuantileAggSeries::median_as_series(&self.0)
}
fn var_as_series(&self) -> Series {
VarAggSeries::var_as_series(&self.0)
Expand All @@ -358,7 +340,7 @@ impl SeriesTrait for SeriesWrap<BooleanChunked> {
quantile: f64,
interpol: QuantileInterpolOptions,
) -> Result<Series> {
ChunkAggSeries::quantile_as_series(&self.0, quantile, interpol)
QuantileAggSeries::quantile_as_series(&self.0, quantile, interpol)
}

fn fmt_list(&self) -> String {
Expand Down

0 comments on commit 0a8ede9

Please sign in to comment.