Skip to content

Commit

Permalink
Tighten type annotations on DataFrame and Series get/set (#2142)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Dec 23, 2021
1 parent e19b0ae commit f3013a7
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 42 deletions.
35 changes: 20 additions & 15 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@

from polars._html import NotebookFormatter
from polars.datatypes import Boolean, DataType, Datetime, UInt32, py_type_to_dtype
from polars.utils import _process_null_values, is_int_sequence, is_str_sequence
from polars.utils import (
_process_null_values,
is_int_sequence,
is_str_sequence,
range_to_slice,
)

try:
import pandas as pd
Expand Down Expand Up @@ -1155,11 +1160,16 @@ def __getitem__(
) -> "DataFrame":
...

def __getitem__(self, item: Any) -> Union["DataFrame", "pli.Series"]:
def __getitem__(
self,
item: Union[
str, int, range, slice, np.ndarray, "pli.Expr", "pli.Series", List, tuple
],
) -> Union["DataFrame", "pli.Series"]:
"""
Does quite a lot. Read the comments.
"""
if hasattr(item, "_pyexpr"):
if isinstance(item, pli.Expr):
return self.select(item)
# select rows and columns at once
# every 2d selection, i.e. tuple is row column order, just like numpy
Expand Down Expand Up @@ -1250,14 +1260,7 @@ def __getitem__(self, item: Any) -> Union["DataFrame", "pli.Series"]:

# df[range(n)]
if isinstance(item, range):
step: Optional[int]
# maybe we can slice instead of take by indices
if item.step != 1:
step = item.step
else:
step = None
slc = slice(item.start, item.stop, step)
return self[slc]
return self[range_to_slice(item)]

# df[:]
if isinstance(item, slice):
Expand Down Expand Up @@ -1306,7 +1309,7 @@ def __getitem__(self, item: Any) -> Union["DataFrame", "pli.Series"]:
return self.select(item)
elif type(item[0]) == bool:
item = pli.Series("", item) # fall through to next if isinstance
else:
elif is_int_sequence(item):
return wrap_df(self._df.take([self._pos_idx(i, dim=0) for i in item]))

if isinstance(item, pli.Series):
Expand All @@ -1317,9 +1320,11 @@ def __getitem__(self, item: Any) -> Union["DataFrame", "pli.Series"]:
return wrap_df(self._df.take_with_series(item.inner()))

# if no data has been returned, the operation is not supported
raise IndexError
raise NotImplementedError

def __setitem__(self, key: Union[str, int, Tuple[Any, Any]], value: Any) -> None:
def __setitem__(
self, key: Union[str, int, List, Tuple[Any, Any]], value: Any
) -> None:
# df["foo"] = series
if isinstance(key, str):
try:
Expand Down Expand Up @@ -4024,7 +4029,7 @@ def get_group(self, group_value: Union[Any, Tuple[Any]]) -> DataFrame:

# should be only one match
try:
groups_idx = groups[mask][0]
groups_idx = groups[mask][0] # type: ignore
except IndexError:
raise ValueError(f"no group: {group_value} found")

Expand Down
33 changes: 14 additions & 19 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@
maybe_cast,
py_type_to_dtype,
)
from polars.utils import _date_to_pl_date, _datetime_to_pl_timestamp, _ptr_to_numpy
from polars.utils import (
_date_to_pl_date,
_datetime_to_pl_timestamp,
_ptr_to_numpy,
range_to_slice,
)

try:
import pandas as pd
Expand Down Expand Up @@ -402,7 +407,7 @@ def __pow__(self, power: float, modulo: None = None) -> "Series":
def __neg__(self) -> "Series":
return 0 - self

def __getitem__(self, item: Any) -> Any:
def __getitem__(self, item: Union[int, "Series", range, slice]) -> Any:
if isinstance(item, int):
if item < 0:
item = self.len() + item
Expand All @@ -423,32 +428,22 @@ def __getitem__(self, item: Any) -> Any:
return wrap_s(self._s.filter(item._s))

if isinstance(item, range):
step: Optional[int]
# maybe we can slice instead of take by indices
if item.step != 1:
step = item.step
else:
step = None
slc = slice(item.start, item.stop, step)
return self[slc]
return self[range_to_slice(item)]

