Skip to content

Commit

Permalink
Support creation of polars series/dataframes from float16 numpy array…
Browse files Browse the repository at this point in the history
…s. (#3142)
  • Loading branch information
ghuls committed Apr 14, 2022
1 parent 00e5c00 commit b9832de
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
3 changes: 3 additions & 0 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def numpy_to_pyseries(

if len(values.shape) == 1:
dtype = values.dtype.type
if dtype == np.float16:
values = values.astype(np.float32)
dtype = values.dtype.type
constructor = numpy_type_to_constructor(dtype)
if dtype == np.float32 or dtype == np.float64:
return constructor(name, values, nan_to_null)
Expand Down
38 changes: 38 additions & 0 deletions py-polars/tests/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,44 @@
import polars as pl


def test_from_numpy() -> None:
df = pl.DataFrame(
{
"int8": np.array([1, 3, 2], dtype=np.int8),
"int16": np.array([1, 3, 2], dtype=np.int16),
"int32": np.array([1, 3, 2], dtype=np.int32),
"int64": np.array([1, 3, 2], dtype=np.int64),
"uint8": np.array([1, 3, 2], dtype=np.uint8),
"uint16": np.array([1, 3, 2], dtype=np.uint16),
"uint32": np.array([1, 3, 2], dtype=np.uint32),
"uint64": np.array([1, 3, 2], dtype=np.uint64),
"float16": np.array([21.7, 21.8, 21], dtype=np.float16),
"float32": np.array([21.7, 21.8, 21], dtype=np.float32),
"float64": np.array([21.7, 21.8, 21], dtype=np.float64),
"str": np.array(["string1", "string2", "string3"], dtype=np.str_),
"bytes": np.array(
["byte_string1", "byte_string2", "byte_string3"], dtype=np.bytes_
),
}
)
out = [
pl.datatypes.Int8,
pl.datatypes.Int16,
pl.datatypes.Int32,
pl.datatypes.Int64,
pl.datatypes.UInt8,
pl.datatypes.UInt16,
pl.datatypes.UInt32,
pl.datatypes.UInt64,
pl.datatypes.Float32, # np.float16 gets converted to float32 as Rust does not support float16.
pl.datatypes.Float32,
pl.datatypes.Float64,
pl.datatypes.Utf8,
pl.datatypes.Object,
]
assert out == df.dtypes


def test_from_pandas() -> None:
df = pd.DataFrame(
{
Expand Down

0 comments on commit b9832de

Please sign in to comment.