Skip to content

Commit

Permalink
feat[python]: df.__getitem__ allow boolean sequence in column position (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 22, 2022
1 parent fbcbc90 commit 55125e9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
24 changes: 23 additions & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from polars import internals as pli
from polars._html import NotebookFormatter
from polars.datatypes import (
Boolean,
ColumnsType,
DataType,
Int8,
Expand Down Expand Up @@ -50,6 +51,7 @@
_process_null_values,
format_path,
handle_projection_columns,
is_bool_sequence,
is_int_sequence,
is_str_sequence,
range_to_slice,
Expand Down Expand Up @@ -123,7 +125,9 @@
# MultiColSelector indexes into the horizontal axis
# NOTE: wrapping these as strings is necessary for Python <3.10
MultiRowSelector: TypeAlias = "slice | range | list[int] | pli.Series"
MultiColSelector: TypeAlias = "slice | range | list[int] | list[str] | pli.Series"
MultiColSelector: TypeAlias = (
"slice | range | list[int] | list[str] | list[bool] | pli.Series"
)

# A type variable used to refer to a polars.DataFrame or any subclass of it.
# Used to annotate DataFrame methods which returns the same type as self.
Expand Down Expand Up @@ -1779,6 +1783,24 @@ def __getitem__(
df = self.__getitem__(self.columns[col_selection])
return df[row_selection]

# df[:, [True, False]]
if is_bool_sequence(col_selection) or (
isinstance(col_selection, pli.Series)
and col_selection.dtype == Boolean
):
if len(col_selection) != self.width:
raise ValueError(
f"Expected {self.width} values when selecting columns by"
f" boolean mask. Got {len(col_selection)}."
)
series_list = []
for (i, val) in enumerate(col_selection):
if val:
series_list.append(self.to_series(i))

df = self.__class__(series_list)
return df[row_selection]

# single slice
# df[:, unknown]
series = self.__getitem__(col_selection)
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,19 @@ def test_selection() -> None:
expect = pl.DataFrame({"a": [1, 3], "b": [1.0, 3.0], "c": ["a", "c"]})
assert df[::2].frame_equal(expect)

# only allow boolean values in column position
df = pl.DataFrame(
{
"a": [1, 2],
"b": [2, 3],
"c": [3, 4],
}
)

assert df[:, [False, True, True]].columns == ["b", "c"]
assert df[:, pl.Series([False, True, True])].columns == ["b", "c"]
assert df[:, pl.Series([False, False, False])].columns == []


def test_mixed_sequence_selection() -> None:
df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
Expand Down

0 comments on commit 55125e9

Please sign in to comment.