Skip to content

Commit

Permalink
feat(python): support Series init from generator (#5411)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Nov 3, 2022
1 parent 6669a84 commit 39e061c
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 17 deletions.
53 changes: 51 additions & 2 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
from contextlib import suppress
from dataclasses import astuple, is_dataclass
from datetime import date, datetime, time, timedelta
from itertools import zip_longest
from itertools import islice, zip_longest
from sys import version_info
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, get_type_hints
from typing import (
TYPE_CHECKING,
Any,
Generator,
Iterable,
Mapping,
Sequence,
get_type_hints,
)

from polars import internals as pli
from polars.datatypes import (
Expand Down Expand Up @@ -178,6 +186,47 @@ def sequence_from_anyvalue_or_object(name: str, values: Sequence[Any]) -> PySeri
return PySeries.new_object(name, values, False)


def iterable_to_pyseries(
name: str,
values: Iterable[Any],
dtype: PolarsDataType | None = None,
strict: bool = True,
dtype_if_empty: PolarsDataType | None = None,
chunk_size: int = 1_000_000,
) -> PySeries:
"""Construct a PySeries from an iterable/generator."""
if not isinstance(values, Generator):
values = iter(values)

def to_series_chunk(values: list[Any], dtype: PolarsDataType | None) -> pli.Series:
return pli.Series(
name=name,
values=values,
dtype=dtype,
strict=strict,
dtype_if_empty=dtype_if_empty,
)

n_chunks = 0
series: pli.Series = None # type: ignore[assignment]
while True:
slice_values = list(islice(values, chunk_size))
if not slice_values:
break
schunk = to_series_chunk(slice_values, dtype)
if series is None:
series = schunk
dtype = series.dtype
else:
series.append(schunk, append_chunks=True)
n_chunks += 1

if n_chunks > 0:
series.rechunk(in_place=True)

return series._s


def sequence_to_pyseries(
name: str,
values: Sequence[Any],
Expand Down
23 changes: 16 additions & 7 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,7 @@ def arange(
step: int = ...,
*,
eager: Literal[True],
dtype: PolarsDataType | None = ...,
) -> pli.Series:
...

Expand All @@ -1351,6 +1352,7 @@ def arange(
step: int = ...,
*,
eager: bool = False,
dtype: PolarsDataType | None = ...,
) -> pli.Expr | pli.Series:
...

Expand All @@ -1361,13 +1363,13 @@ def arange(
step: int = 1,
*,
eager: bool = False,
dtype: PolarsDataType | None = None,
) -> pli.Expr | pli.Series:
"""
Create a range expression.
Create a range expression (or Series).
This can be used in a `select`, `with_column` etc.
Be sure that the range size is equal to the DataFrame you are collecting.
This can be used in a `select`, `with_column` etc. Be sure that the resulting
range size is equal to the length of the DataFrame you are collecting.
Examples
--------
Expand All @@ -1383,18 +1385,25 @@ def arange(
Step size of the range.
eager
If eager evaluation is `True`, a Series is returned instead of an Expr.
dtype
Apply an explicit integer dtype to the resulting expression (default is Int64).
"""
low = pli.expr_to_lit_or_expr(low, str_to_lit=False)
high = pli.expr_to_lit_or_expr(high, str_to_lit=False)
if eager:
range_expr = pli.wrap_expr(pyarange(low._pyexpr, high._pyexpr, step))

if dtype is not None and dtype != Int64:
range_expr = range_expr.cast(dtype)
if not eager:
return range_expr
else:
return (
pli.DataFrame()
.select(arange(low, high, step))
.select(range_expr)
.to_series()
.rename("arange", in_place=True)
)
return pli.wrap_expr(pyarange(low._pyexpr, high._pyexpr, step))


def argsort_by(
Expand Down
33 changes: 31 additions & 2 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
import math
import warnings
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Sequence, Union, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Iterable,
Mapping,
NoReturn,
Sequence,
Union,
overload,
)
from warnings import warn

from polars import internals as pli
Expand Down Expand Up @@ -46,6 +57,7 @@
from polars.dependencies import pyarrow as pa
from polars.internals.construction import (
arrow_to_pyseries,
iterable_to_pyseries,
numpy_to_pyseries,
pandas_to_pyseries,
sequence_to_pyseries,
Expand Down Expand Up @@ -228,7 +240,13 @@ def __init__(

elif isinstance(values, range):
self._s = (
pli.arange(values.start, values.stop, values.step, eager=True)
pli.arange(
low=values.start,
high=values.stop,
step=values.step,
eager=True,
dtype=dtype,
)
.rename(name, in_place=True)
._s
)
Expand All @@ -238,6 +256,17 @@ def __init__(
)
elif _PANDAS_TYPE(values) and isinstance(values, (pd.Series, pd.DatetimeIndex)):
self._s = pandas_to_pyseries(name, values)

elif isinstance(values, (Generator, Iterable)) and not isinstance(
values, Mapping
):
self._s = iterable_to_pyseries(
name,
values,
dtype=dtype,
strict=strict,
dtype_if_empty=dtype_if_empty,
)
else:
raise ValueError(f"Series constructor not called properly. Got {values}.")

Expand Down
45 changes: 39 additions & 6 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Iterator, cast

import numpy as np
import pandas as pd
Expand All @@ -22,6 +22,7 @@
UInt32,
UInt64,
)
from polars.internals.construction import iterable_to_pyseries
from polars.internals.type_aliases import EpochTimeUnit
from polars.testing import assert_frame_equal, assert_series_equal
from polars.testing._private import verify_series_and_expr_api
Expand Down Expand Up @@ -107,11 +108,6 @@ def test_init_inputs(monkeypatch: Any) -> None:
with pytest.raises(OverflowError):
pl.Series("bigint", [2**64])

# numpy not available
monkeypatch.setattr(pl.internals.series.series, "_NUMPY_TYPE", lambda x: False)
with pytest.raises(ValueError):
pl.DataFrame(np.array([1, 2, 3]), columns=["a"])


def test_init_dataclass_namedtuple() -> None:
from dataclasses import dataclass
Expand Down Expand Up @@ -1379,6 +1375,43 @@ def test_to_numpy(monkeypatch: Any) -> None:
assert np_array_with_missing_values.flags.writeable == writable


def test_from_generator_or_iterable() -> None:
# iterable object
class Data:
def __init__(self, n: int):
self._n = n

def __iter__(self) -> Iterator[int]:
yield from range(self._n)

# generator function
def gen(n: int) -> Iterator[int]:
yield from range(n)

expected = pl.Series("s", range(10))
assert expected.dtype == pl.Int64

for generated_series in (
pl.Series("s", values=gen(10)),
pl.Series("s", values=Data(10)),
pl.Series("s", values=(x for x in gen(10))),
):
assert_series_equal(expected, generated_series)

# test 'iterable_to_pyseries' directly to validate 'chunk_size' behaviour
ps1 = iterable_to_pyseries("s", gen(10), dtype=pl.UInt8)
ps2 = iterable_to_pyseries("s", gen(10), dtype=pl.UInt8, chunk_size=3)
ps3 = iterable_to_pyseries("s", Data(10), dtype=pl.UInt8, chunk_size=6)

expected = pl.Series("s", range(10), dtype=pl.UInt8)
assert expected.dtype == pl.UInt8

for ps in (ps1, ps2, ps3):
generated_series = pl.Series("s")
generated_series._s = ps
assert_series_equal(expected, generated_series)


def test_from_sequences(monkeypatch: Any) -> None:
# test int, str, bool, flt
values = [
Expand Down

0 comments on commit 39e061c

Please sign in to comment.