Skip to content

Commit

Permalink
feat[python]: Optionally return an writable numpy array with Series()…
Browse files Browse the repository at this point in the history
….to_numpy(writable=True) (#4563) (#4566)
  • Loading branch information
ghuls committed Aug 29, 2022
1 parent 560f87e commit 2a3514a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
26 changes: 19 additions & 7 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,7 +2361,7 @@ def __array_ufunc__(
)

def to_numpy(
self, *args: Any, zero_copy_only: bool = False, **kwargs: Any
self, *args: Any, zero_copy_only: bool = False, writable: bool = False
) -> np.ndarray[Any, Any]:
"""
Convert this Series to numpy. This operation clones data but is completely safe.
Expand Down Expand Up @@ -2390,6 +2390,11 @@ def to_numpy(
If True, an exception will be raised if the conversion to a numpy
array would require copying the underlying data (e.g. in presence
of nulls, or for non-primitive types).
writable
For numpy arrays created with zero copy (view on the Arrow data),
the resulting array is not writable (Arrow data is immutable).
By setting this to True, a copy of the array is made to ensure
it is writable.
kwargs
kwargs will be sent to pyarrow.Array.to_numpy
Expand All @@ -2406,16 +2411,23 @@ def convert_to_date(arr: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:

if _PYARROW_AVAILABLE and not self.is_datelike():
return self.to_arrow().to_numpy(
*args, zero_copy_only=zero_copy_only, **kwargs
*args, zero_copy_only=zero_copy_only, writable=writable
)
else:
if not self.has_validity():
if self.is_datelike():
return convert_to_date(self.view(ignore_nulls=True))
return self.view(ignore_nulls=True)
if self.is_datelike():
return convert_to_date(self._s.to_numpy())
return self._s.to_numpy()
np_array = convert_to_date(self.view(ignore_nulls=True))
else:
np_array = self.view(ignore_nulls=True)
elif self.is_datelike():
np_array = convert_to_date(self._s.to_numpy())
else:
np_array = self._s.to_numpy()

if writable and not np_array.flags.writeable:
return np_array.copy()
else:
return np_array

def to_arrow(self) -> pa.Array:
"""
Expand Down
37 changes: 32 additions & 5 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,11 +1129,38 @@ def test_bitwise() -> None:


def test_to_numpy(monkeypatch: Any) -> None:
monkeypatch.setattr(pl.internals.series.series, "_PYARROW_AVAILABLE", False)
a = pl.Series("a", [1, 2, 3])
assert np.all(a.to_numpy() == np.array([1, 2, 3]))
a = pl.Series("a", [1, 2, None])
np.testing.assert_array_equal(a.to_numpy(), np.array([1.0, 2.0, np.nan]))
for writable in [False, True]:
for flag in [False, True]:
monkeypatch.setattr(pl.internals.series.series, "_PYARROW_AVAILABLE", flag)

np_array = pl.Series("a", [1, 2, 3], pl.UInt8).to_numpy(writable=writable)

np.testing.assert_array_equal(np_array, np.array([1, 2, 3], dtype=np.uint8))
# Test if numpy array is readonly or writable.
assert np_array.flags.writeable == writable

if writable:
np_array[1] += 10
np.testing.assert_array_equal(
np_array, np.array([1, 12, 3], dtype=np.uint8)
)

np_array_with_missing_values = pl.Series(
"a", [None, 2, 3], pl.UInt8
).to_numpy(writable=writable)

np.testing.assert_array_equal(
np_array_with_missing_values,
np.array(
[np.NaN, 2.0, 3.0],
dtype=(np.float64 if flag is True else np.float32),
),
)

if writable:
# As Null values can't be encoded natively in a numpy array,
# this array will never be a view.
assert np_array_with_missing_values.flags.writeable == writable


def test_from_sequences(monkeypatch: Any) -> None:
Expand Down

0 comments on commit 2a3514a

Please sign in to comment.