Skip to content

Commit

Permalink
feat(python): Improve iterating over GroupBy (#6051)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jan 5, 2023
1 parent a8e7e4a commit 40424c4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 286 deletions.
245 changes: 22 additions & 223 deletions py-polars/polars/internals/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Sequence, TypeVar
from typing import TYPE_CHECKING, Callable, Generic, Sequence, TypeVar

import polars.internals as pli
from polars.internals.dataframe.pivot import PivotOps
from polars.utils import _timedelta_to_pl_duration

if TYPE_CHECKING:
from polars.datatypes import DataType
from polars.internals.type_aliases import (
ClosedWindow,
RollingInterpolationMethod,
Expand All @@ -31,9 +30,8 @@ class GroupBy(Generic[DF]):
Examples
--------
>>> df = pl.DataFrame({"foo": ["a", "a", "b"], "bar": [1, 2, 3]})
>>> for group in df.groupby("foo"):
>>> for group in df.groupby("foo", maintain_order=True):
... print(group)
... # doctest: +IGNORE_RESULT
...
shape: (2, 2)
┌─────┬─────┐
Expand All @@ -55,11 +53,6 @@ class GroupBy(Generic[DF]):
"""

_df: PyDataFrame
_dataframe_class: type[DF]
by: str | list[str]
maintain_order: bool

def __init__(
self,
df: PyDataFrame,
Expand Down Expand Up @@ -90,90 +83,31 @@ def __init__(
self.by = by
self.maintain_order = maintain_order

def __iter__(self) -> Iterable[Any]:
groups_df = self._groups()
groups = groups_df["groups"]
df = self._dataframe_class._from_pydf(self._df)
for i in range(groups_df.height):
yield df[groups[i]]

def _select(self, columns: str | list[str]) -> GBSelection[DF]: # pragma: no cover
"""
Select the columns that will be aggregated.
Parameters
----------
columns
One or multiple columns.
"""
warnings.warn(
"accessing GroupBy by index is deprecated, consider using the `.agg`"
" method",
DeprecationWarning,
)
if isinstance(columns, str):
columns = [columns]
return GBSelection(
self._df,
self.by,
columns,
dataframe_class=self._dataframe_class,
)
def __iter__(self) -> GroupBy[DF]:
by = {self.by} if isinstance(self.by, str) else set(self.by)

def _select_all(self) -> GBSelection[DF]:
"""Select all columns for aggregation."""
return GBSelection(
self._df,
self.by,
None,
dataframe_class=self._dataframe_class,
)
# Aggregate groups for any single column that is not specified as 'by'
columns = self._df.columns()
if len(by) < len(columns):
non_by_col = next(c for c in columns if c not in by)
groups_df = self.agg(pli.col(non_by_col).agg_groups())
group_indices = groups_df.select(non_by_col).to_series()
else:
group_indices = pli.Series([[i] for i in range(self._df.height())])

def _groups(self) -> DF: # pragma: no cover
"""
Get keys and group indices for each group in the groupby.
self._group_indices = group_indices
self._current_index = 0

Returns
-------
DataFrame
A DataFrame with:
return self

- the groupby keys
- the group indexes aggregated as lists
def __next__(self) -> DF:
if self._current_index >= len(self._group_indices):
raise StopIteration

Examples
--------
>>> df = pl.DataFrame(
... {
... "a": [1, 1, 2, 3, 4, 5],
... "b": [0.5, 0.5, 4, 10, 13, 14],
... "c": [True, True, True, False, True, True],
... "d": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
... }
... )
>>> df.groupby("d")._groups().sort(by="d")
shape: (3, 2)
┌────────┬───────────┐
│ d ┆ groups │
│ --- ┆ --- │
│ str ┆ list[u32] │
╞════════╪═══════════╡
│ Apple ┆ [0, 2, 3] │
│ Banana ┆ [4, 5] │
│ Orange ┆ [1] │
└────────┴───────────┘
"""
warnings.warn(
"accessing GroupBy by index is deprecated, consider using the `.agg`"
" method",
DeprecationWarning,
)
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, None, "groups")
)
df = self._dataframe_class._from_pydf(self._df)
group = df[self._group_indices[self._current_index]]
self._current_index += 1
return group

def apply(self, f: Callable[[pli.DataFrame], pli.DataFrame]) -> DF:
"""
Expand Down Expand Up @@ -884,138 +818,3 @@ def agg(self, aggs: pli.Expr | Sequence[pli.Expr]) -> pli.DataFrame:
.agg(aggs)
.collect(no_optimization=True)
)


class GBSelection(Generic[DF]):
"""Utility class returned in a groupby operation."""

def __init__(
self,
df: PyDataFrame,
by: str | Sequence[str],
selection: Sequence[str] | None,
dataframe_class: type[DF],
):
self._df = df
self.by = by
self.selection = selection
self._dataframe_class = dataframe_class

def first(self) -> DF:
"""Aggregate the first values in the group."""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "first")
)

