Skip to content

Commit

Permalink
Refactor DataFrame.__getitem__ (#2134)
Browse files Browse the repository at this point in the history
* added a test covering all one-dimensional uses. Most were hit already by other tests, but good to have this in a single place. The tuple input (2D) is to be tested in a future PR.
* moved around some logic for numpy ndarray, it was confusing as some code path could not be hit at all.
* added type overloads so a user no longer loses typing the moment this method is used. This has led in various places in the tests for type warnings to pop up, some I could not fix/didnt fully understand, added ignores in those cases.
  • Loading branch information
zundertj committed Dec 23, 2021
1 parent 3bca19d commit f0767f4
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 40 deletions.
75 changes: 45 additions & 30 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import os
import sys
from datetime import timedelta
from datetime import datetime, timedelta
from io import BytesIO, StringIO
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -1140,17 +1140,30 @@ def _pos_idx(self, idx: int, dim: int) -> int:
else:
return self.shape[dim] + idx

def __getitem__(self, item: Any) -> Any:
# __getitem__() mostly returns a dataframe. The major exception is when a string is passed in. Note that there are
# more subtle cases possible where a non-string value leads to a Series.
@overload
def __getitem__(self, item: str) -> "pli.Series": # type: ignore
...

@overload
def __getitem__(
self,
item: Union[
int, range, slice, np.ndarray, "pli.Expr", "pli.Series", List, tuple
],
) -> "DataFrame":
...

def __getitem__(self, item: Any) -> Union["DataFrame", "pli.Series"]:
"""
Does quite a lot. Read the comments.
"""
if hasattr(item, "_pyexpr"):
return self.select(item)
if isinstance(item, np.ndarray):
item = pli.Series("", item)
# select rows and columns at once
# every 2d selection, i.e. tuple is row column order, just like numpy
if isinstance(item, tuple):
if isinstance(item, tuple) and len(item) == 2:
row_selection, col_selection = item

# df[:, unknown]
Expand Down Expand Up @@ -1220,10 +1233,10 @@ def __getitem__(self, item: Any) -> Any:
# df[:, [1, 2]]
# select by column indexes
if isinstance(col_selection[0], int):
series = [self.to_series(i) for i in col_selection]
df = DataFrame(series)
series_list = [self.to_series(i) for i in col_selection]
df = DataFrame(series_list)
return df[row_selection]
df = self.__getitem__(col_selection)
df = self.__getitem__(col_selection) # type: ignore
return df.__getitem__(row_selection)

# select single column
Expand Down Expand Up @@ -1273,37 +1286,39 @@ def __getitem__(self, item: Any) -> Any:
pli.col("*").slice(start, length).take_every(item.step) # type: ignore
)

# select multiple columns
# df["foo", "bar"]
if isinstance(item, Sequence):
if isinstance(item[0], str):
return wrap_df(self._df.select(item))
elif isinstance(item[0], pli.Expr):
return self.select(item)

# select rows by mask or index
# select rows by numpy mask or index
# df[[1, 2, 3]]
# df[true, false, true]
# df[[true, false, true]]
if isinstance(item, np.ndarray):
if item.dtype == int:
return wrap_df(self._df.take(item))
if isinstance(item[0], str):
return wrap_df(self._df.select(item))
if isinstance(item, (pli.Series, Sequence)):
if isinstance(item, Sequence):
# only bool or integers allowed
if type(item[0]) == bool:
item = pli.Series("", item)
else:
return wrap_df(
self._df.take([self._pos_idx(i, dim=0) for i in item])
)
if item.dtype == bool:
return wrap_df(self._df.filter(pli.Series("", item).inner()))

if isinstance(item, Sequence):
if isinstance(item[0], str):
# select multiple columns
# df[["foo", "bar"]]
return wrap_df(self._df.select(item))
elif isinstance(item[0], pli.Expr):
return self.select(item)
elif type(item[0]) == bool:
item = pli.Series("", item) # fall through to next if isinstance
else:
return wrap_df(self._df.take([self._pos_idx(i, dim=0) for i in item]))

if isinstance(item, pli.Series):
dtype = item.dtype
if dtype == Boolean:
return wrap_df(self._df.filter(item.inner()))
if dtype == UInt32:
return wrap_df(self._df.take_with_series(item.inner()))

# if no data has been returned, the operation is not supported
raise IndexError

def __setitem__(self, key: Union[str, int, Tuple[Any, Any]], value: Any) -> None:
# df["foo"] = series
if isinstance(key, str):
Expand Down Expand Up @@ -1335,7 +1350,7 @@ def __setitem__(self, key: Union[str, int, Tuple[Any, Any]], value: Any) -> None
if isinstance(col_selection, str):
s = self.__getitem__(col_selection)
elif isinstance(col_selection, int):
s = self[:, col_selection]
s = self[:, col_selection] # type: ignore
else:
raise ValueError(f"column selection not understood: {col_selection}")

Expand Down Expand Up @@ -2526,8 +2541,8 @@ def upsample(self, by: str, interval: Union[str, timedelta]) -> "DataFrame":
bounds = self.select(
[pli.col(by).min().alias("low"), pli.col(by).max().alias("high")]
)
low = bounds["low"].dt[0]
high = bounds["high"].dt[0]
low: datetime = bounds["low"].dt[0] # type: ignore
high: datetime = bounds["high"].dt[0] # type: ignore
upsampled = pli.date_range(low, high, interval, name=by)
return DataFrame(upsampled).join(self, on=by, how="left")

Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_apply_none() -> None:
assert out[0].to_list() == [4.75, 326.75, 82.75]
assert out[1].to_list() == [238.75, 3418849.75, 372.75]

