Skip to content

Commit

Permalink
Allow Series.set with boolean numpy mask (#2125)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Dec 22, 2021
1 parent d4ba72e commit f3bc446
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
3 changes: 3 additions & 0 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ def __setitem__(self, key: Any, value: Any) -> None:
elif key.dtype == UInt32:
self._s = self.set_at_idx(key, value)._s
# TODO: implement for these types without casting to series
elif isinstance(key, np.ndarray) and key.dtype == np.bool_:
# boolean numpy mask
self._s = self.set_at_idx(np.argwhere(key)[:, 0], value)._s
elif isinstance(key, (np.ndarray, list, tuple)):
s = wrap_s(PySeries.new_u32("", np.array(key, np.uint32), True))
self.__setitem__(s, value)
Expand Down
24 changes: 24 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import date
from typing import Any, Sequence

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -282,6 +283,29 @@ def test_set() -> None:
a = pl.Series("a", [True, False, True])
mask = pl.Series("msk", [True, False, True])
a[mask] = False
testing.assert_series_equal(a, pl.Series("", [False] * 3))


def test_set_np_array_boolean_mask() -> None:
a = pl.Series("a", [1, 2, 3])
mask = np.array([True, False, True])
a[mask] = 4
testing.assert_series_equal(a, pl.Series("a", [4, 2, 4]))


@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.uint32, np.uint64])
def test_set_np_array(dtype: Any) -> None:
a = pl.Series("a", [1, 2, 3])
idx = np.array([0, 2], dtype=dtype)
a[idx] = 4
testing.assert_series_equal(a, pl.Series("a", [4, 2, 4]))


@pytest.mark.parametrize("idx", [[0, 2], (0, 2)])
def test_set_list_and_tuple(idx: Sequence) -> None:
a = pl.Series("a", [1, 2, 3])
a[idx] = 4
testing.assert_series_equal(a, pl.Series("a", [4, 2, 4]))


def test_fill_null() -> None:
Expand Down

0 comments on commit f3bc446

Please sign in to comment.