Skip to content

Commit

Permalink
More mypy improvements (#4243)
Browse files Browse the repository at this point in the history
Co-authored-by: Matteo Santamaria <msantama@gmail.com>
  • Loading branch information
matteosantama and Matteo Santamaria committed Aug 4, 2022
1 parent e6dc1c9 commit d1e5b10
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 23 deletions.
64 changes: 51 additions & 13 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from io import BytesIO, IOBase, StringIO
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Expand Down Expand Up @@ -98,10 +99,25 @@
else:
from typing_extensions import Literal

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

# A type variable used to refer to a polars.DataFrame or any subclass of it.
# Used to annotate DataFrame methods which returns the same type as self.
DF = TypeVar("DF", bound="DataFrame")

if TYPE_CHECKING:
# these aliases are used to annotate DataFrame.__getitem__()
# MultiRowSelector indexes into the vertical axis and
# MultiColSelector indexes into the horizontal axis
# NOTE: wrapping these as strings is necessary for Python <3.10
MultiRowSelector: TypeAlias = "slice | range | list[int] | list[bool] | pli.Series"
MultiColSelector: TypeAlias = (
"slice | range | list[int] | list[bool] | list[str] | pli.Series"
)


def wrap_df(df: PyDataFrame) -> DataFrame:
return DataFrame._from_pydf(df)
Expand Down Expand Up @@ -1821,32 +1837,54 @@ def _pos_idxs(self, idxs: np.ndarray | pli.Series, dim: int) -> pli.Series:

raise NotImplementedError("Unsupported idxs datatype.")

# __getitem__() mostly returns a dataframe. The major exception is when a string is
# passed in. Note that there are more subtle cases possible where a non-string value
# leads to a Series.
@overload
def __getitem__(self, item: str) -> pli.Series:
def __getitem__(self: DF, item: str) -> pli.Series:
...

@overload
def __getitem__(
self: DF,
item: int | range | slice | np.ndarray | pli.Expr | pli.Series | list | tuple,
item: int
| np.ndarray
| pli.Expr
| list[pli.Expr]
| MultiColSelector
| tuple[int, MultiColSelector]
| tuple[MultiRowSelector, MultiColSelector],
) -> DF:
...

@overload
def __getitem__(self: DF, item: tuple[MultiRowSelector, int]) -> pli.Series:
...

@overload
def __getitem__(self: DF, item: tuple[MultiRowSelector, str]) -> pli.Series:
...

@overload
def __getitem__(self: DF, item: tuple[int, int]) -> Any:
...

@overload
def __getitem__(self: DF, item: tuple[int, str]) -> Any:
...

def __getitem__(
self: DF,
item: (
str
| int
| range
| slice
| np.ndarray
| pli.Expr
| pli.Series
| list
| tuple
| list[pli.Expr]
| MultiColSelector
| tuple[int, MultiColSelector]
| tuple[MultiRowSelector, MultiColSelector]
| tuple[MultiRowSelector, int]
| tuple[MultiRowSelector, str]
| tuple[int, int]
| tuple[int, str]
),
) -> DF | pli.Series:
"""Get item. Does quite a lot. Read the comments."""
Expand Down Expand Up @@ -1928,11 +1966,11 @@ def __getitem__(

if isinstance(col_selection, list):
# df[:, [1, 2]]
# select by column indexes
if isinstance(col_selection[0], int):
if is_int_sequence(col_selection):
series_list = [self.to_series(i) for i in col_selection]
df = self.__class__(series_list)
return df[row_selection]

df = self.__getitem__(col_selection)
return df.__getitem__(row_selection)

Expand Down Expand Up @@ -2041,7 +2079,7 @@ def __setitem__(
if isinstance(col_selection, str):
s = self.__getitem__(col_selection)
elif isinstance(col_selection, int):
s = self[:, col_selection] # type: ignore[assignment]
s = self[:, col_selection]
else:
raise ValueError(f"column selection not understood: {col_selection}")

Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,9 @@ def _pos_idxs(self, idxs: np.ndarray | Series) -> Series:

raise NotImplementedError("Unsupported idxs datatype.")

def __getitem__(self, item: int | Series | range | slice | np.ndarray) -> Any:
def __getitem__(
self, item: int | Series | range | slice | np.ndarray | list[int] | list[bool]
) -> Any:
if isinstance(item, int):
if item < 0:
item = self.len() + item
Expand Down
4 changes: 1 addition & 3 deletions py-polars/polars/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,7 @@ def verify_series_and_expr_api(
"""
expr = _getattr_multi(pli.col("*"), op)(*args, **kwargs)
result_expr: pli.Series = input.to_frame().select(expr)[ # type: ignore[assignment]
:, 0
]
result_expr = input.to_frame().select(expr)[:, 0]
result_series = _getattr_multi(input, op)(*args, **kwargs)
if expected is None:
assert_series_equal(result_series, result_expr)
Expand Down
8 changes: 2 additions & 6 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,10 @@ def test_flatten_explode() -> None:
df = pl.Series("a", ["Hello", "World"])
expected = pl.Series("a", ["H", "e", "l", "l", "o", "W", "o", "r", "l", "d"])

result: pl.Series = df.to_frame().select( # type: ignore[assignment]
pl.col("a").flatten()
)[:, 0]
result = df.to_frame().select(pl.col("a").flatten())[:, 0]
assert_series_equal(result, expected)

result: pl.Series = df.to_frame().select( # type: ignore[no-redef]
pl.col("a").explode()
)[:, 0]
result = df.to_frame().select(pl.col("a").explode())[:, 0]
assert_series_equal(result, expected)


Expand Down

0 comments on commit d1e5b10

Please sign in to comment.