Skip to content

Commit

Permalink
Add mean/median for date
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Apr 16, 2024
1 parent 70b71da commit f9c3401
Show file tree
Hide file tree
Showing 16 changed files with 292 additions and 202 deletions.
100 changes: 55 additions & 45 deletions crates/polars-core/src/frame/group_by/aggregations/dispatch.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use DataType::*;

use super::*;

// implemented on the series because we don't need types
Expand Down Expand Up @@ -106,35 +108,6 @@ impl Series {
}
}

#[doc(hidden)]
pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series {
use DataType::*;

match self.dtype() {
Boolean => self.cast(&Float64).unwrap().agg_median(groups),
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_median(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_median(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_median, groups),
#[cfg(feature = "dtype-datetime")]
dt @ (Datetime(_, _) | Duration(_) | Time) => self
.to_physical_repr()
.agg_median(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
dt @ Date => {
let ca = self.to_physical_repr();
let physical_type = ca.dtype();
let s = apply_method_physical_integer!(ca, agg_median, groups);
// back to physical and then
// back to logical type
s.cast(physical_type).unwrap().cast(dt).unwrap()
},
_ => Series::full_null("", groups.len(), self.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_quantile(
&self,
Expand Down Expand Up @@ -166,29 +139,66 @@ impl Series {

#[doc(hidden)]
pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Series {
use DataType::*;
#[cfg(any(
feature = "dtype-date",
feature = "dtype-datetime",
feature = "dtype-duration",
feature = "dtype-time"
))]
unsafe fn temporal_mean(s: &Series, groups: &GroupsProxy, dt: &DataType) -> Series {
s.agg_mean(groups).cast(&Int64).unwrap().cast(dt).unwrap()
}

match self.dtype() {
Boolean => self.cast(&Float64).unwrap().agg_mean(groups),
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_mean, groups),
#[cfg(feature = "dtype-date")]
Date => temporal_mean(
&(self.cast(&Int64).unwrap() * (MS_IN_DAY as f64)),
groups,
&Datetime(TimeUnit::Milliseconds, None),
),
#[cfg(feature = "dtype-datetime")]
dt @ (Datetime(_, _) | Duration(_) | Time) => self
.to_physical_repr()
.agg_mean(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
dt @ Date => {
let ca = self.to_physical_repr();
let physical_type = ca.dtype();
let s = apply_method_physical_integer!(ca, agg_mean, groups);
// back to physical and then
// back to logical type
s.cast(physical_type).unwrap().cast(dt).unwrap()
},
dt @ Datetime(_, _) => temporal_mean(&self.to_physical_repr(), groups, dt),
#[cfg(feature = "dtype-duration")]
dt @ Duration(_) => temporal_mean(&self.to_physical_repr(), groups, dt),
#[cfg(feature = "dtype-time")]
dt @ Time => temporal_mean(&self.to_physical_repr(), groups, dt),
_ => Series::full_null("", groups.len(), self.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series {
#[cfg(any(
feature = "dtype-date",
feature = "dtype-datetime",
feature = "dtype-duration",
feature = "dtype-time"
))]
unsafe fn temporal_median(s: &Series, groups: &GroupsProxy, dt: &DataType) -> Series {
s.agg_median(groups).cast(&Int64).unwrap().cast(dt).unwrap()
}

match self.dtype() {
Boolean => self.cast(&Float64).unwrap().agg_median(groups),
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_median(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_median(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_median, groups),
#[cfg(feature = "dtype-date")]
Date => temporal_median(
&(self.cast(&Int64).unwrap() * (MS_IN_DAY as f64)),
groups,
&Datetime(TimeUnit::Milliseconds, None),
),
#[cfg(feature = "dtype-datetime")]
dt @ Datetime(_, _) => temporal_median(&self.to_physical_repr(), groups, dt),
#[cfg(feature = "dtype-duration")]
dt @ Duration(_) => temporal_median(&self.to_physical_repr(), groups, dt),
#[cfg(feature = "dtype-time")]
dt @ Time => temporal_median(&self.to_physical_repr(), groups, dt),
_ => Series::full_null("", groups.len(), self.dtype()),
}
}
Expand Down
10 changes: 9 additions & 1 deletion crates/polars-core/src/series/implementations/dates_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,15 @@ macro_rules! impl_dyn_series {
Ok(self.0.min_as_series().$into_logical())
}
fn median_as_series(&self) -> PolarsResult<Series> {
Series::new(self.name(), &[self.median().map(|v| v as i64)]).cast(self.dtype())
match self.dtype() {
#[cfg(feature = "dtype-date")]
DataType::Date => {
let ms = MS_IN_DAY as f64;
Series::new(self.name(), &[self.median().map(|v| (v * ms) as i64)])
.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))
},
dt => Series::new(self.name(), &[self.median().map(|v| v as i64)]).cast(dt),
}
}

fn clone_inner(&self) -> Arc<dyn SeriesTrait> {
Expand Down
31 changes: 17 additions & 14 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,12 @@ impl Series {
}

pub fn mean_as_series(&self) -> Series {
fn temporal_mean(s: &Series, dt: &DataType) -> Series {
Series::new(s.name(), &[s.mean().map(|v| v as i64)])
.cast(dt)
.unwrap()
}

match self.dtype() {
DataType::Float32 => {
let val = &[self.mean().map(|m| m as f32)];
Expand All @@ -777,23 +783,20 @@ impl Series {
let val = &[self.mean()];
Series::new(self.name(), val)
},
#[cfg(feature = "dtype-date")]
DataType::Date => Series::new(
self.name(),
&[self.mean().map(|v| (v * (MS_IN_DAY as f64)) as i64)],
)
.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))
.unwrap(),
#[cfg(feature = "dtype-datetime")]
dt @ DataType::Datetime(_, _) => {
Series::new(self.name(), &[self.mean().map(|v| v as i64)])
.cast(dt)
.unwrap()
},
dt @ DataType::Datetime(_, _) => temporal_mean(self, dt),
#[cfg(feature = "dtype-duration")]
dt @ DataType::Duration(_) => {
Series::new(self.name(), &[self.mean().map(|v| v as i64)])
.cast(dt)
.unwrap()
},
dt @ DataType::Duration(_) => temporal_mean(self, dt),
#[cfg(feature = "dtype-time")]
dt @ DataType::Time => Series::new(self.name(), &[self.mean().map(|v| v as i64)])
.cast(dt)
.unwrap(),
_ => return Series::full_null(self.name(), 1, self.dtype()),
dt @ DataType::Time => temporal_mean(self, dt),
dt => return Series::full_null(self.name(), 1, dt),
}
}

Expand Down
22 changes: 2 additions & 20 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1398,16 +1398,7 @@ impl LazyFrame {
/// - String columns will have a mean of None.
pub fn mean(self) -> PolarsResult<LazyFrame> {
self.stats_helper(
|dt| {
dt.is_numeric()
|| matches!(
dt,
DataType::Boolean
| DataType::Duration(_)
| DataType::Datetime(_, _)
| DataType::Time
)
},
|dt| *dt == DataType::Boolean || dt.is_numeric() || dt.is_temporal(),
|name| col(name).mean(),
)
}
Expand All @@ -1419,16 +1410,7 @@ impl LazyFrame {
/// - String columns will sum to None.
pub fn median(self) -> PolarsResult<LazyFrame> {
self.stats_helper(
|dt| {
dt.is_numeric()
|| matches!(
dt,
DataType::Boolean
| DataType::Duration(_)
| DataType::Datetime(_, _)
| DataType::Time
)
},
|dt| *dt == DataType::Boolean || dt.is_numeric() || dt.is_temporal(),
|name| col(name).median(),
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,26 +238,36 @@ where
.unwrap();

let logical_dtype = phys_expr.field(schema).unwrap().dtype;
#[cfg(feature = "dtype-categorical")]
if matches!(
logical_dtype,
DataType::Categorical(_, _) | DataType::Enum(_, _)
) {
return (

match &logical_dtype {
#[cfg(feature = "dtype-categorical")]
&DataType::Categorical(_, _) | DataType::Enum(_, _) => (
logical_dtype.clone(),
phys_expr,
AggregateFunction::Null(NullAgg::new(logical_dtype)),
);
}
let agg_fn = match logical_dtype.to_physical() {
dt if dt.is_integer() | dt.is_bool() => {
AggregateFunction::MeanF64(MeanAgg::<f64>::new())
),
&DataType::Date => (
logical_dtype,
to_physical(
&ExprIR::from_node(*input, expr_arena),
expr_arena,
Some(schema),
)
.unwrap(),
AggregateFunction::MeanDate(MeanAgg::<i64>::new_date()),
),
dt => {
let agg_fn = match dt.to_physical() {
dt if dt.is_integer() | dt.is_bool() => {
AggregateFunction::MeanF64(MeanAgg::<f64>::new())
},
DataType::Float32 => AggregateFunction::MeanF32(MeanAgg::<f32>::new()),
DataType::Float64 => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
dt => AggregateFunction::Null(NullAgg::new(dt)),
};
(logical_dtype, phys_expr, agg_fn)
},
DataType::Float32 => AggregateFunction::MeanF32(MeanAgg::<f32>::new()),
DataType::Float64 => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
dt => AggregateFunction::Null(NullAgg::new(dt)),
};
(logical_dtype, phys_expr, agg_fn)
}
},
AAggExpr::First(input) => {
let phys_expr = to_physical(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub(crate) enum AggregateFunction {
SumU64(SumAgg<u64>),
SumI32(SumAgg<i32>),
SumI64(SumAgg<i64>),
MeanDate(MeanAgg<i64>),
MeanF32(MeanAgg<f32>),
MeanF64(MeanAgg<f64>),
Null(NullAgg),
Expand Down Expand Up @@ -81,6 +82,7 @@ impl AggregateFunction {
SumU64(_) => SumU64(SumAgg::new()),
SumI32(_) => SumI32(SumAgg::new()),
SumI64(_) => SumI64(SumAgg::new()),
MeanDate(_) => MeanDate(MeanAgg::new_date()),
MeanF32(_) => MeanF32(MeanAgg::new()),
MeanF64(_) => MeanF64(MeanAgg::new()),
Count(_) => Count(CountAgg::new()),
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@ use super::*;
pub struct MeanAgg<K: NumericNative> {
sum: Option<K>,
count: IdxSize,
as_date: bool,
}

impl<K: NumericNative> MeanAgg<K> {
pub(crate) fn new() -> Self {
MeanAgg {
sum: None,
count: 0,
as_date: false,
}
}

pub(crate) fn new_date() -> Self {
MeanAgg {
sum: None,
count: 0,
as_date: true,
}
}
}
Expand Down Expand Up @@ -120,6 +130,14 @@ where
if let Some(val) = self.sum {
unsafe {
match K::PRIMITIVE {
PrimitiveType::Int64 => {
let mut arr = val.to_i64().unwrap_unchecked_release();
if self.as_date {
let ms_in_day = 86_400_000i64;
arr *= ms_in_day;
}
AnyValue::Int64(arr / self.count as i64)
},
PrimitiveType::Float32 => AnyValue::Float32(
val.to_f32().unwrap_unchecked_release() / self.count as f32,
),
Expand Down
12 changes: 10 additions & 2 deletions crates/polars-plan/src/logical_plan/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,21 @@ impl AExpr {
Median(expr) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
float_type(&mut field);
if field.dtype == DataType::Date {
field.coerce(DataType::Datetime(TimeUnit::Milliseconds, None));
} else {
float_type(&mut field);
}
Ok(field)
},
Mean(expr) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
float_type(&mut field);
if field.dtype == DataType::Date {
field.coerce(DataType::Datetime(TimeUnit::Milliseconds, None));
} else {
float_type(&mut field);
}
Ok(field)
},
Implode(expr) => {
Expand Down
Loading

0 comments on commit f9c3401

Please sign in to comment.