Skip to content

Commit

Permalink
feat[python]: improve frame init from @dataclass and namedtuple r…
Browse files Browse the repository at this point in the history
…ows (#4807)
  • Loading branch information
alexander-beedie committed Sep 12, 2022
1 parent 40f0320 commit af2d161
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 34 deletions.
96 changes: 71 additions & 25 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import ctypes
import sys
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from inspect import isclass
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -32,11 +34,28 @@
except ImportError:
_DOCUMENTING = True

UnionType: type
if sys.version_info >= (3, 10):
from types import UnionType
else:
# infer equivalent class
UnionType = type(Union[int, float])

if sys.version_info >= (3, 8):
from typing import get_args
else:

# pass-through (only impact is that under 3.7 we'll end-up doing
# standard inference for dataclass fields with an option/union)
def get_args(tp: Any) -> Any:
return tp


if TYPE_CHECKING:
from polars.internals.type_aliases import TimeUnit


def get_idx_type() -> type[DataType]:
def get_idx_type() -> PolarsDataType:
"""
Get the datatype used for polars Indexing.
Expand Down Expand Up @@ -77,13 +96,16 @@ def __repr__(self) -> str:
return dtype_str_repr(self)


# note: defined this way as some types can have instances that
# act as specialisations (eg: "List" and "List[Int32]")
PolarsDataType = Union[Type[DataType], DataType]

ColumnsType = Union[
Sequence[str],
Mapping[str, PolarsDataType],
Sequence[Tuple[str, Optional[PolarsDataType]]],
]
NoneType = type(None)


class Int8(DataType):
Expand Down Expand Up @@ -143,7 +165,9 @@ class Unknown(DataType):


class List(DataType):
def __init__(self, inner: type[DataType]):
inner: PolarsDataType | None = None

def __init__(self, inner: PolarsDataType):
"""
Nested list/array type.
Expand All @@ -155,7 +179,7 @@ def __init__(self, inner: type[DataType]):
"""
self.inner = py_type_to_dtype(inner)

def __eq__(self, other: type[DataType]) -> bool: # type: ignore[override]
def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
# The comparison allows comparing objects to classes
# and specific inner types to none specific.
# if one of the arguments is not specific about its inner type
Expand All @@ -177,7 +201,7 @@ def __eq__(self, other: type[DataType]) -> bool: # type: ignore[override]
return False

def __hash__(self) -> int:
return hash(List)
return hash((List, self.inner))


class Date(DataType):
Expand All @@ -203,7 +227,7 @@ def __init__(self, time_unit: TimeUnit = "us", time_zone: str | None = None):
self.tu = time_unit
self.tz = time_zone

def __eq__(self, other: type[DataType]) -> bool: # type: ignore[override]
def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
# allow comparing object instances to class
if type(other) is type and issubclass(other, Datetime):
return True
Expand Down Expand Up @@ -231,7 +255,7 @@ def __init__(self, time_unit: TimeUnit = "us"):
"""
self.tu = time_unit

def __eq__(self, other: type[DataType]) -> bool: # type: ignore[override]
def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
# allow comparing object instances to class
if type(other) is type and issubclass(other, Duration):
return True
Expand All @@ -257,7 +281,7 @@ class Categorical(DataType):


class Field:
def __init__(self, name: str, dtype: type[DataType]):
def __init__(self, name: str, dtype: PolarsDataType):
"""
Definition of a single field within a `Struct` DataType.
Expand Down Expand Up @@ -296,19 +320,17 @@ def __init__(self, fields: Sequence[Field]):
"""
self.fields = fields

def __eq__(self, other: type[DataType]) -> bool: # type: ignore[override]
# The comparison allows comparing objects to classes
# and specific inner types to none specific.
# if one of the arguments is not specific about its inner type
# we infer it as being equal.
# See the list type for more info
if type(other) is type and issubclass(other, Struct):
def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
# The comparison allows comparing objects to classes, and specific
# inner types to those without (eg: inner=None). if one of the
# arguments is not specific about its inner type we infer it
# as being equal. (See the List type for more info).
if isclass(other) and issubclass(other, Struct):
return True
if isinstance(other, Struct):
if self.fields is None or other.fields is None:
return True
else:
return self.fields == other.fields
elif isinstance(other, Struct):
return any((f is None) for f in (self.fields, other.fields)) or (
self.fields == other.fields
)
else:
return False

Expand Down Expand Up @@ -366,7 +388,7 @@ def __hash__(self) -> int:
_DTYPE_TO_CTYPE[Duration(tu)] = ctypes.c_int64


_PY_TYPE_TO_DTYPE: dict[type, type[DataType]] = {
_PY_TYPE_TO_DTYPE: dict[type, PolarsDataType] = {
float: Float64,
int: Int64,
str: Utf8,
Expand All @@ -377,6 +399,7 @@ def __hash__(self) -> int:
time: Time,
list: List,
tuple: List,
Decimal: Float64,
}


Expand Down Expand Up @@ -462,8 +485,16 @@ def __hash__(self) -> int:
}


def _lookup_type(dtype: PolarsDataType) -> PolarsDataType:
"""Normalise type so it can be looked-up correctly."""
# (currently only List requires this, due to arbitrary 'inner')
return List if dtype == List else dtype


def dtype_to_ctype(dtype: PolarsDataType) -> type[_SimpleCData]:
"""Convert a Polars dtype to a ctype."""
try:
dtype = _lookup_type(dtype)
return _DTYPE_TO_CTYPE[dtype]
except KeyError: # pragma: no cover
raise NotImplementedError(
Expand All @@ -472,7 +503,9 @@ def dtype_to_ctype(dtype: PolarsDataType) -> type[_SimpleCData]:


def dtype_to_ffiname(dtype: PolarsDataType) -> str:
"""Return FFI function name associated with the given Polars dtype."""
try:
dtype = _lookup_type(dtype)
return _DTYPE_TO_FFINAME[dtype]
except KeyError: # pragma: no cover
raise NotImplementedError(
Expand All @@ -481,7 +514,9 @@ def dtype_to_ffiname(dtype: PolarsDataType) -> str:


def dtype_to_py_type(dtype: PolarsDataType) -> type:
"""Convert a Polars dtype to a Python dtype."""
try:
dtype = _lookup_type(dtype)
return _DTYPE_TO_PY_TYPE[dtype]
except KeyError: # pragma: no cover
raise NotImplementedError(
Expand All @@ -490,18 +525,28 @@ def dtype_to_py_type(dtype: PolarsDataType) -> type:


def is_polars_dtype(data_type: Any) -> bool:
"""Indicate whether the given input is a Polars dtype, or dtype specialisation."""
return isinstance(data_type, DataType) or (
type(data_type) is type and issubclass(data_type, DataType)
)


def py_type_to_dtype(data_type: Any) -> PolarsDataType:
def py_type_to_dtype(data_type: Any, raise_unmatched: bool = True) -> PolarsDataType:
"""Convert a Python dtype to a Polars dtype."""
# when the passed in is already a Polars datatype, return that
if is_polars_dtype(data_type):
return data_type
elif isinstance(data_type, UnionType):
# not exhaustive; currently handles the common "type | None" case,
# but ideally would pick appropriate supertype when n_types > 1
possible_types = [tp for tp in get_args(data_type) if tp is not NoneType]
if len(possible_types) == 1:
data_type = possible_types[0]
try:
return _PY_TYPE_TO_DTYPE[data_type]
except KeyError: # pragma: no cover
if not raise_unmatched:
return None # type: ignore[return-value]
raise NotImplementedError(
f"Conversion of Python data type {data_type} to Polars data type not"
" implemented."
Expand Down Expand Up @@ -532,7 +577,8 @@ def supported_numpy_char_code(dtype: str) -> bool:
return dtype in _NUMPY_CHAR_CODE_TO_DTYPE


def numpy_char_code_to_dtype(dtype: str) -> type[DataType]:
def numpy_char_code_to_dtype(dtype: str) -> PolarsDataType:
"""Convert a numpy character dtype to a Polars dtype."""
try:
return _NUMPY_CHAR_CODE_TO_DTYPE[dtype]
except KeyError: # pragma: no cover
Expand All @@ -542,8 +588,8 @@ def numpy_char_code_to_dtype(dtype: str) -> type[DataType]:


def maybe_cast(
el: type[DataType], dtype: type, time_unit: TimeUnit | None = None
) -> type[DataType]:
el: PolarsDataType, dtype: type, time_unit: TimeUnit | None = None
) -> PolarsDataType:
# cast el if it doesn't match
from polars.utils import _datetime_to_pl_timestamp, _timedelta_to_pl_timedelta

Expand All @@ -558,4 +604,4 @@ def maybe_cast(


#: Mapping of `~polars.DataFrame` / `~polars.LazyFrame` column names to their `DataType`
Schema = Dict[str, Type[DataType]]
Schema = Dict[str, PolarsDataType]
37 changes: 34 additions & 3 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from contextlib import suppress
from dataclasses import astuple, is_dataclass
from datetime import date, datetime, time, timedelta
from itertools import zip_longest
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
from sys import version_info
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, get_type_hints

from polars import internals as pli
from polars.datatypes import (
Expand All @@ -30,6 +32,17 @@
)
from polars.utils import threadpool_size

if version_info >= (3, 10):

def dataclass_type_hints(obj: type) -> dict[str, Any]:
return get_type_hints(obj)

else:

def dataclass_type_hints(obj: type) -> dict[str, Any]:
return obj.__annotations__


try:
from polars.polars import PyDataFrame, PySeries

Expand Down Expand Up @@ -545,7 +558,16 @@ def sequence_to_pydf(
return pydf

elif isinstance(data[0], Sequence) and not isinstance(data[0], str):
# Infer orientation
# infer orientation
if all(
hasattr(data[0], attr)
for attr in ("_fields", "_field_defaults", "_replace")
): # namedtuple
if columns is None:
columns = data[0]._fields # type: ignore[attr-defined]
elif orient is None:
orient = "row"

if orient is None and columns is not None:
orient = "col" if len(columns) == len(data) else "row"

Expand All @@ -564,7 +586,16 @@ def sequence_to_pydf(
raise ValueError(
f"orient must be one of {{'col', 'row', None}}, got {orient} instead."
)

elif is_dataclass(data[0]):
columns = columns or [
(col, py_type_to_dtype(tp, raise_unmatched=False))
for col, tp in dataclass_type_hints(data[0].__class__).items()
]
pydf = _post_apply_columns(
PyDataFrame.read_rows([astuple(dc) for dc in data], infer_schema_length),
columns=columns,
)
return pydf
else:
columns, dtypes = _unpack_columns(columns, n_expected=1)
data_series = [pli.Series(columns[0], data, dtypes.get(columns[0]))._s]
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/internals/series/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Any, Callable, TypeVar

import polars.internals as pli
from polars.datatypes import DataType, dtype_to_ffiname
from polars.datatypes import PolarsDataType, dtype_to_ffiname

if TYPE_CHECKING:
from polars.polars import PySeries
Expand Down Expand Up @@ -126,7 +126,7 @@ def expr_dispatch(cls: type[T]) -> type[T]:


def get_ffi_func(
name: str, dtype: type[DataType], obj: PySeries
name: str, dtype: PolarsDataType, obj: PySeries
) -> Callable[..., Any] | None:
"""
Dynamically obtain the proper ffi function/ method.
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/slow/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ def test_csv_statistics_offset() -> None:
# this would fail if the statistics sample did not also sample
# from the end of the file
# the lines at the end have larger rows as the numbers increase
csv = "\n".join([str(x) for x in range(5_000)])
csv = "\n".join(str(x) for x in range(5_000))
assert pl.read_csv(io.StringIO(csv), n_rows=5000).height == 4999
18 changes: 16 additions & 2 deletions py-polars/tests/unit/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_dtype_temporal_units() -> None:
assert pl.Datetime == pl.Datetime(tu)
assert pl.Duration == pl.Duration(tu)

assert pl.Datetime(tu) == pl.Datetime() # type: ignore[operator]
assert pl.Duration(tu) == pl.Duration() # type: ignore[operator]
assert pl.Datetime(tu) == pl.Datetime()
assert pl.Duration(tu) == pl.Duration()

assert pl.Datetime("ms") != pl.Datetime("ns")
assert pl.Duration("ns") != pl.Duration("us")
Expand All @@ -40,3 +40,17 @@ def test_dtypes_picklable() -> None:
singleton_type = pl.Float64
assert pickle.loads(pickle.dumps(parametric_type)) == parametric_type
assert pickle.loads(pickle.dumps(singleton_type)) == singleton_type


def test_dtypes_hashable() -> None:
# ensure that all the types can be hashed, and that their hashes
# are sufficient to ensure distinct entries in a dictionary/set

all_dtypes = [
getattr(datatypes, d)
for d in dir(datatypes)
if isinstance(getattr(datatypes, d), datatypes.DataType)
]
assert len(set(all_dtypes + all_dtypes)) == len(all_dtypes)
assert len({pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")}) == 3
assert len({pl.List, pl.List(pl.Int16), pl.List(pl.Int32), pl.List(pl.Int64)}) == 4

0 comments on commit af2d161

Please sign in to comment.