Skip to content

Commit

Permalink
Type args and kwargs in pipe (#823)
Browse files Browse the repository at this point in the history
* Test dataframe pipe typing

* Test series pipe typing

* Remove pipe annotations from DataFrame

* Type args and kwargs parameters in generic pipe
  • Loading branch information
paw-lu committed Dec 27, 2023
1 parent a370cab commit 117e97a
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 15 deletions.
6 changes: 5 additions & 1 deletion pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ from pandas.core.generic import NDFrame
from pandas.core.groupby.grouper import Grouper
from pandas.core.indexes.base import Index
from pandas.core.series import Series
from typing_extensions import TypeAlias
from typing_extensions import (
ParamSpec,
TypeAlias,
)

from pandas._libs.interval import Interval
from pandas._libs.tslibs import (
Expand Down Expand Up @@ -447,6 +450,7 @@ JSONSerializable: TypeAlias = PythonScalar | list | dict
Axes: TypeAlias = AnyArrayLike | list | dict | range | tuple
Renamer: TypeAlias = Mapping[Any, Label] | Callable[[Any], Label]
T = TypeVar("T")
P = ParamSpec("P")
FuncType: TypeAlias = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
HashableT = TypeVar("HashableT", bound=Hashable)
Expand Down
7 changes: 0 additions & 7 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ from pandas._typing import (
StorageOptions,
StrLike,
Suffixes,
T as TType,
TimestampConvention,
ValidationOptions,
WriteBuffer,
Expand Down Expand Up @@ -1829,12 +1828,6 @@ class DataFrame(NDFrame, OpsMixin):
freq=...,
**kwargs,
) -> DataFrame: ...
def pipe(
self,
func: Callable[..., TType] | tuple[Callable[..., TType], _str],
*args,
**kwargs,
) -> TType: ...
def pop(self, item: _str) -> Series: ...
def pow(
self,
Expand Down
19 changes: 17 additions & 2 deletions pandas-stubs/core/generic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ from pandas import Index
import pandas.core.indexing as indexing
from pandas.core.series import Series
import sqlalchemy.engine
from typing_extensions import Self
from typing_extensions import (
Concatenate,
Self,
)

from pandas._typing import (
S1,
Expand All @@ -40,6 +43,7 @@ from pandas._typing import (
IgnoreRaise,
IndexLabel,
Level,
P,
ReplaceMethod,
SortKind,
StorageOptions,
Expand Down Expand Up @@ -352,8 +356,19 @@ class NDFrame(indexing.IndexingMixin):
) -> Self: ...
def head(self, n: int = ...) -> Self: ...
def tail(self, n: int = ...) -> Self: ...
@overload
def pipe(
self,
func: Callable[Concatenate[Self, P], T],
*args: P.args,
**kwargs: P.kwargs,
) -> T: ...
@overload
def pipe(
self, func: Callable[..., T] | tuple[Callable[..., T], str], *args, **kwargs
self,
func: tuple[Callable[..., T], str],
*args: Any,
**kwargs: Any,
) -> T: ...
def __finalize__(self, other, method=..., **kwargs) -> Self: ...
def __setattr__(self, name: _str, value) -> None: ...
Expand Down
112 changes: 107 additions & 5 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,21 +1436,123 @@ def foo(df: pd.DataFrame) -> pd.DataFrame:
.pipe(foo)
)

df = pd.DataFrame({"a": [1], "b": [2]})
check(assert_type(val, pd.DataFrame), pd.DataFrame)

check(assert_type(pd.DataFrame({"a": [1]}).pipe(foo), pd.DataFrame), pd.DataFrame)
check(assert_type(df.pipe(foo), pd.DataFrame), pd.DataFrame)

def bar(val: Styler) -> Styler:
return val

check(
assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(bar), Styler), Styler
)
check(assert_type(df.style.pipe(bar), Styler), Styler)

def baz(val: Styler) -> str:
return val.to_latex()

check(assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(baz), str), str)
check(assert_type(df.style.pipe(baz), str), str)

def qux(
df: pd.DataFrame,
positional_only: int,
/,
argument_1: list[float],
argument_2: str,
*,
keyword_only: tuple[int, int],
) -> pd.DataFrame:
return pd.DataFrame(df)

check(
assert_type(
df.pipe(qux, 1, [1.0, 2.0], argument_2="hi", keyword_only=(1, 2)),
pd.DataFrame,
),
pd.DataFrame,
)

if TYPE_CHECKING_INVALID_USAGE:
df.pipe(
qux,
"a", # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
[1.0, 2.0],
argument_2="hi",
keyword_only=(1, 2),
)
df.pipe(
qux,
1,
[1.0, "b"], # type: ignore[list-item] # pyright: ignore[reportGeneralTypeIssues]
argument_2="hi",
keyword_only=(1, 2),
)
df.pipe(
qux,
1,
[1.0, 2.0],
argument_2=11, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
keyword_only=(1, 2),
)
df.pipe(
qux,
1,
[1.0, 2.0],
argument_2="hi",
keyword_only=(1,), # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
)
df.pipe( # type: ignore[call-arg]
qux,
1,
[1.0, 2.0],
argument_3="hi", # pyright: ignore[reportGeneralTypeIssues]
keyword_only=(1, 2),
)
df.pipe( # type: ignore[misc]
qux,
1,
[1.0, 2.0],
11, # type: ignore[arg-type]
(1, 2), # pyright: ignore[reportGeneralTypeIssues]
)
df.pipe( # type: ignore[call-arg]
qux,
positional_only=1, # pyright: ignore[reportGeneralTypeIssues]
argument_1=[1.0, 2.0],
argument_2=11, # type: ignore[arg-type]
keyword_only=(1, 2),
)

def dataframe_not_first_arg(x: int, df: pd.DataFrame) -> pd.DataFrame:
return df

check(
assert_type(
df.pipe(
(
dataframe_not_first_arg,
"df",
),
1,
),
pd.DataFrame,
),
pd.DataFrame,
)

if TYPE_CHECKING_INVALID_USAGE:
df.pipe(
(
dataframe_not_first_arg, # type: ignore[arg-type]
1, # pyright: ignore[reportGeneralTypeIssues]
),
1,
)
df.pipe(
( # pyright: ignore[reportGeneralTypeIssues]
1, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
"df",
),
1,
)


# set_flags() method added in 1.2.0 https://pandas.pydata.org/docs/whatsnew/v1.2.0.html
Expand Down
110 changes: 110 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2914,3 +2914,113 @@ def test_timedeltaseries_operators() -> None:
pd.Series,
pd.Timedelta,
)


