Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue for pandera allowing generic Series to work #492

Merged
merged 3 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ from typing_extensions import (
)
import xarray as xr

from pandas._libs.interval import Interval
from pandas._libs.missing import NAType
from pandas._libs.tslibs import BaseOffset
from pandas._typing import (
Expand All @@ -94,7 +95,6 @@ from pandas._typing import (
IgnoreRaise,
IndexingInt,
IntervalClosedType,
IntervalT,
JoinHow,
JsonSeriesOrient,
Level,
Expand Down Expand Up @@ -216,13 +216,43 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
@overload
def __new__(
cls,
data: IntervalIndex[IntervalT],
data: IntervalIndex[Interval[int]],
index: Axes | None = ...,
dtype=...,
name: Hashable | None = ...,
copy: bool = ...,
fastpath: bool = ...,
) -> Series[IntervalT]: ...
) -> Series[Interval[int]]: ...
@overload
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be nice to add a TODO referencing the mypy issue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be nice to add a TODO referencing the mypy issue

I didn't research whether there exists a mypy issue, or how to create a small sample to replicate.

I could add a comment indicating that we prefer the Series[IntervalT] notation, but that it doesn't seem to work and reference the original pandas-stubs issue. Let me know if I should do that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like the following work?

T = TypeVar("T", float, int, Timestamp, Timedelta)

        data: IntervalIndex[Interval[T]],
        ....
    ) -> Series[Interval[T]]: ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't see how this would work, because we have:

class IntervalIndex(IntervalMixin, ExtensionIndex, Generic[IntervalT]):

Since IntervalIndex is generic, and Interval is generic, you can't make IntervalIndex depend on the "inner" generic type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sorry for the trouble.

def __new__(
cls,
data: IntervalIndex[Interval[float]],
index: Axes | None = ...,
dtype=...,
name: Hashable | None = ...,
copy: bool = ...,
fastpath: bool = ...,
) -> Series[Interval[float]]: ...
@overload
def __new__(
cls,
data: IntervalIndex[Interval[Timestamp]],
index: Axes | None = ...,
dtype=...,
name: Hashable | None = ...,
copy: bool = ...,
fastpath: bool = ...,
) -> Series[Interval[Timestamp]]: ...
@overload
def __new__(
cls,
data: IntervalIndex[Interval[Timedelta]],
index: Axes | None = ...,
dtype=...,
name: Hashable | None = ...,
copy: bool = ...,
fastpath: bool = ...,
) -> Series[Interval[Timedelta]]: ...
@overload
def __new__(
cls,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
TYPE_CHECKING,
Any,
Dict,
Generic,
Hashable,
Iterable,
Iterator,
List,
Sequence,
TypeVar,
cast,
)

Expand Down Expand Up @@ -1362,3 +1364,16 @@ def test_AnyArrayLike_and_clip() -> None:
s2 = ser.clip(upper=ser)
check(assert_type(s1, pd.Series), pd.Series)
check(assert_type(s2, pd.Series), pd.Series)


def test_pandera_generic() -> None:
# GH 471
T = TypeVar("T")

class MySeries(pd.Series, Generic[T]):
...

def func() -> MySeries[float]:
return MySeries[float]([1, 2, 3])

func()