out = df.select(pl.map(exprs=["a", "b"], f=lambda s: s[0] * s[1]))
assert out["a"].to_list() == (df["a"] * df["b"]).to_list()
out_df = df.select(pl.map(exprs=["a", "b"], f=lambda s: s[0] * s[1]))
assert out_df["a"].to_list() == (df["a"] * df["b"]).to_list()

# check if we can return None
def func(s: List) -> Optional[int]:
Expand Down
84 changes: 81 additions & 3 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

import polars as pl
from polars import testing


def test_version() -> None:
Expand Down Expand Up @@ -374,7 +375,7 @@ def test_groupby() -> None:
# df.groupby(by="a", select="b", agg="count").frame_equal(
# pl.DataFrame({"a": ["a", "b", "c"], "": [2, 3, 1]})
# )
assert df.groupby("a").apply(lambda df: df[["c"]].sum()).sort("c")["c"][0] == 1
assert df.groupby("a").apply(lambda df: df[["c"]].sum()).sort("c")["c"][0] == 1 # type: ignore

assert (
df.groupby("a")
Expand Down Expand Up @@ -742,10 +743,10 @@ def test_lazy_functions() -> None:
]
expected = 1.0
assert np.isclose(out.select_at_idx(0), expected)
assert np.isclose(pl.var(df["b"]), expected)
assert np.isclose(pl.var(df["b"]), expected) # type: ignore
expected = 1.0
assert np.isclose(out.select_at_idx(1), expected)
assert np.isclose(pl.std(df["b"]), expected)
assert np.isclose(pl.std(df["b"]), expected) # type: ignore
expected = 3
assert np.isclose(out.select_at_idx(2), expected)
assert np.isclose(pl.max(df["b"]), expected)
Expand Down Expand Up @@ -1329,3 +1330,80 @@ def test_df_schema_unique() -> None:

def test_empty_projection() -> None:
assert pl.DataFrame({"a": [1, 2], "b": [3, 4]}).select([]).shape == (0, 0)


def test_arithmetic() -> None:
df = pl.DataFrame({"a": [1.0, 2.0], "b": [3.0, 4.0]})

df_mul = df * 2
expected = pl.DataFrame({"a": [2, 4], "b": [6, 8]})
assert df_mul.frame_equal(expected)

df_div = df / 2
expected = pl.DataFrame({"a": [0.5, 1.0], "b": [1.5, 2.0]})
assert df_div.frame_equal(expected)

df_plus = df + 2
expected = pl.DataFrame({"a": [3, 4], "b": [5, 6]})
assert df_plus.frame_equal(expected)

df_minus = df - 2
expected = pl.DataFrame({"a": [-1, 0], "b": [1, 2]})
assert df_minus.frame_equal(expected)


def test_getattr() -> None:
df = pl.DataFrame({"a": [1.0, 2.0]})
testing.assert_series_equal(df.a, pl.Series("a", [1.0, 2.0]))

with pytest.raises(AttributeError):
_ = df.b


def test_get_item() -> None:
"""test all the methods to use [] on a dataframe"""
df = pl.DataFrame({"a": [1.0, 2.0], "b": [3, 4]})

# expression
assert df[pl.col("a")].frame_equal(pl.DataFrame({"a": [1.0, 2.0]}))

# numpy array
assert df[np.array([True, False])].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))

# tuple. The first element refers to the rows, the second element to columns
assert df[:, :].frame_equal(df)

# str, always refers to a column name
assert df["a"].series_equal(pl.Series("a", [1.0, 2.0]))

# int, always refers to a row index (zero-based): index=1 => second row
assert df[1].frame_equal(pl.DataFrame({"a": [2.0], "b": [4]}))

# range, refers to rows
assert df[range(1)].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))

# slice. Below an example of taking every second row
assert df[::2].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))

# numpy array; assumed to be row indices if integers, or columns if strings
# TODO: add boolean mask support
df[np.array([1])].frame_equal(pl.DataFrame({"a": [2.0], "b": [4]}))
df[np.array(["a"])].frame_equal(pl.DataFrame({"a": [1.0, 2.0]}))
# note that we cannot use floats (even if they could be casted to integer without loss)
with pytest.raises(IndexError):
_ = df[np.array([1.0])]

# sequences (lists or tuples; tuple only if length != 2)
# if strings or list of expressions, assumed to be column names
# if bools, assumed to be a row mask
# if integers, assumed to be row indices
assert df[["a", "b"]].frame_equal(df)
assert df[[pl.col("a"), pl.col("b")]].frame_equal(df)
df[[1]].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))
df[[False, True]].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))

# pl.Series: like sequences, but only for rows
df[[1]].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))
df[[False, True]].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))
with pytest.raises(IndexError):
_ = df[pl.Series("", ["hello Im a string"])]
10 changes: 5 additions & 5 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,12 @@ def test_arange_expr() -> None:
assert out.select_at_idx(0)[-1] == 19

# eager arange
out = pl.arange(0, 10, 2, eager=True)
assert out == [0, 2, 4, 8, 8]
out2 = pl.arange(0, 10, 2, eager=True)
assert out2 == [0, 2, 4, 8, 8]

out = pl.arange(pl.Series([0, 19]), pl.Series([3, 39]), step=2, eager=True)
assert out.dtype == pl.List
assert out[0].to_list() == [0, 2]
out3 = pl.arange(pl.Series([0, 19]), pl.Series([3, 39]), step=2, eager=True)
assert out3.dtype == pl.List # type: ignore
assert out3[0].to_list() == [0, 2] # type: ignore


def test_round() -> None:
Expand Down

0 comments on commit f0767f4

Please sign in to comment.