Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
self,
orient: str = ...,
*,
into: type[defaultdict],
into: type[defaultdict[Any, Any]],
index: Literal[True] = True,
) -> Never: ...
@overload
Expand All @@ -500,7 +500,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
self,
orient: Literal["records"],
*,
into: type[dict] = ...,
into: type[dict[Any, Any]] = ...,
index: Literal[True] = True,
) -> list[dict[Hashable, Any]]: ...
@overload
Expand All @@ -516,23 +516,23 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
self,
orient: Literal["index"],
*,
into: OrderedDict | type[OrderedDict],
into: OrderedDict[Any, Any] | type[OrderedDict[Any, Any]],
index: Literal[True] = True,
) -> OrderedDict[Hashable, dict[Hashable, Any]]: ...
@overload
def to_dict(
self,
orient: Literal["index"],
*,
into: type[MutableMapping],
into: type[MutableMapping[Any, Any]],
index: Literal[True] = True,
) -> MutableMapping[Hashable, dict[Hashable, Any]]: ...
@overload
def to_dict(
self,
orient: Literal["index"],
*,
into: type[dict] = ...,
into: type[dict[Any, Any]] = ...,
index: Literal[True] = True,
) -> dict[Hashable, dict[Hashable, Any]]: ...
@overload
Expand All @@ -548,23 +548,23 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
self,
orient: Literal["dict", "list", "series"] = ...,
*,
into: type[dict] = ...,
into: type[dict[Any, Any]] = ...,
index: Literal[True] = True,
) -> dict[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["split", "tight"],
*,
into: MutableMapping[Any, Any] | type[MutableMapping],
into: MutableMapping[Any, Any] | type[MutableMapping[Any, Any]],
index: bool = ...,
) -> MutableMapping[str, list[Any]]: ...
@overload
def to_dict(
self,
orient: Literal["split", "tight"],
*,
into: type[dict] = ...,
into: type[dict[Any, Any]] = ...,
index: bool = ...,
) -> dict[str, list[Any]]: ...
@classmethod
Expand Down
8 changes: 4 additions & 4 deletions pandas-stubs/core/groupby/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
random_state: RandomState | None = ...,
) -> NDFrameT: ...

_GroupByT = TypeVar("_GroupByT", bound=GroupBy)
_GroupByT = TypeVar("_GroupByT", bound=GroupBy[Any])

# GroupByPlot does not really inherit from PlotAccessor but it delegates
# to it using __call__ and __getattr__. We lie here to avoid repeating the
Expand Down Expand Up @@ -383,15 +383,15 @@ class BaseGroupBy(SelectionMixin[NDFrameT], GroupByIndexingMixin):
@final
def __iter__(self) -> Iterator[tuple[Hashable, NDFrameT]]: ...
@overload
def __getitem__(self: BaseGroupBy[DataFrame], key: Scalar) -> generic.SeriesGroupBy: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
def __getitem__(self: BaseGroupBy[DataFrame], key: Scalar) -> generic.SeriesGroupBy[Any, Any]: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
@overload
def __getitem__(
self: BaseGroupBy[DataFrame], key: Iterable[Hashable]
) -> generic.DataFrameGroupBy: ...
) -> generic.DataFrameGroupBy[Any, Any]: ...
@overload
def __getitem__(
self: BaseGroupBy[Series[S1]],
idx: list[str] | Index | Series[S1] | MaskType | tuple[Hashable | slice, ...],
) -> generic.SeriesGroupBy: ...
) -> generic.SeriesGroupBy[Any, Any]: ...
@overload
def __getitem__(self: BaseGroupBy[Series[S1]], idx: Scalar) -> S1: ...
2 changes: 1 addition & 1 deletion pandas-stubs/core/resample.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ _SeriesGroupByFuncArgs: TypeAlias = (
)

class Resampler(BaseGroupBy[NDFrameT]):
def __getattr__(self, attr: str) -> SeriesGroupBy: ...
def __getattr__(self, attr: str) -> SeriesGroupBy[Any, Any]: ...
@overload
def aggregate(
self: Resampler[DataFrame],
Expand Down
6 changes: 3 additions & 3 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -779,10 +779,10 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame):
def items(self) -> Iterator[tuple[Hashable, S1]]: ...
def keys(self) -> Index: ...
@overload
def to_dict(self, *, into: type[dict] = ...) -> dict[Any, S1]: ...
def to_dict(self, *, into: type[dict[Any, Any]] = ...) -> dict[Hashable, S1]: ...
@overload
def to_dict(
self, *, into: type[MutableMapping] | MutableMapping[Any, Any]
self, *, into: type[MutableMapping[Any, Any]] | MutableMapping[Any, Any]
) -> MutableMapping[Hashable, S1]: ...
def to_frame(self, name: object | None = ...) -> DataFrame: ...
@overload
Expand Down Expand Up @@ -1105,7 +1105,7 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame):
def swaplevel(
self, i: Level = -2, j: Level = -1, copy: _bool = True
) -> Series[S1]: ...
def reorder_levels(self, order: list) -> Series[S1]: ...
def reorder_levels(self, order: list[Any]) -> Series[S1]: ...
def explode(self, ignore_index: _bool = ...) -> Series[S1]: ...
def unstack(
self,
Expand Down
36 changes: 26 additions & 10 deletions tests/frame/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,10 +3021,10 @@ def test_to_dict_simple() -> None:
check(assert_type(data.to_dict("dict"), dict[Hashable, Any]), dict)
check(assert_type(data.to_dict("list"), dict[Hashable, Any]), dict)
check(assert_type(data.to_dict("series"), dict[Hashable, Any]), dict)
check(assert_type(data.to_dict("split"), dict[str, list]), dict, str)
check(assert_type(data.to_dict("split"), dict[str, list[Any]]), dict, str)

# orient param accepting "tight" added in 1.4.0 https://pandas.pydata.org/docs/whatsnew/v1.4.0.html
check(assert_type(data.to_dict("tight"), dict[str, list]), dict, str)
check(assert_type(data.to_dict("tight"), dict[str, list[Any]]), dict, str)

if TYPE_CHECKING_INVALID_USAGE:

Expand Down Expand Up @@ -3075,7 +3075,7 @@ def test_to_dict_into_defaultdict() -> None:
defaultdict,
)
check(
assert_type(data.to_dict("tight", into=target), MutableMapping[str, list]),
assert_type(data.to_dict("tight", into=target), MutableMapping[str, list[Any]]),
defaultdict,
str,
)
Expand All @@ -3093,7 +3093,11 @@ def test_to_dict_into_ordered_dict() -> None:

