Skip to content

Commit

Permalink
fix[python]: Set dtype for empty Series if possible. (#4564)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghuls committed Aug 26, 2022
1 parent 73e7a8e commit 067045c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
11 changes: 9 additions & 2 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,22 @@ def sequence_to_pyseries(
values: Sequence[Any],
dtype: PolarsDataType | None = None,
strict: bool = True,
dtype_if_empty: PolarsDataType | None = None,
) -> PySeries:
"""Construct a PySeries from a sequence."""
python_dtype: type | None = None
nested_dtype: PolarsDataType | type | None = None
temporal_unit: str | None = None

# empty sequence defaults to Float32 type
# empty sequence
if not values and dtype is None:
dtype = Float32
if dtype_if_empty:
# if dtype for empty sequence could be guessed
# (e.g comparisons between self and other)
dtype = dtype_if_empty
else:
# default to Float32 type
dtype = Float32
# lists defer to subsequent handling; identify nested type
elif dtype == List:
nested_dtype = getattr(dtype, "inner", None)
Expand Down
14 changes: 11 additions & 3 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ class Series:
nan_to_null
In case a numpy array is used to create this Series, indicate how to deal
with np.nan values.
dtype_if_empty=dtype_if_empty : DataType, default None
If no dtype is specified and values contains None or an empty list,
set the Polars dtype of the Series data. If not specified, Float32 is used.
Examples
--------
Expand Down Expand Up @@ -193,6 +196,7 @@ def __init__(
dtype: type[DataType] | DataType | None = None,
strict: bool = True,
nan_to_null: bool = False,
dtype_if_empty: type[DataType] | DataType | None = None,
):

# Handle case where values are passed as the first argument
Expand All @@ -207,7 +211,9 @@ def __init__(
name = ""

if values is None:
self._s = sequence_to_pyseries(name, [], dtype=dtype)
self._s = sequence_to_pyseries(
name, [], dtype=dtype, dtype_if_empty=dtype_if_empty
)
elif isinstance(values, Series):
self._s = series_to_pyseries(name, values)
elif _PYARROW_AVAILABLE and isinstance(values, (pa.Array, pa.ChunkedArray)):
Expand All @@ -228,7 +234,9 @@ def __init__(
if dtype is not None:
self._s = self.cast(dtype, strict=True)._s
elif isinstance(values, Sequence):
self._s = sequence_to_pyseries(name, values, dtype=dtype, strict=strict)
self._s = sequence_to_pyseries(
name, values, dtype=dtype, strict=strict, dtype_if_empty=dtype_if_empty
)
elif _PANDAS_AVAILABLE and isinstance(values, (pd.Series, pd.DatetimeIndex)):
self._s = pandas_to_pyseries(name, values)
else:
Expand Down Expand Up @@ -314,7 +322,7 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
return wrap_s(f(d))

if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other)
other = Series("", other, dtype_if_empty=self.dtype)
if isinstance(other, Series):
return wrap_s(getattr(self._s, op)(other._s))
other = maybe_cast(other, self.dtype, self.time_unit)
Expand Down
19 changes: 16 additions & 3 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def test_init_inputs(monkeypatch: Any) -> None:
assert pl.Series(values=[1, 2]).dtype == pl.Int64
assert pl.Series("a").dtype == pl.Float32 # f32 type used in case of no data
assert pl.Series().dtype == pl.Float32
assert pl.Series([]).dtype == pl.Float32
assert pl.Series(dtype_if_empty=pl.Utf8).dtype == pl.Utf8
assert pl.Series([], dtype_if_empty=pl.UInt16).dtype == pl.UInt16
# "== []" will be cast to empty Series with Utf8 dtype.
pl.testing.assert_series_equal(
pl.Series([], dtype_if_empty=pl.Utf8) == [], pl.Series("", dtype=pl.Boolean)
)
assert pl.Series(values=[True, False]).dtype == pl.Boolean
assert pl.Series(values=np.array([True, False])).dtype == pl.Boolean
assert pl.Series(values=np.array(["foo", "bar"])).dtype == pl.Utf8
Expand Down Expand Up @@ -810,13 +817,19 @@ def test_describe() -> None:


def test_is_in() -> None:
s = pl.Series([1, 2, 3])
s = pl.Series(["a", "b", "c"])

out = s.is_in([1, 2])
out = s.is_in(["a", "b"])
assert out == [True, True, False]
df = pl.DataFrame({"a": [1.0, 2.0], "b": [1, 4]})

# Check if empty list is converted to pl.Utf8.
out = s.is_in([])
assert out == [False, False, False]

df = pl.DataFrame({"a": [1.0, 2.0], "b": [1, 4], "c": ["e", "d"]})

assert df.select(pl.col("a").is_in(pl.col("b"))).to_series() == [True, False]
assert df.select(pl.col("b").is_in([])).to_series() == [False, False]


def test_slice() -> None:
Expand Down

0 comments on commit 067045c

Please sign in to comment.