Skip to content

Commit

Permalink
correct logical type names in header in pivot operation
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 15, 2022
1 parent 42d09c1 commit 77b41d0
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 15 deletions.
131 changes: 117 additions & 14 deletions polars/polars-core/src/frame/groupby/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ use std::collections::hash_map::RandomState;
use std::fmt::{Debug, Formatter};
use std::ops::{Add, Deref};

#[cfg(feature = "dtype-date")]
use arrow::temporal_conversions::date32_to_date;
#[cfg(feature = "dtype-datetime")]
use arrow::temporal_conversions::{timestamp_ms_to_datetime, timestamp_ns_to_datetime};

/// Utility enum used for grouping on multiple columns
#[derive(Copy, Clone, Hash, Eq, PartialEq)]
pub(crate) enum Groupable<'a> {
Expand Down Expand Up @@ -595,68 +600,134 @@ where
builder.append_option(max);
}

fn finish_logical_types(mut out: DataFrame, pivot_series: &Series) -> Result<DataFrame> {
match pivot_series.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical => {
let piv = pivot_series.categorical().unwrap();
let rev_map = piv.categorical_map.as_ref().unwrap();
for s in out.columns[1..].iter_mut() {
let category = s.name().parse::<u32>().unwrap();
let name = rev_map.get(category);
s.rename(name);
}
Ok(out)
}
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(tu, _) => {
let fun = match tu {
TimeUnit::Nanoseconds => timestamp_ns_to_datetime,
TimeUnit::Milliseconds => timestamp_ms_to_datetime,
};

for s in out.columns[1..].iter_mut() {
let ts = s.name().parse::<i64>().unwrap();
let nd = fun(ts);
s.rename(&format!("{}", nd));
}
Ok(out)
}
#[cfg(feature = "dtype-date")]
DataType::Date => {
for s in out.columns[1..].iter_mut() {
let days = s.name().parse::<i32>().unwrap();
let nd = date32_to_date(days);
s.rename(&format!("{}", nd));
}
Ok(out)
}
_ => Ok(out),
}
}

impl<'df, 'sel_str> Pivot<'df, 'sel_str> {
/// Aggregate the pivot results by taking the count the values.
/// Aggregate the pivot results by taking the count values.
pub fn count(&self) -> Result<DataFrame> {
let pivot_series = self.gb.df.column(self.pivot_column)?;
let values_series = self.gb.df.column(self.values_column)?;
values_series.pivot_count(pivot_series, self.gb.keys(), &self.gb.groups)
let out = values_series.pivot_count(
&pivot_series.to_physical_repr(),
self.gb.keys(),
&self.gb.groups,
)?;
finish_logical_types(out, pivot_series)
}

/// Aggregate the pivot results by taking the first occurring value.
pub fn first(&self) -> Result<DataFrame> {
let pivot_series = self.gb.df.column(self.pivot_column)?;
let values_series = self.gb.df.column(self.values_column)?;
values_series.pivot(
pivot_series,
let out = values_series.pivot(
&pivot_series.to_physical_repr(),
self.gb.keys(),
&self.gb.groups,
PivotAgg::First,
)
)?;
finish_logical_types(out, pivot_series)
}

/// Aggregate the pivot results by taking the sum of all duplicates.
pub fn sum(&self) -> Result<DataFrame> {
let pivot_series = self.gb.df.column(self.pivot_column)?;
let values_series = self.gb.df.column(self.values_column)?;
values_series.pivot(pivot_series, self.gb.keys(), &self.gb.groups, PivotAgg::Sum)
let out = values_series.pivot(
&pivot_series.to_physical_repr(),
self.gb.keys(),
&self.gb.groups,
PivotAgg::Sum,
)?;
finish_logical_types(out, pivot_series)
}

/// Aggregate the pivot results by taking the minimal value of all duplicates.
pub fn min(&self) -> Result<DataFrame> {
let pivot_series = self.gb.df.column(self.pivot_column)?;
let values_series = self.gb.df.column(self.values_column)?;
values_series.pivot(pivot_series, self.gb.keys(), &self.gb.groups, PivotAgg::Min)
let out = values_series.pivot(
&pivot_series.to_physical_repr(),
self.gb.keys(),
&self.gb.groups,
PivotAgg::Min,
)?;
finish_logical_types(out, pivot_series)
}

/// Aggregate the pivot results by taking the maximum value of all duplicates.
pub fn max(&self) -> Result<DataFrame> {
let pivot_series = self.gb.df.column(self.pivot_column)?;
let values_series = self.gb.df.column(self.values_column)?;
values_series.pivot(pivot_series, self.gb.keys(), &self.gb.groups, PivotAgg::Max)
let out = values_series.pivot(
&pivot_series.to_physical_repr(),
self.gb.keys(),
&self.gb.groups,
PivotAgg::Max,
)?;
finish_logical_types(out, pivot_series)
}

/// Aggregate the pivot results by taking the mean value of all duplicates.
pub fn mean(&self) -> Result<DataFrame> {
let pivot_series = self.gb.df.column(self.pivot_column)?;
let values_series = self.gb.df.column(self.values_column)?;
values_series.pivot(
pivot_series,
let out = values_series.pivot(
&pivot_series.to_physical_repr(),
self.gb.keys(),
&self.gb.groups,
PivotAgg::Mean,
)
)?;
finish_logical_types(out, pivot_series)
}
/// Aggregate the pivot results by taking the median value of all duplicates.
pub fn median(&self) -> Result<DataFrame> {
let pivot_series = self.gb.df.column(self.pivot_column)?;
let values_series = self.gb.df.column(self.values_column)?;
values_series.pivot(
pivot_series,
let out = values_series.pivot(
&pivot_series.to_physical_repr(),
self.gb.keys(),
&self.gb.groups,
PivotAgg::Median,
)
)?;
finish_logical_types(out, pivot_series)
}
}

Expand Down Expand Up @@ -702,4 +773,36 @@ mod test {
&[Some(0), Some(0), Some(2)]
);
}

#[test]
#[cfg(feature = "dtype-categorical")]
fn test_pivot_categorical() -> Result<()> {
let mut df = df![
"A" => [1, 1, 1, 1, 1, 1, 1, 1],
"B" => [8, 2, 3, 6, 3, 6, 2, 2],
"C" => ["a", "b", "c", "a", "b", "c", "a", "b"]
]?;
df.try_apply("C", |s| s.cast(&DataType::Categorical))?;

let out = df.groupby("B")?.pivot("C", "A").count()?;
assert_eq!(out.get_column_names(), &["B", "a", "b", "c"]);

Ok(())
}

#[test]
#[cfg(feature = "dtype-date")]
fn test_pivot_date() -> Result<()> {
let mut df = df![
"A" => [1, 1, 1, 1, 1, 1, 1, 1],
"B" => [8, 2, 3, 6, 3, 6, 2, 2],
"C" => [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]
]?;
df.try_apply("C", |s| s.cast(&DataType::Date))?;

let out = df.groupby("B")?.pivot("C", "A").count()?;
assert_eq!(out.get_column_names(), &["B", "1972-09-27"]);

Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ mod test {
}

#[test]
#[cfg(feature = "dtype-duration")]
fn test_duration() -> Result<()> {
let a = Int64Chunked::new("", &[1, 2, 3])
.into_datetime(TimeUnit::Nanoseconds, None)
Expand All @@ -821,7 +822,6 @@ mod test {
.into_duration(TimeUnit::Nanoseconds)
.into_series()
);
// assert_eq!(a.add_to(&c)?, b);
Ok(())
}
}

0 comments on commit 77b41d0

Please sign in to comment.