data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]})

check(assert_type(data.to_dict(into=OrderedDict), OrderedDict), OrderedDict, tuple)
check(
assert_type(data.to_dict(into=OrderedDict), OrderedDict[Any, Any]),
OrderedDict,
tuple,
)
check(
assert_type(
data.to_dict("index", into=OrderedDict),
Expand All @@ -3102,12 +3106,16 @@ def test_to_dict_into_ordered_dict() -> None:
OrderedDict,
)
check(
assert_type(data.to_dict("tight", into=OrderedDict), MutableMapping[str, list]),
assert_type(
data.to_dict("tight", into=OrderedDict), MutableMapping[str, list[Any]]
),
OrderedDict,
str,
)
check(
assert_type(data.to_dict("records", into=OrderedDict), list[OrderedDict]),
assert_type(
data.to_dict("records", into=OrderedDict), list[OrderedDict[Any, Any]]
),
list,
OrderedDict,
)
Expand Down Expand Up @@ -3446,16 +3454,24 @@ def test_to_dict_index() -> None:
dict,
)
check(
assert_type(df.to_dict(orient="split", index=True), dict[str, list]), dict, str
assert_type(df.to_dict(orient="split", index=True), dict[str, list[Any]]),
dict,
str,
)
check(
assert_type(df.to_dict(orient="tight", index=True), dict[str, list]), dict, str
assert_type(df.to_dict(orient="tight", index=True), dict[str, list[Any]]),
dict,
str,
)
check(
assert_type(df.to_dict(orient="tight", index=False), dict[str, list]), dict, str
assert_type(df.to_dict(orient="tight", index=False), dict[str, list[Any]]),
dict,
str,
)
check(
assert_type(df.to_dict(orient="split", index=False), dict[str, list]), dict, str
assert_type(df.to_dict(orient="split", index=False), dict[str, list[Any]]),
dict,
str,
)
if TYPE_CHECKING_INVALID_USAGE:
_0 = df.to_dict(orient="records", index=False) # type: ignore[call-overload] # pyright: ignore[reportArgumentType,reportCallIssue]
Expand Down
6 changes: 3 additions & 3 deletions tests/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def makeseries(x: float) -> pd.Series:
def retseries(x: float) -> float:
return x

check(assert_type(s.apply(retseries).tolist(), list), list)
check(assert_type(s.apply(retseries).tolist(), list[Any]), list)

def retlist(x: float) -> list[float]:
return [x]
Expand Down Expand Up @@ -1780,7 +1780,7 @@ def test_types_to_list() -> None:

def test_types_to_dict() -> None:
s = pd.Series(["a", "b", "c"], dtype=str)
assert_type(s.to_dict(), dict[Any, str])
assert_type(s.to_dict(), dict[Hashable, str])


def test_categorical_codes() -> None:
Expand Down Expand Up @@ -2182,7 +2182,7 @@ def test_change_to_dict_return_type() -> None:
value = ["a", "b", "c"]
df = pd.DataFrame(zip(id, value), columns=["id", "value"])
fd = df.set_index("id")["value"].to_dict()
check(assert_type(fd, dict[Any, Any]), dict)
check(assert_type(fd, dict[Hashable, Any]), dict)


ASTYPE_BOOL_ARGS: list[tuple[BooleanDtypeArg, type]] = [
Expand Down
5 changes: 2 additions & 3 deletions tests/test_api_typing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportMissingTypeArgument=false
"""Test module for classes in pandas.api.typing."""

from typing import TypeAlias
Expand Down Expand Up @@ -26,9 +27,7 @@
Window,
)
import pytest
from typing_extensions import (
assert_type,
)
from typing_extensions import assert_type

from tests import (
check,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import decimal
from typing import Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -29,9 +30,9 @@ def test_tolist() -> None:
s1 = pd.Series(data1)
# python/mypy#19952: mypy believes ExtensionArray and its subclasses have a
# conflict and gives Any for s.array
check(assert_type(s.array.tolist(), list), list) # type: ignore[assert-type]
check(assert_type(s1.array.tolist(), list), list)
check(assert_type(pd.array([1, 2, 3]).tolist(), list), list)
check(assert_type(s.array.tolist(), list[Any]), list) # type: ignore[assert-type]
check(assert_type(s1.array.tolist(), list[Any]), list)
check(assert_type(pd.array([1, 2, 3]).tolist(), list[Any]), list)


def test_ExtensionArray_reduce_accumulate() -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/test_resampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportMissingTypeArgument=false
from collections.abc import (
Hashable,
Iterator,
Expand Down