def test_pipe() -> None:
ser = pd.Series(range(10))

def first_arg_series(
ser: pd.Series,
positional_only: int,
/,
argument_1: list[float],
argument_2: str,
*,
keyword_only: tuple[int, int],
) -> pd.Series:
return ser

check(
assert_type(
ser.pipe(
first_arg_series,
1,
[1.0, 2.0],
argument_2="hi",
keyword_only=(1, 2),
),
pd.Series,
),
pd.Series,
)

if TYPE_CHECKING_INVALID_USAGE:
ser.pipe(
first_arg_series,
"a", # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
[1.0, 2.0],
argument_2="hi",
keyword_only=(1, 2),
)
ser.pipe(
first_arg_series,
1,
[1.0, "b"], # type: ignore[list-item] # pyright: ignore[reportGeneralTypeIssues]
argument_2="hi",
keyword_only=(1, 2),
)
ser.pipe(
first_arg_series,
1,
[1.0, 2.0],
argument_2=11, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
keyword_only=(1, 2),
)
ser.pipe(
first_arg_series,
1,
[1.0, 2.0],
argument_2="hi",
keyword_only=(1,), # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
)
ser.pipe( # type: ignore[call-arg]
first_arg_series,
1,
[1.0, 2.0],
argument_3="hi", # pyright: ignore[reportGeneralTypeIssues]
keyword_only=(1, 2),
)
ser.pipe( # type: ignore[misc]
first_arg_series,
1,
[1.0, 2.0],
11, # type: ignore[arg-type]
(1, 2), # pyright: ignore[reportGeneralTypeIssues]
)
ser.pipe( # type: ignore[call-arg]
first_arg_series,
positional_only=1, # pyright: ignore[reportGeneralTypeIssues]
argument_1=[1.0, 2.0],
argument_2=11, # type: ignore[arg-type]
keyword_only=(1, 2),
)

def first_arg_not_series(argument_1: int, ser: pd.Series) -> pd.Series:
return ser

check(
assert_type(
ser.pipe(
(first_arg_not_series, "ser"),
1,
),
pd.Series,
),
pd.Series,
)

if TYPE_CHECKING_INVALID_USAGE:
ser.pipe(
(
first_arg_not_series, # type: ignore[arg-type]
1, # pyright: ignore[reportGeneralTypeIssues]
),
1,
)
ser.pipe(
(
1, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
"df",
),
1,
)

0 comments on commit 117e97a

Please sign in to comment.