Skip to content

Commit

Permalink
feat(rust, python): implement mean aggregation for duration (#5807)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 14, 2022
1 parent a935bf4 commit c0c3a08
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 219 deletions.
7 changes: 3 additions & 4 deletions polars/polars-core/src/frame/groupby/aggregations/dispatch.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use super::*;

// implemented on the series because we don't need types
impl Series {
fn slice_from_offsets(&self, first: IdxSize, len: IdxSize) -> Self {
self.slice(first as i64, len as usize)
}

fn restore_logical(&self, out: Series) -> Series {
if self.is_logical() {
if self.dtype().is_logical() {
out.cast(self.dtype()).unwrap()
} else {
out
Expand Down Expand Up @@ -165,9 +166,7 @@ impl Series {
use DataType::*;

match self.dtype() {
Boolean => {
self.cast(&DataType::Float64).unwrap().agg_mean(groups)
}
Boolean => self.cast(&Float64).unwrap().agg_mean(groups),
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups),
dt if dt.is_numeric() => {
Expand Down
220 changes: 5 additions & 215 deletions polars/polars-core/src/frame/groupby/aggregations/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod agg_list;
mod dispatch;

pub use agg_list::*;
use arrow::bitmap::{Bitmap, MutableBitmap};
Expand Down Expand Up @@ -428,217 +429,6 @@ impl Utf8Chunked {
}
}

// implemented on the series because we don't need types
impl Series {
fn slice_from_offsets(&self, first: IdxSize, len: IdxSize) -> Self {
self.slice(first as i64, len as usize)
}

fn restore_logical(&self, out: Series) -> Series {
if self.dtype().is_logical() {
out.cast(self.dtype()).unwrap()
} else {
out
}
}

#[doc(hidden)]
pub fn agg_valid_count(&self, groups: &GroupsProxy) -> Series {
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if !self.has_validity() {
Some(idx.len() as IdxSize)
} else {
let take =
unsafe { self.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) };
Some((take.len() - take.null_count()) as IdxSize)
}
}),
GroupsProxy::Slice { groups, .. } => {
_agg_helper_slice::<IdxType, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
if len == 0 {
None
} else if !self.has_validity() {
Some(len)
} else {
let take = self.slice_from_offsets(first, len);
Some((take.len() - take.null_count()) as IdxSize)
}
})
}
}
}

#[doc(hidden)]
pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Series {
let out = match groups {
GroupsProxy::Idx(groups) => {
let mut iter = groups.iter().map(|(first, idx)| {
if idx.is_empty() {
None
} else {
Some(first as usize)
}
});
// Safety:
// groups are always in bounds
self.take_opt_iter_unchecked(&mut iter)
}
GroupsProxy::Slice { groups, .. } => {
let mut iter =
groups.iter().map(
|&[first, len]| {
if len == 0 {
None
} else {
Some(first as usize)
}
},
);
// Safety:
// groups are always in bounds
self.take_opt_iter_unchecked(&mut iter)
}
};
self.restore_logical(out)
}

#[doc(hidden)]
pub unsafe fn agg_n_unique(&self, groups: &GroupsProxy) -> Series {
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else {
let take = self.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize));
take.n_unique().ok().map(|v| v as IdxSize)
}
}),
GroupsProxy::Slice { groups, .. } => {
_agg_helper_slice::<IdxType, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
if len == 0 {
None
} else {
let take = self.slice_from_offsets(first, len);
take.n_unique().ok().map(|v| v as IdxSize)
}
})
}
}
}

