Skip to content

Commit

Permalink
group by aggregate function 'last'; #44
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 3, 2020
1 parent 6b301e6 commit 08edbad
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 12 deletions.
92 changes: 80 additions & 12 deletions polars/src/frame/group_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,7 @@ where

#[enum_dispatch(Series)]
trait AggFirst {
fn agg_first(&self, _groups: &Vec<(usize, Vec<usize>)>) -> Series {
unimplemented!()
}
fn agg_first(&self, _groups: &Vec<(usize, Vec<usize>)>) -> Series;
}

macro_rules! impl_agg_first {
Expand All @@ -424,30 +422,60 @@ macro_rules! impl_agg_first {

impl<T> AggFirst for ChunkedArray<T>
where
T: PolarsNumericType + std::marker::Sync,
T: ArrowPrimitiveType + Send,
{
fn agg_first(&self, groups: &Vec<(usize, Vec<usize>)>) -> Series {
impl_agg_first!(self, groups, ChunkedArray<T>)
}
}

impl AggFirst for BooleanChunked {
impl AggFirst for Utf8Chunked {
fn agg_first(&self, groups: &Vec<(usize, Vec<usize>)>) -> Series {
impl_agg_first!(self, groups, BooleanChunked)
impl_agg_first!(self, groups, Utf8Chunked)
}
}

impl AggFirst for Utf8Chunked {
impl AggFirst for LargeListChunked {
fn agg_first(&self, groups: &Vec<(usize, Vec<usize>)>) -> Series {
groups
impl_agg_first!(self, groups, LargeListChunked)
}
}

#[enum_dispatch(Series)]
trait AggLast {
fn agg_last(&self, _groups: &Vec<(usize, Vec<usize>)>) -> Series;
}

macro_rules! impl_agg_last {
($self:ident, $groups:ident, $ca_type:ty) => {{
$groups
.iter()
.map(|(first, _idx)| self.get(*first))
.collect::<Utf8Chunked>()
.map(|(_first, idx)| $self.get(idx[idx.len() - 1]))
.collect::<$ca_type>()
.into_series()
}};
}

impl<T> AggLast for ChunkedArray<T>
where
T: ArrowPrimitiveType + Send,
{
fn agg_last(&self, groups: &Vec<(usize, Vec<usize>)>) -> Series {
impl_agg_last!(self, groups, ChunkedArray<T>)
}
}

impl AggLast for Utf8Chunked {
fn agg_last(&self, groups: &Vec<(usize, Vec<usize>)>) -> Series {
impl_agg_last!(self, groups, Utf8Chunked)
}
}

impl AggFirst for LargeListChunked {}
impl AggLast for LargeListChunked {
fn agg_last(&self, groups: &Vec<(usize, Vec<usize>)>) -> Series {
impl_agg_last!(self, groups, LargeListChunked)
}
}

impl<'df, 'selection_str> GroupBy<'df, 'selection_str> {
/// Select the column by which the determine the groups.
Expand Down Expand Up @@ -638,7 +666,7 @@ impl<'df, 'selection_str> GroupBy<'df, 'selection_str> {
DataFrame::new(cols)
}

/// Aggregate grouped series and find the first value per group.
/// Aggregate grouped `Series` and find the first value per group.
///
/// # Example
///
Expand Down Expand Up @@ -674,6 +702,42 @@ impl<'df, 'selection_str> GroupBy<'df, 'selection_str> {
DataFrame::new(cols)
}

/// Aggregate grouped `Series` and return the last value per group.
///
/// # Example
///
/// ```rust
/// # use polars::prelude::*;
/// fn example(df: DataFrame) -> Result<DataFrame> {
/// df.groupby("date")?.select("temp").last()
/// }
/// ```
/// Returns:
///
/// ```text
/// +------------+------------+
/// | date | temp_last |
/// | --- | --- |
/// | date32 | i32 |
/// +============+============+
/// | 2020-08-23 | 9 |
/// +------------+------------+
/// | 2020-08-22 | 1 |
/// +------------+------------+
/// | 2020-08-21 | 10 |
/// +------------+------------+
/// ```
pub fn last(&self) -> Result<DataFrame> {
let (mut cols, agg_cols) = self.prepare_agg()?;
for agg_col in agg_cols {
let new_name = format!["{}_last", agg_col.name()];
let mut agg = agg_col.agg_last(&self.groups);
agg.rename(&new_name);
cols.push(agg);
}
DataFrame::new(cols)
}

/// Aggregate grouped series and compute the number of values per group.
///
/// # Example
Expand Down Expand Up @@ -1199,6 +1263,10 @@ mod test {
"{:?}",
df.groupby("date").unwrap().select("temp").first().unwrap()
);
println!(
"{:?}",
df.groupby("date").unwrap().select("temp").last().unwrap()
);
}

#[test]
Expand Down
42 changes: 42 additions & 0 deletions py-polars/pypolars/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,14 +438,35 @@ def __init__(self, df: DataFrame, by: List[str]):
self.by = by

def select(self, columns: Union[str, List[str]]) -> GBSelection:
"""
Select the columns that will be aggregated.
Parameters
----------
columns
One or multiple columns
"""
if isinstance(columns, str):
columns = [columns]
return GBSelection(self._df, self.by, columns)

def select_all(self):
"""
Select all columns for aggregation.
"""
return GBSelection(self._df, self.by, self._df.columns())

def pivot(self, pivot_column: str, values_column: str) -> PivotOps:
"""
Do a pivot operation based on the group key, a pivot column and an aggregation function on the values column.
Parameters
----------
pivot_column
Column to pivot.
values_column
Column that will be aggregated
"""
return PivotOps(self._df, self.by, pivot_column, values_column)


Expand Down Expand Up @@ -496,18 +517,39 @@ def __init__(self, df: DataFrame, by: List[str], selection: List[str]):
self.selection = selection

def first(self):
"""
Aggregate the first value in the group.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "first"))

def last(self):
"""
Aggregate the first value in the group.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "last"))

def sum(self):
"""
Reduce the groups to the sum.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "sum"))

def min(self):
"""
Reduce the groups to the minimal value.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "min"))

def max(self):
"""
Reduce the groups to the maximal value.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "max"))

def count(self):
"""
Count the number of values in each group.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "count"))

def mean(self):
Expand Down
1 change: 1 addition & 0 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl PyDataFrame {
"max" => selection.max(),
"mean" => selection.mean(),
"first" => selection.first(),
"last" => selection.last(),
"sum" => selection.sum(),
"count" => selection.count(),
a => Err(PolarsError::Other(format!("agg fn {} does not exists", a))),
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def test_groupby():
.mean()
.frame_equal(DataFrame({"a": ["a", "b", "c"], "": [2.0, (2 + 4 + 5) / 3, 6.0]}))
)
assert (
df.groupby("a")
.select("b")
.last()
.frame_equal(DataFrame({"a": ["a", "b", "c"], "": [3, 5, 6]}))
)
#
# # TODO: is false because count is u32
# df.groupby(by="a", select="b", agg="count").frame_equal(
Expand Down

0 comments on commit 08edbad

Please sign in to comment.