Skip to content

Commit

Permalink
[Python] change groupby dsl
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 13, 2020
1 parent 9e0fba8 commit 2bc4405
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 24 deletions.
44 changes: 41 additions & 3 deletions py-polars/polars/frame.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from .polars import PyDataFrame, PySeries
from typing import Dict, Sequence, List, Tuple, Optional
from typing import Dict, Sequence, List, Tuple, Optional, Union
from .series import Series, wrap_s
import numpy as np

Expand Down Expand Up @@ -121,8 +121,10 @@ def head(self, length: int = 5) -> DataFrame:
def tail(self, length: int = 5) -> DataFrame:
return wrap_df(self._df.tail(length))

def groupby(self, by: str, select: str, agg: str) -> DataFrame:
return wrap_df(self._df.groupby(by, select, agg))
def groupby(self, by: Union[str, List[str]]) -> GroupBy:
if isinstance(by, str):
by = [by]
return GroupBy(self._df, by)

def join(
self, df: DataFrame, left_on: str, right_on: str, how="inner"
Expand All @@ -148,3 +150,39 @@ def drop_in_place(self, name: str) -> Series:

def select_at_idx(self, idx: int) -> Series:
return wrap_s(self._df.select_at_idx(idx))


class GroupBy:
def __init__(self, df: DataFrame, by: List[str]):
self._df = df
self.by = by

def select(self, columns: Union[str, List[str]]) -> GBSelection:
if isinstance(columns, str):
columns = [columns]
return GBSelection(self._df, self.by, columns)


class GBSelection:
def __init__(self, df: DataFrame, by: List[str], selection: List[str]):
self._df = df
self.by = by
self.selection = selection

def first(self):
return wrap_df(self._df.groupby(self.by, self.selection, "first"))

def sum(self):
return wrap_df(self._df.groupby(self.by, self.selection, "sum"))

def min(self):
return wrap_df(self._df.groupby(self.by, self.selection, "min"))

def max(self):
return wrap_df(self._df.groupby(self.by, self.selection, "max"))

def count(self):
return wrap_df(self._df.groupby(self.by, self.selection, "count"))

def mean(self):
return wrap_df(self._df.groupby(self.by, self.selection, "mean"))
7 changes: 4 additions & 3 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,14 @@ impl PyDataFrame {
self.df.frame_equal(&other.df)
}

pub fn groupby(&self, by: &str, select: &str, agg: &str) -> PyResult<Self> {
let gb = self.df.groupby(by).map_err(PyPolarsEr::from)?;
let selection = gb.select(select);
pub fn groupby(&self, by: Vec<String>, select: Vec<String>, agg: &str) -> PyResult<Self> {
let gb = self.df.groupby(&by).map_err(PyPolarsEr::from)?;
let selection = gb.select(&select);
let df = match agg {
"min" => selection.min(),
"max" => selection.max(),
"mean" => selection.mean(),
"first" => selection.first(),
"sum" => selection.sum(),
"count" => selection.count(),
a => Err(PolarsError::Other(format!("agg fn {} does not exists", a))),
Expand Down
48 changes: 30 additions & 18 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,41 @@ def test_groupby():
"c": [6, 5, 4, 3, 2, 1],
}
)
assert df.groupby(by="a", select="b", agg="sum").frame_equal(
DataFrame({"a": ["a", "b", "c"], "": [4, 11, 6]})
assert (
df.groupby("a")
.select("b")
.sum()
.frame_equal(DataFrame({"a": ["a", "b", "c"], "": [4, 11, 6]}))
)
assert df.groupby(by="a", select="c", agg="sum").frame_equal(
DataFrame({"a": ["a", "b", "c"], "": [10, 10, 1]})
assert (
df.groupby("a")
.select("c")
.sum()
.frame_equal(DataFrame({"a": ["a", "b", "c"], "": [10, 10, 1]}))
)
assert df.groupby(by="a", select="b", agg="min").frame_equal(
DataFrame({"a": ["a", "b", "c"], "": [1, 2, 6]})
assert (
df.groupby("a")
.select("b")
.min()
.frame_equal(DataFrame({"a": ["a", "b", "c"], "": [1, 2, 6]}))
)
assert df.groupby(by="a", select="b", agg="min").frame_equal(
DataFrame({"a": ["a", "b", "c"], "": [1, 2, 6]})
assert (
df.groupby("a")
.select("b")
.max()
.frame_equal(DataFrame({"a": ["a", "b", "c"], "": [3, 5, 6]}))
)
assert df.groupby(by="a", select="b", agg="max").frame_equal(
DataFrame({"a": ["a", "b", "c"], "": [3, 5, 6]})
)
assert df.groupby(by="a", select="b", agg="mean").frame_equal(
DataFrame({"a": ["a", "b", "c"], "": [2.0, (2 + 4 + 5) / 3, 6.0]})
)

# TODO: is false because count is u32
df.groupby(by="a", select="b", agg="count").frame_equal(
DataFrame({"a": ["a", "b", "c"], "": [2, 3, 1]})
assert (
df.groupby("a")
.select("b")
.mean()
.frame_equal(DataFrame({"a": ["a", "b", "c"], "": [2.0, (2 + 4 + 5) / 3, 6.0]}))
)
#
# # TODO: is false because count is u32
# df.groupby(by="a", select="b", agg="count").frame_equal(
# DataFrame({"a": ["a", "b", "c"], "": [2, 3, 1]})
# )


def test_join():
Expand Down

0 comments on commit 2bc4405

Please sign in to comment.