#[doc(hidden)]
pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series {
use DataType::*;

match self.dtype() {
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_median(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_median(groups),
dt if dt.is_numeric() || dt.is_temporal() => {
let ca = self.to_physical_repr();
let physical_type = ca.dtype();
let s = apply_method_physical_integer!(ca, agg_median, groups);
if dt.is_logical() {
// back to physical and then
// back to logical type
s.cast(physical_type).unwrap().cast(dt).unwrap()
} else {
s
}
}
_ => Series::full_null("", groups.len(), self.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_quantile(
&self,
groups: &GroupsProxy,
quantile: f64,
interpol: QuantileInterpolOptions,
) -> Series {
use DataType::*;

match self.dtype() {
Float32 => {
SeriesWrap(self.f32().unwrap().clone()).agg_quantile(groups, quantile, interpol)
}
Float64 => {
SeriesWrap(self.f64().unwrap().clone()).agg_quantile(groups, quantile, interpol)
}
dt if dt.is_numeric() || dt.is_temporal() => {
let ca = self.to_physical_repr();
let physical_type = ca.dtype();
let s =
apply_method_physical_integer!(ca, agg_quantile, groups, quantile, interpol);
if dt.is_logical() {
// back to physical and then
// back to logical type
s.cast(physical_type).unwrap().cast(dt).unwrap()
} else {
s
}
}
_ => Series::full_null("", groups.len(), self.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Series {
use DataType::*;

match self.dtype() {
Boolean => self.cast(&Float64).unwrap().agg_mean(groups),
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups),
dt if dt.is_numeric() => {
apply_method_physical_integer!(self, agg_mean, groups)
}
dt @ Duration(_) => {
let s = self.to_physical_repr();
// agg_mean returns Float64
let out = s.agg_mean(groups);
// cast back to Int64 and then to logical duration type
out.cast(&Int64).unwrap().cast(dt).unwrap()
}
_ => Series::full_null("", groups.len(), self.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Series {
let out = match groups {
GroupsProxy::Idx(groups) => {
let mut iter = groups.all().iter().map(|idx| {
if idx.is_empty() {
None
} else {
Some(idx[idx.len() - 1] as usize)
}
});
self.take_opt_iter_unchecked(&mut iter)
}
GroupsProxy::Slice { groups, .. } => {
let mut iter = groups.iter().map(|&[first, len]| {
if len == 0 {
None
} else {
Some((first + len - 1) as usize)
}
});
self.take_opt_iter_unchecked(&mut iter)
}
};
self.restore_logical(out)
}
}

#[inline(always)]
fn take_min<T: PartialOrd>(a: T, b: T) -> T {
if a < b {
Expand Down Expand Up @@ -1120,7 +910,7 @@ where
ChunkedArray::<T>::from_chunks("", vec![arr]).into_series()
} else {
_agg_helper_slice::<T, _>(groups, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
debug_assert!(first + len <= self.len() as IdxSize);
match len {
0 => None,
1 => self.get(first as usize),
Expand Down Expand Up @@ -1227,7 +1017,7 @@ where
ca.agg_mean(groups)
} else {
_agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
debug_assert!(len < self.len() as IdxSize);
debug_assert!(first + len <= self.len() as IdxSize);
match len {
0 => None,
1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
Expand Down Expand Up @@ -1264,7 +1054,7 @@ where
ca.agg_var(groups, ddof)
} else {
_agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
debug_assert!(first + len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
Expand Down Expand Up @@ -1300,7 +1090,7 @@ where
ca.agg_std(groups, ddof)
} else {
_agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
debug_assert!(first + len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
Expand Down
5 changes: 5 additions & 0 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,11 @@ impl Series {
let val = &[self.mean()];
Series::new(self.name(), val)
}
dt @ DataType::Duration(_) => {
Series::new(self.name(), &[self.mean().map(|v| v as i64)])
.cast(dt)
.unwrap()
}
_ => return Series::full_null(self.name(), 1, self.dtype()),
}
}
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/test_aggregations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime, timedelta

import polars as pl


Expand Down Expand Up @@ -29,3 +31,26 @@ def test_boolean_aggs() -> None:
"std": [0.5773502691896258],
"var": [0.33333333333333337],
}


def test_duration_aggs() -> None:
df = pl.DataFrame(
{
"time1": pl.date_range(
low=datetime(2022, 12, 12), high=datetime(2022, 12, 18), interval="1d"
),
"time2": pl.date_range(
low=datetime(2023, 1, 12), high=datetime(2023, 1, 18), interval="1d"
),
}
)

df = df.with_column((pl.col("time2") - pl.col("time1")).alias("time_difference"))

assert df.select("time_difference").mean().to_dict(False) == {
"time_difference": [timedelta(days=31)]
}
assert df.groupby(pl.lit(1)).agg(pl.mean("time_difference")).to_dict(False) == {
"literal": [1],
"time_difference": [timedelta(days=31)],
}

0 comments on commit c0c3a08

Please sign in to comment.