Skip to content

Commit

Permalink
Add overloads to DataFrame aggregation methods (#2205)
Browse files Browse the repository at this point in the history
* Add overloads to DataFrame aggregation methods

Note that for `sum` & `mean`, the arguments are now keyword only, as per previous discussions.
  • Loading branch information
zundertj committed Dec 28, 2021
1 parent f036caf commit 3008586
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 4 deletions.
72 changes: 71 additions & 1 deletion py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,20 @@ def to_arrow(self) -> "pa.Table":
record_batches = self._df.to_arrow()
return pa.Table.from_batches(record_batches)

@overload
def to_dict(self, as_series: Literal[True] = ...) -> Dict[str, "pli.Series"]:
...

@overload
def to_dict(self, as_series: Literal[False]) -> Dict[str, List[Any]]:
...

@overload
def to_dict(
self, as_series: bool = True
) -> Union[Dict[str, "pli.Series"], Dict[str, List[Any]]]:
...

def to_dict(
self, as_series: bool = True
) -> Union[Dict[str, "pli.Series"], Dict[str, List[Any]]]:
Expand Down Expand Up @@ -3338,6 +3352,18 @@ def n_chunks(self) -> int:
"""
return self._df.n_chunks()

@overload
def max(self, axis: Literal[0] = ...) -> "DataFrame":
...

@overload
def max(self, axis: Literal[1]) -> "pli.Series":
...

@overload
def max(self, axis: int = 0) -> Union["DataFrame", "pli.Series"]:
...

def max(self, axis: int = 0) -> Union["DataFrame", "pli.Series"]:
"""
Aggregate the columns of this DataFrame to their maximum value.
Expand Down Expand Up @@ -3368,6 +3394,18 @@ def max(self, axis: int = 0) -> Union["DataFrame", "pli.Series"]:
return pli.wrap_s(self._df.hmax())
raise ValueError("Axis should be 0 or 1.") # pragma: no cover

@overload
def min(self, axis: Literal[0] = ...) -> "DataFrame":
...

@overload
def min(self, axis: Literal[1]) -> "pli.Series":
...

@overload
def min(self, axis: int = 0) -> Union["DataFrame", "pli.Series"]:
...

def min(self, axis: int = 0) -> Union["DataFrame", "pli.Series"]:
"""
Aggregate the columns of this DataFrame to their minimum value.
Expand Down Expand Up @@ -3398,8 +3436,24 @@ def min(self, axis: int = 0) -> Union["DataFrame", "pli.Series"]:
return pli.wrap_s(self._df.hmin())
raise ValueError("Axis should be 0 or 1.") # pragma: no cover

@overload
def sum(
self, axis: int = 0, null_strategy: str = "ignore"
self, *, axis: Literal[0] = ..., null_strategy: str = "ignore"
) -> "DataFrame":
...

@overload
def sum(self, *, axis: Literal[1], null_strategy: str = "ignore") -> "pli.Series":
...

@overload
def sum(
self, *, axis: int = 0, null_strategy: str = "ignore"
) -> Union["DataFrame", "pli.Series"]:
...

def sum(
self, *, axis: int = 0, null_strategy: str = "ignore"
) -> Union["DataFrame", "pli.Series"]:
"""
Aggregate the columns of this DataFrame to their sum value.
Expand Down Expand Up @@ -3438,6 +3492,22 @@ def sum(
return pli.wrap_s(self._df.hsum(null_strategy))
raise ValueError("Axis should be 0 or 1.") # pragma: no cover

@overload
def mean(
self, *, axis: Literal[0] = ..., null_strategy: str = "ignore"
) -> "DataFrame":
...

@overload
def mean(self, *, axis: Literal[1], null_strategy: str = "ignore") -> "pli.Series":
...

@overload
def mean(
self, *, axis: int = 0, null_strategy: str = "ignore"
) -> Union["DataFrame", "pli.Series"]:
...

def mean(
self, axis: int = 0, null_strategy: str = "ignore"
) -> Union["DataFrame", "pli.Series"]:
Expand Down
12 changes: 9 additions & 3 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,9 +1222,15 @@ def test_panic() -> None:
def test_h_agg() -> None:
df = pl.DataFrame({"a": [1, None, 3], "b": [1, 2, 3]})

assert df.sum(axis=1, null_strategy="ignore").to_list() == [2, 2, 6]
assert df.sum(axis=1, null_strategy="propagate").to_list() == [2, None, 6]
assert df.mean(axis=1, null_strategy="propagate")[1] is None
pl.testing.assert_series_equal(
df.sum(axis=1, null_strategy="ignore"), pl.Series("a", [2, 2, 6])
)
pl.testing.assert_series_equal(
df.sum(axis=1, null_strategy="propagate"), pl.Series("a", [2, None, 6])
)
pl.testing.assert_series_equal(
df.mean(axis=1, null_strategy="propagate"), pl.Series("a", [1.0, None, 3.0])
)


def test_slicing() -> None:
Expand Down

0 comments on commit 3008586

Please sign in to comment.