def last(self) -> DF:
"""Aggregate the last values in the group."""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "last")
)

def sum(self) -> DF:
"""Reduce the groups to the sum."""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "sum")
)

def min(self) -> DF:
"""Reduce the groups to the minimal value."""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "min")
)

def max(self) -> DF:
"""Reduce the groups to the maximal value."""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "max")
)

def count(self) -> DF:
"""
Count the number of values in each group.
Examples
--------
>>> df = pl.DataFrame(
... {
... "foo": [1, None, 3, 4],
... "bar": ["a", "b", "c", "a"],
... }
... )
>>> df.groupby("bar", maintain_order=True).count() # counts nulls
shape: (3, 2)
┌─────┬───────┐
│ bar ┆ count │
│ --- ┆ --- │
│ str ┆ u32 │
╞═════╪═══════╡
│ a ┆ 2 │
│ b ┆ 1 │
│ c ┆ 1 │
└─────┴───────┘
"""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "count")
)

def mean(self) -> DF:
"""Reduce the groups to the mean values."""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "mean")
)

def n_unique(self) -> DF:
"""Count the unique values per group."""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "n_unique")
)

def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod = "nearest"
) -> DF:
"""
Compute the quantile per group.
Parameters
----------
quantile
Quantile between 0.0 and 1.0.
interpolation : {'nearest', 'higher', 'lower', 'midpoint', 'linear'}
Interpolation method.
"""
return self._dataframe_class._from_pydf(
self._df.groupby_quantile(self.by, self.selection, quantile, interpolation)
)

def median(self) -> DF:
"""Return the median per group."""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "median")
)

def agg_list(self) -> DF:
"""Aggregate the groups into Series."""
return self._dataframe_class._from_pydf(
self._df.groupby(self.by, self.selection, "agg_list")
)

def apply(
self,
func: Callable[[Any], Any],
return_dtype: type[DataType] | None = None,
) -> DF:
"""Apply a function over the groups."""
df = self.agg_list()
if self.selection is None:
raise TypeError(
"apply not available for Groupby.select_all(). Use select() instead."
)
for name in self.selection:
s = df.drop_in_place(name + "_agg_list").apply(func, return_dtype)
s.rename(name, in_place=True)
df.with_column(s)

return df
51 changes: 0 additions & 51 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::io::BufWriter;
use std::ops::Deref;

use numpy::IntoPyArray;
use polars::frame::groupby::GroupBy;
use polars::frame::row::{rows_to_schema_supertypes, Row};
#[cfg(feature = "avro")]
use polars::io::avro::AvroCompression;
Expand Down Expand Up @@ -1057,16 +1056,6 @@ impl PyDataFrame {
Ok(df.into())
}

pub fn groupby(&self, by: Vec<&str>, select: Option<Vec<String>>, agg: &str) -> PyResult<Self> {
let gb = Python::with_gil(|py| py.allow_threads(|| self.df.groupby(&by)))
.map_err(PyPolarsErr::from)?;
let selection = match select.as_ref() {
Some(s) => gb.select(s),
None => gb,
};
finish_groupby(selection, agg)
}

pub fn groupby_apply(&self, by: Vec<&str>, lambda: PyObject) -> PyResult<Self> {
let gb = self.df.groupby(&by).map_err(PyPolarsErr::from)?;
let function = move |df: DataFrame| {
Expand Down Expand Up @@ -1105,21 +1094,6 @@ impl PyDataFrame {
Ok(df.into())
}

#[allow(deprecated)]
pub fn groupby_quantile(
&self,
by: Vec<&str>,
select: Vec<String>,
quantile: f64,
interpolation: Wrap<QuantileInterpolOptions>,
) -> PyResult<Self> {
let gb = self.df.groupby(&by).map_err(PyPolarsErr::from)?;
let selection = gb.select(&select);
let df = selection.quantile(quantile, interpolation.0);
let df = df.map_err(PyPolarsErr::from)?;
Ok(PyDataFrame::new(df))
}

pub fn clone(&self) -> Self {
PyDataFrame::new(self.df.clone())
}
Expand Down Expand Up @@ -1399,28 +1373,3 @@ impl PyDataFrame {
Ok(df.into())
}
}

#[allow(deprecated)]
fn finish_groupby(gb: GroupBy, agg: &str) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
let df = py.allow_threads(|| match agg {
"min" => gb.min(),
"max" => gb.max(),
"mean" => gb.mean(),
"first" => gb.first(),
"last" => gb.last(),
"sum" => gb.sum(),
"count" => gb.count(),
"n_unique" => gb.n_unique(),
"median" => gb.median(),
"agg_list" => gb.agg_list(),
"groups" => gb.groups(),
a => Err(PolarsError::ComputeError(
format!("agg fn {a} does not exists").into(),
)),
});

let df = df.map_err(PyPolarsErr::from)?;
Ok(PyDataFrame::new(df))
})
}

0 comments on commit 40424c4

Please sign in to comment.