# slice
if type(item) == slice:
if isinstance(item, slice):
start, stop, stride = item.indices(self.len())
out = self.slice(start, stop - start)
if stride != 1:
return out.take_every(stride)
else:
return out
f = get_ffi_func("get_<>", self.dtype, self._s)
if f is None:
return NotImplemented
out = f(item)
if self.dtype == PlList:
return wrap_s(out)
return out

def __setitem__(self, key: Any, value: Any) -> None:
raise NotImplementedError

def __setitem__(
self, key: Union[int, "Series", np.ndarray, List, Tuple], value: Any
) -> None:
if isinstance(value, list):
raise ValueError("cannot set with a list as value, use a primitive value")
if isinstance(key, Series):
Expand Down
12 changes: 11 additions & 1 deletion py-polars/polars/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ctypes
import sys
from datetime import date, datetime, timedelta, timezone
from typing import Any, Dict, Iterable, List, Sequence, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union

import numpy as np

Expand Down Expand Up @@ -77,3 +77,13 @@ def is_int_sequence(val: Sequence[object]) -> TypeGuard[Sequence[int]]:

def _is_iterable_of(val: Iterable, itertype: Type, eltype: Type) -> bool:
return isinstance(val, itertype) and all(isinstance(x, eltype) for x in val)


def range_to_slice(rng: range) -> slice:
step: Optional[int]
# maybe we can slice instead of take by indices
if rng.step != 1:
step = rng.step
else:
step = None
return slice(rng.start, rng.stop, step)
8 changes: 4 additions & 4 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,8 +720,8 @@ def test_from_pandas_nan_to_none() -> None:
"nulls": [None, np.nan, np.nan],
}
)
out_true = pl.from_pandas(df)
out_false = pl.from_pandas(df, nan_to_none=False)
out_true: pl.DataFrame = pl.from_pandas(df) # type: ignore
out_false: pl.DataFrame = pl.from_pandas(df, nan_to_none=False) # type: ignore
df.loc[2, "nulls"] = pd.NA
assert [val is None for val in out_true["nulls"]]
assert [np.isnan(val) for val in out_false["nulls"][1:]]
Expand Down Expand Up @@ -1593,7 +1593,7 @@ def test_get_item() -> None:
df[np.array([1])].frame_equal(pl.DataFrame({"a": [2.0], "b": [4]}))
df[np.array(["a"])].frame_equal(pl.DataFrame({"a": [1.0, 2.0]}))
# note that we cannot use floats (even if they could be casted to integer without loss)
with pytest.raises(IndexError):
with pytest.raises(NotImplementedError):
_ = df[np.array([1.0])]

# sequences (lists or tuples; tuple only if length != 2)
Expand All @@ -1608,7 +1608,7 @@ def test_get_item() -> None:
# pl.Series: like sequences, but only for rows
df[[1]].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))
df[[False, True]].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))
with pytest.raises(IndexError):
with pytest.raises(NotImplementedError):
_ = df[pl.Series("", ["hello Im a string"])]


Expand Down
3 changes: 2 additions & 1 deletion py-polars/tests/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
def test_from_pandas_datetime() -> None:
ts = datetime.datetime(2021, 1, 1, 20, 20, 20, 20)
s = pd.Series([ts, ts])
s = pl.from_pandas(s.to_frame("a"))["a"]
tmp: pl.DataFrame = pl.from_pandas(s.to_frame("a")) # type: ignore
s = tmp["a"]
assert s.dt.hour()[0] == 20
assert s.dt.minute()[0] == 20
assert s.dt.second()[0] == 20
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import date, datetime
from typing import Any, Sequence
from typing import Any, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_set_np_array(dtype: Any) -> None:


@pytest.mark.parametrize("idx", [[0, 2], (0, 2)])
def test_set_list_and_tuple(idx: Sequence) -> None:
def test_set_list_and_tuple(idx: Union[list, tuple]) -> None:
a = pl.Series("a", [1, 2, 3])
a[idx] = 4
testing.assert_series_equal(a, pl.Series("a", [4, 2, 4]))
Expand Down

0 comments on commit f3013a7

Please sign in to comment.