Skip to content

Commit

Permalink
feat(python): extend existing fast range->Series init to lists of ran…
Browse files Browse the repository at this point in the history
…ges in a Series (#6099)
  • Loading branch information
alexander-beedie committed Jan 7, 2023
1 parent a78664c commit 4e13a81
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 23 deletions.
12 changes: 9 additions & 3 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from polars.dependencies import pandas as pd
from polars.dependencies import pyarrow as pa
from polars.exceptions import ShapeError
from polars.utils import _is_generator, arrlen, threadpool_size
from polars.utils import _is_generator, arrlen, range_to_series, threadpool_size

if version_info >= (3, 10):

Expand Down Expand Up @@ -269,6 +269,8 @@ def sequence_to_pyseries(
if value is not None:
if is_dataclass(value) or is_namedtuple(value, annotated=True):
return pli.DataFrame(values).to_struct(name)._s
elif isinstance(value, range):
values = [range_to_series("", v) for v in values]
else:
# for temporal dtypes:
# * if the values are integer, we take the physical branch.
Expand Down Expand Up @@ -319,7 +321,12 @@ def sequence_to_pyseries(
elif python_dtype in (list, tuple):
if nested_dtype is None:
nested_value = _get_first_non_none(value)
nested_dtype = type(nested_value) if nested_value is not None else float
if isinstance(nested_value, range):
nested_dtype = list
else:
nested_dtype = (
type(nested_value) if nested_value is not None else float
)

# recursively call Series constructor
if nested_dtype == list:
Expand Down Expand Up @@ -391,7 +398,6 @@ def sequence_to_pyseries(
return PySeries.new_series_list(name, values, strict)
else:
constructor = py_type_to_constructor(python_dtype)

if constructor == PySeries.new_object:
try:
return PySeries.new_from_anyvalues(name, values)
Expand Down
24 changes: 8 additions & 16 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
_time_to_pl_time,
is_bool_sequence,
is_int_sequence,
range_to_series,
range_to_slice,
scale_bytes,
sphinx_accessor,
Expand Down Expand Up @@ -215,17 +216,17 @@ def __init__(
raise ValueError(
f"Given dtype: '{dtype}' is not a valid Polars data type and cannot be converted into one." # noqa: E501
)

# Handle case where values are passed as the first argument
if name is not None and not isinstance(name, str):
if name is None:
name = ""
elif not isinstance(name, str):
if values is None:
values = name
name = None
name = ""
else:
raise ValueError("Series name must be a string.")

if name is None:
name = ""

if values is None:
self._s = sequence_to_pyseries(
name, [], dtype=dtype, dtype_if_empty=dtype_if_empty
Expand All @@ -234,17 +235,8 @@ def __init__(
self._s = series_to_pyseries(name, values)

elif isinstance(values, range):
self._s = (
pli.arange(
low=values.start,
high=values.stop,
step=values.step,
eager=True,
dtype=dtype,
)
.rename(name, in_place=True)
._s
)
self._s = range_to_series(name, values, dtype=dtype)._s

elif isinstance(values, Sequence):
self._s = sequence_to_pyseries(
name,
Expand Down
22 changes: 21 additions & 1 deletion py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
)

import polars.internals as pli
from polars.datatypes import DataType, Date, Datetime, PolarsDataType, is_polars_dtype
from polars.datatypes import (
DataType,
Date,
Datetime,
Int64,
PolarsDataType,
is_polars_dtype,
)
from polars.dependencies import _ZONEINFO_AVAILABLE, zoneinfo

try:
Expand Down Expand Up @@ -170,6 +177,19 @@ def is_str_sequence(
return isinstance(val, Sequence) and _is_iterable_of(val, str)


def range_to_series(
name: str, rng: range, dtype: PolarsDataType | None = Int64
) -> pli.Series:
"""Fast conversion of the given range to a Series."""
return pli.arange(
low=rng.start,
high=rng.stop,
step=rng.step,
eager=True,
dtype=dtype,
).rename(name, in_place=True)


def range_to_slice(rng: range) -> slice:
"""Return the given range as an equivalent slice."""
return slice(rng.start, rng.stop, rng.step)
Expand Down
22 changes: 19 additions & 3 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,9 +1317,25 @@ def test_sqrt() -> None:


def test_range() -> None:
s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0])
assert s[2:5].series_equal(s[range(2, 5)])
df = pl.DataFrame([s])
s1 = pl.Series("a", [1, 2, 3, 2, 2, 3, 0])
assert s1[2:5].series_equal(s1[range(2, 5)])

ranges = [range(-2, 1), range(3), range(2, 8, 2)]

s2 = pl.Series("b", ranges, dtype=pl.List(pl.Int8))
assert s2.to_list() == [[-2, -1, 0], [0, 1, 2], [2, 4, 6]]
assert s2.dtype == pl.List(pl.Int8)
assert s2.name == "b"

s3 = pl.Series("c", (ranges for _ in range(3)))
assert s3.to_list() == [
[[-2, -1, 0], [0, 1, 2], [2, 4, 6]],
[[-2, -1, 0], [0, 1, 2], [2, 4, 6]],
[[-2, -1, 0], [0, 1, 2], [2, 4, 6]],
]
assert s3.dtype == pl.List(pl.List(pl.Int64))

df = pl.DataFrame([s1])
assert df[2:5].frame_equal(df[range(2, 5)])


Expand Down

0 comments on commit 4e13a81

Please sign in to comment.