Skip to content

Commit

Permalink
Only pass dtype to __array__, if not None: Fixes #3253 (#3257)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghuls committed Apr 29, 2022
1 parent d1c3235 commit deec56d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
5 changes: 4 additions & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,7 +2060,10 @@ def view(self, ignore_nulls: bool = False) -> np.ndarray:
return array

def __array__(self, dtype: Any = None) -> np.ndarray:
return self.to_numpy().__array__(dtype)
if dtype:
return self.to_numpy().__array__(dtype)
else:
return self.to_numpy().__array__()

def __array_ufunc__(
self, ufunc: Callable[..., Any], method: str, *inputs: Any, **kwargs: Any
Expand Down
36 changes: 35 additions & 1 deletion py-polars/tests/test_interop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Dict, Sequence, Union
from typing import Dict, Sequence, Type, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -47,6 +47,40 @@ def test_from_numpy() -> None:
assert out == df.dtypes


def test_to_numpy() -> None:
def test_series_to_numpy(
name: str,
values: list,
pl_dtype: Type[pl.DataType],
np_dtype: Union[
Type[np.signedinteger],
Type[np.unsignedinteger],
Type[np.floating],
Type[np.object_],
],
) -> None:
pl_series_to_numpy_array = np.array(pl.Series(name, values, pl_dtype))
numpy_array = np.array(values, dtype=np_dtype)
assert pl_series_to_numpy_array.dtype == numpy_array.dtype
assert np.all(pl_series_to_numpy_array == numpy_array) == np.bool_(True)

test_series_to_numpy("int8", [1, 3, 2], pl.Int8, np.int8)
test_series_to_numpy("int16", [1, 3, 2], pl.Int16, np.int16)
test_series_to_numpy("int32", [1, 3, 2], pl.Int32, np.int32)
test_series_to_numpy("int64", [1, 3, 2], pl.Int64, np.int64)

test_series_to_numpy("uint8", [1, 3, 2], pl.UInt8, np.uint8)
test_series_to_numpy("uint16", [1, 3, 2], pl.UInt16, np.uint16)
test_series_to_numpy("uint32", [1, 3, 2], pl.UInt32, np.uint32)
test_series_to_numpy("uint64", [1, 3, 2], pl.UInt64, np.uint64)

test_series_to_numpy("float32", [21.7, 21.8, 21], pl.Float32, np.float32)
test_series_to_numpy("float64", [21.7, 21.8, 21], pl.Float64, np.float64)

test_series_to_numpy("str", ["string1", "string2", "string3"], pl.Utf8, np.object_)
# test_series_to_numpy("bytes", ["byte_string1", "byte_string2", "byte_string3"], pl.Object, np.bytes_)


def test_from_pandas() -> None:
df = pd.DataFrame(
{
Expand Down

0 comments on commit deec56d

Please sign in to comment.