Skip to content

Commit

Permalink
Improve __getitem__ for Dataframe/Series. (#4160)
Browse files Browse the repository at this point in the history
Improve __getitem__ for Dataframe/Series:
  - Allow fast path for retrieving multiple rows when using a
    sequence, numpy array or Series of row indexes.
    (speedup of x6).
  - Support all unsigned and signed integer types for multiple
    row indexing.
  - Allow negative indexes for multiple row indexing.
  - Add more getitem tests and fix old getitem tests that were
    wrong (didn't use assert, so didn't fail before).
  - Support getting columns from dataframe with pl.Utf8 series.
  • Loading branch information
ghuls committed Jul 27, 2022
1 parent a5e863f commit 08dd52c
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 40 deletions.
116 changes: 106 additions & 10 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,21 @@

from polars import internals as pli
from polars._html import NotebookFormatter
from polars.datatypes import Boolean, DataType, UInt32, Utf8, py_type_to_dtype
from polars.datatypes import (
Boolean,
DataType,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Utf8,
get_idx_type,
py_type_to_dtype,
)
from polars.internals.construction import (
ColumnsType,
arrow_to_pydf,
Expand All @@ -41,6 +55,7 @@
deprecated_alias,
format_path,
handle_projection_columns,
is_bool_sequence,
is_int_sequence,
is_str_sequence,
range_to_slice,
Expand Down Expand Up @@ -1638,6 +1653,78 @@ def _pos_idx(self, idx: int, dim: int) -> int:
else:
return self.shape[dim] + idx

def _pos_idxs(self, idxs: np.ndarray | pli.Series, dim: int) -> pli.Series:
# pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx).
idx_type = get_idx_type()

if isinstance(idxs, pli.Series):
if idxs.dtype == idx_type:
return idxs
if idxs.dtype in {
UInt8,
UInt16,
UInt64 if idx_type == UInt32 else UInt32,
Int8,
Int16,
Int32,
Int64,
}:
if idx_type == UInt32:
if idxs.dtype in {Int64, UInt64}:
if idxs.max() >= 2**32: # type: ignore[operator]
raise ValueError(
"Index positions should be smaller than 2^32."
)
if idxs.dtype == Int64:
if idxs.min() < -(2**32): # type: ignore[operator]
raise ValueError(
"Index positions should be bigger than -2^32 + 1."
)
if idxs.dtype in {Int8, Int16, Int32, Int64}:
if idxs.min() < 0: # type: ignore[operator]
if idx_type == UInt32:
if idxs.dtype in {Int8, Int16}:
idxs = idxs.cast(Int32)
else:
if idxs.dtype in {Int8, Int16, Int32}:
idxs = idxs.cast(Int64)

idxs = pli.select(
pli.when(pli.lit(idxs) < 0)
.then(self.shape[dim] + pli.lit(idxs))
.otherwise(pli.lit(idxs))
).to_series()

return idxs.cast(idx_type)

if _NUMPY_AVAILABLE and isinstance(idxs, np.ndarray):
if idxs.ndim != 1:
raise ValueError("Only 1D numpy array is supported as index.")
if idxs.dtype.kind in ("i", "u"):
# Numpy array with signed or unsigned integers.

if idx_type == UInt32:
if idxs.dtype in {np.int64, np.uint64} and idxs.max() >= 2**32:
raise ValueError("Index positions should be smaller than 2^32.")
if idxs.dtype == np.int64 and idxs.min() < -(2**32):
raise ValueError(
"Index positions should be bigger than -2^32 + 1."
)
if idxs.dtype.kind == "i" and idxs.min() < 0:
if idx_type == UInt32:
if idxs.dtype in (np.int8, np.int16):
idxs = idxs.astype(np.int32)
else:
if idxs.dtype in (np.int8, np.int16, np.int32):
idxs = idxs.astype(np.int64)

# Update negative indexes to absolute indexes.
idxs = np.where(idxs < 0, self.shape[dim] + idxs, idxs)

return pli.Series("", idxs, dtype=idx_type)

raise NotImplementedError("Unsupported idxs datatype.")

# __getitem__() mostly returns a dataframe. The major exception is when a string is
# passed in. Note that there are more subtle cases possible where a non-string value
# leads to a Series.
Expand Down Expand Up @@ -1773,11 +1860,16 @@ def __getitem__(
return PolarsSlice(self).apply(item) # type: ignore[return-value]

# select rows by numpy mask or index
# df[[1, 2, 3]]
# df[[true, false, true]]
# df[np.array([1, 2, 3])]
# df[np.array([True, False, True])]
if _NUMPY_AVAILABLE and isinstance(item, np.ndarray):
if item.dtype == int:
return self._from_pydf(self._df.take(item))
if item.ndim != 1:
raise ValueError("Only a 1D-Numpy array is supported as index.")
if item.dtype.kind in ("i", "u"):
# Numpy array with signed or unsigned integers.
return self._from_pydf(
self._df.take_with_series(self._pos_idxs(item, dim=0).inner())
)
if isinstance(item[0], str):
return self._from_pydf(self._df.select(item))
if item.dtype == bool:
Expand All @@ -1795,19 +1887,23 @@ def __getitem__(
return self._from_pydf(self._df.select(item))
elif isinstance(item[0], pli.Expr):
return self.select(item)
elif type(item[0]) == bool:
elif is_bool_sequence(item):
item = pli.Series("", item) # fall through to next if isinstance
elif is_int_sequence(item):
return self._from_pydf(
self._df.take([self._pos_idx(i, dim=0) for i in item])
)
item = pli.Series("", item) # fall through to next if isinstance

if isinstance(item, pli.Series):
dtype = item.dtype
if dtype == Utf8:
return self._from_pydf(self._df.select(item))
if dtype == Boolean:
return self._from_pydf(self._df.filter(item.inner()))
if dtype == UInt32:
return self._from_pydf(self._df.take_with_series(item.inner()))
if dtype in {UInt8, UInt16, UInt64, Int8, Int16, Int32, Int64}:
return self._from_pydf(
self._df.take_with_series(self._pos_idxs(item, dim=0).inner())
)

# if no data has been returned, the operation is not supported
raise NotImplementedError
Expand All @@ -1832,7 +1928,7 @@ def __setitem__(
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
value = np.array(value)
if len(value.shape) != 2:
if value.ndim != 2:
raise ValueError("can only set multiple columns with 2D matrix")
if value.shape[1] != len(key):
raise ValueError(
Expand Down
101 changes: 98 additions & 3 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Utf8,
dtype_to_ctype,
dtype_to_ffiname,
get_idx_type,
maybe_cast,
numpy_char_code_to_dtype,
py_type_to_dtype,
Expand All @@ -48,6 +49,8 @@
_ptr_to_numpy,
_to_python_datetime,
deprecated_alias,
is_bool_sequence,
is_int_sequence,
range_to_slice,
)

Expand Down Expand Up @@ -460,7 +463,79 @@ def __rpow__(self, other: Any) -> Series:
def __neg__(self) -> Series:
return 0 - self

def __getitem__(self, item: int | Series | range | slice) -> Any:
def _pos_idxs(self, idxs: np.ndarray | Series) -> Series:
# pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx).
idx_type = get_idx_type()

if isinstance(idxs, Series):
if idxs.dtype == idx_type:
return idxs
if idxs.dtype in {
UInt8,
UInt16,
UInt64 if idx_type == UInt32 else UInt32,
Int8,
Int16,
Int32,
Int64,
}:
if idx_type == UInt32:
if idxs.dtype in {Int64, UInt64}:
if idxs.max() >= 2**32: # type: ignore[operator]
raise ValueError(
"Index positions should be smaller than 2^32."
)
if idxs.dtype == Int64:
if idxs.min() < -(2**32): # type: ignore[operator]
raise ValueError(
"Index positions should be bigger than -2^32 + 1."
)
if idxs.dtype in {Int8, Int16, Int32, Int64}:
if idxs.min() < 0: # type: ignore[operator]
if idx_type == UInt32:
if idxs.dtype in {Int8, Int16}:
idxs = idxs.cast(Int32)
else:
if idxs.dtype in {Int8, Int16, Int32}:
idxs = idxs.cast(Int64)

idxs = pli.select(
pli.when(pli.lit(idxs) < 0)
.then(self.len() + pli.lit(idxs))
.otherwise(pli.lit(idxs))
).to_series()

return idxs.cast(idx_type)

if _NUMPY_AVAILABLE and isinstance(idxs, np.ndarray):
if idxs.ndim != 1:
raise ValueError("Only 1D numpy array is supported as index.")
if idxs.dtype.kind in ("i", "u"):
# Numpy array with signed or unsigned integers.

if idx_type == UInt32:
if idxs.dtype in {np.int64, np.uint64} and idxs.max() >= 2**32:
raise ValueError("Index positions should be smaller than 2^32.")
if idxs.dtype == np.int64 and idxs.min() < -(2**32):
raise ValueError(
"Index positions should be bigger than -2^32 + 1."
)
if idxs.dtype.kind == "i" and idxs.min() < 0:
if idx_type == UInt32:
if idxs.dtype in (np.int8, np.int16):
idxs = idxs.astype(np.int32)
else:
if idxs.dtype in (np.int8, np.int16, np.int32):
idxs = idxs.astype(np.int64)

# Update negative indexes to absolute indexes.
idxs = np.where(idxs < 0, self.len() + idxs, idxs)

return Series("", idxs, dtype=idx_type)

raise NotImplementedError("Unsupported idxs datatype.")

def __getitem__(self, item: int | Series | range | slice | np.ndarray) -> Any:
if isinstance(item, int):
if item < 0:
item = self.len() + item
Expand All @@ -476,9 +551,29 @@ def __getitem__(self, item: int | Series | range | slice) -> Any:
return out

return self._s.get_idx(item)
# assume it is boolean mask

if _NUMPY_AVAILABLE and isinstance(item, np.ndarray):
if item.ndim != 1:
raise ValueError("Only a 1D-Numpy array is supported as index.")
if item.dtype.kind in ("i", "u"):
# Numpy array with signed or unsigned integers.
return wrap_s(self._s.take_with_series(self._pos_idxs(item).inner()))
if item.dtype == bool:
return wrap_s(self._s.filter(pli.Series("", item).inner()))

if isinstance(item, Sequence):
if is_bool_sequence(item):
item = Series("", item) # fall through to next if isinstance
elif is_int_sequence(item):
item = Series("", item) # fall through to next if isinstance

if isinstance(item, Series):
return wrap_s(self._s.filter(item._s))
if item.dtype == Boolean:
return wrap_s(self._s.filter(item._s))
if item.dtype == UInt32:
return wrap_s(self._s.take_with_series(item.inner()))
if item.dtype in {UInt8, UInt16, UInt64, Int8, Int16, Int32, Int64}:
return wrap_s(self._s.take_with_series(self._pos_idxs(item).inner()))

if isinstance(item, range):
return self[range_to_slice(item)]
Expand Down
25 changes: 15 additions & 10 deletions py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ def _date_to_pl_date(d: date) -> int:
return int(dt.timestamp()) // (3600 * 24)


def _is_iterable_of(val: Iterable, itertype: type, eltype: type) -> bool:
"""Check whether the given iterable is of a certain type."""
return isinstance(val, itertype) and all(isinstance(x, eltype) for x in val)


def is_bool_sequence(val: Sequence[object]) -> TypeGuard[Sequence[bool]]:
"""Check whether the given sequence is a sequence of booleans."""
return _is_iterable_of(val, Sequence, bool)


def is_int_sequence(val: Sequence[object]) -> TypeGuard[Sequence[int]]:
"""Check whether the given sequence is a sequence of integers."""
return _is_iterable_of(val, Sequence, int)


def is_str_sequence(
val: Sequence[object], allow_str: bool = False
) -> TypeGuard[Sequence[str]]:
Expand All @@ -131,16 +146,6 @@ def is_str_sequence(
return _is_iterable_of(val, Sequence, str)


def is_int_sequence(val: Sequence[object]) -> TypeGuard[Sequence[int]]:
"""Check whether the given sequence is a sequence of integers."""
return _is_iterable_of(val, Sequence, int)


def _is_iterable_of(val: Iterable, itertype: type, eltype: type) -> bool:
"""Check whether the given iterable is of a certain type."""
return isinstance(val, itertype) and all(isinstance(x, eltype) for x in val)


def range_to_slice(rng: range) -> slice:
"""Return the given range as an equivalent slice."""
return slice(rng.start, rng.stop, rng.step)
Expand Down

0 comments on commit 08dd52c

Please sign in to comment.