Skip to content

Commit

Permalink
Preserve Series types when constructing a list-series
Browse files Browse the repository at this point in the history
  • Loading branch information
nmandery authored and ritchie46 committed Nov 28, 2021
1 parent c3ac948 commit 8c449e6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
6 changes: 5 additions & 1 deletion py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def sequence_to_pyseries(
)
return arrow_to_pyseries(name, pa.array(values))

elif dtype_ == list or dtype_ == tuple or dtype_ == pli.Series:
elif dtype_ == list or dtype_ == tuple:
nested_value = _get_first_non_none(value)
nested_dtype = type(nested_value) if value is not None else float

Expand Down Expand Up @@ -168,6 +168,10 @@ def sequence_to_pyseries(

# Convert mixed sequences like `[[12], "foo", 9]`
return PySeries.new_object(name, values, strict)

elif dtype_ == pli.Series:
return PySeries.new_series_list(name, [v.inner() for v in values], strict)

else:
constructor = py_type_to_constructor(dtype_)
return constructor(name, values, strict)
Expand Down
6 changes: 6 additions & 0 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ impl PySeries {
s.into()
}

#[staticmethod]
pub fn new_series_list(name: &str, val: Vec<Self>, _strict: bool) -> Self {
let series_vec = to_series_collection(val);
Series::new(name, &series_vec).into()
}

#[staticmethod]
pub fn repeat(name: &str, val: &PyAny, n: usize, dtype: &PyAny) -> Self {
let str_repr = dtype.str().unwrap().to_str().unwrap();
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,3 +1064,10 @@ def test_init_categorical() -> None:
expected = pl.Series("a", values, dtype=pl.Utf8).cast(pl.Categorical)
a = pl.Series("a", values, dtype=pl.Categorical)
testing.assert_series_equal(a, expected)


def test_nested_list_types_preserved() -> None:
expected_dtype = pl.UInt32
srs1 = pl.Series([pl.Series([3, 4, 5, 6], dtype=expected_dtype) for _ in range(5)])
for srs2 in srs1:
assert srs2.dtype == expected_dtype

0 comments on commit 8c449e6

Please sign in to comment.