Skip to content

Commit

Permalink
allow more aggregations on dtype duration (#3550)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 1, 2022
1 parent 9654f7d commit de92284
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 11 deletions.
23 changes: 14 additions & 9 deletions polars/polars-core/src/frame/groupby/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,21 @@ impl Series {
#[doc(hidden)]
pub fn agg_mean(&self, groups: &GroupsProxy) -> Series {
use DataType::*;
if self.dtype().is_numeric() {
match self.dtype() {
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups),
_ => {
apply_method_physical_integer!(self, agg_mean, groups)
}

match self.dtype() {
Float32 => SeriesWrap(self.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups),
dt if dt.is_numeric() => {
apply_method_physical_integer!(self, agg_mean, groups)
}
} else {
Series::full_null("", groups.len(), self.dtype())
dt @ Duration(_) => {
let s = self.to_physical_repr();
// agg_mean returns Float64
let out = s.agg_mean(groups);
// cast back to Int64 and then to logical duration type
out.cast(&Int64).unwrap().cast(dt).unwrap()
}
_ => Series::full_null("", groups.len(), self.dtype()),
}
}

Expand Down
33 changes: 33 additions & 0 deletions polars/polars-core/src/series/implementations/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,33 @@ impl private::PrivateSeries for SeriesWrap<DurationChunked> {
.into_series()
}

fn agg_sum(&self, groups: &GroupsProxy) -> Series {
self.0
.agg_sum(groups)
.into_duration(self.0.time_unit())
.into_series()
}

fn agg_std(&self, groups: &GroupsProxy) -> Series {
self.0
.agg_std(groups)
// cast f64 back to physical type
.cast(&DataType::Int64)
.unwrap()
.into_duration(self.0.time_unit())
.into_series()
}

fn agg_var(&self, groups: &GroupsProxy) -> Series {
self.0
.agg_var(groups)
// cast f64 back to physical type
.cast(&DataType::Int64)
.unwrap()
.into_duration(self.0.time_unit())
.into_series()
}

fn agg_list(&self, groups: &GroupsProxy) -> Series {
// we cannot cast and dispatch as the inner type of the list would be incorrect
self.0
Expand All @@ -111,13 +138,19 @@ impl private::PrivateSeries for SeriesWrap<DurationChunked> {
) -> Series {
self.0
.agg_quantile(groups, quantile, interpol)
// cast f64 back to physical type
.cast(&DataType::Int64)
.unwrap()
.into_duration(self.0.time_unit())
.into_series()
}

fn agg_median(&self, groups: &GroupsProxy) -> Series {
self.0
.agg_median(groups)
// cast f64 back to physical type
.cast(&DataType::Int64)
.unwrap()
.into_duration(self.0.time_unit())
.into_series()
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def read_ipc(
file: Union[str, BinaryIO, BytesIO, Path, bytes],
columns: Optional[Union[List[int], List[str]]] = None,
n_rows: Optional[int] = None,
use_pyarrow: bool = _PYARROW_AVAILABLE,
use_pyarrow: bool = False,
memory_map: bool = True,
storage_options: Optional[Dict] = None,
row_count_name: Optional[str] = None,
Expand Down
17 changes: 16 additions & 1 deletion py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,22 @@ impl PySeries {
}

pub fn get_fmt(&self, index: usize) -> String {
format!("{}", self.series.get(index))
let val = format!("{}", self.series.get(index));
if let DataType::Utf8 | DataType::Categorical(_) = self.series.dtype() {
let v_trunc = &val[..val
.char_indices()
.take(15)
.last()
.map(|(i, c)| i + c.len_utf8())
.unwrap_or(0)];
if val == v_trunc {
val
} else {
format!("{}...", v_trunc)
}
} else {
val
}
}

pub fn rechunk(&mut self, in_place: bool) -> Option<Self> {
Expand Down
44 changes: 44 additions & 0 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,47 @@ def test_timedelta_from() -> None:
},
]
assert pl.DataFrame(as_dict).frame_equal(pl.DataFrame(as_rows))


def test_duration_aggregations() -> None:
df = pl.DataFrame(
{
"group": ["A", "B", "A", "B"],
"start": [
datetime(2022, 1, 1),
datetime(2022, 1, 2),
datetime(2022, 1, 3),
datetime(2022, 1, 4),
],
"end": [
datetime(2022, 1, 2),
datetime(2022, 1, 4),
datetime(2022, 1, 6),
datetime(2022, 1, 6),
],
}
)
df = df.with_column((pl.col("end") - pl.col("start")).alias("duration"))
assert df.groupby("group", maintain_order=True).agg(
[
pl.col("duration").mean().alias("mean"),
pl.col("duration").sum().alias("sum"),
pl.col("duration").min().alias("min"),
pl.col("duration").max().alias("max"),
pl.col("duration").quantile(0.1).alias("quantile"),
pl.col("duration").median().alias("median"),
pl.col("duration").list().alias("list"),
]
).to_dict(False) == {
"group": ["A", "B"],
"mean": [timedelta(days=2), timedelta(days=2)],
"sum": [timedelta(days=4), timedelta(days=4)],
"min": [timedelta(days=1), timedelta(days=2)],
"max": [timedelta(days=3), timedelta(days=2)],
"quantile": [timedelta(days=1), timedelta(days=2)],
"median": [timedelta(days=2), timedelta(days=2)],
"list": [
[timedelta(days=1), timedelta(days=3)],
[timedelta(days=2), timedelta(days=2)],
],
}

0 comments on commit de92284

Please sign in to comment.