Skip to content

Commit

Permalink
Wrap all dtype dicts (#1845)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Nov 21, 2021
1 parent 6aca48d commit a9347cc
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 27 deletions.
4 changes: 2 additions & 2 deletions py-polars/polars/_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import TracebackType
from typing import Dict, Iterable, Optional, Type

from polars.datatypes import DTYPE_TO_FFINAME, Object
from polars.datatypes import Object, dtype_to_ffiname


class Tag:
Expand Down Expand Up @@ -75,7 +75,7 @@ def write_header(self) -> None:
self.elements.append(col)
with Tag(self.elements, "tr"):
for dtype in self.df.dtypes:
ffi_name = DTYPE_TO_FFINAME[dtype]
ffi_name = dtype_to_ffiname(dtype)
with Tag(self.elements, "td"):
self.elements.append(ffi_name)

Expand Down
27 changes: 21 additions & 6 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Categorical(DataType):
Object,
Categorical,
]
DTYPE_TO_FFINAME: Dict[Type[DataType], str] = {
_DTYPE_TO_FFINAME: Dict[Type[DataType], str] = {
Int8: "i8",
Int16: "i16",
Int32: "i32",
Expand All @@ -130,7 +130,7 @@ class Categorical(DataType):
Categorical: "categorical",
}

DTYPE_TO_CTYPE = {
_DTYPE_TO_CTYPE = {
UInt8: ctypes.c_uint8,
UInt16: ctypes.c_uint16,
UInt32: ctypes.c_uint32,
Expand Down Expand Up @@ -192,7 +192,21 @@ def date_like_to_physical(dtype: Type[DataType]) -> Type[DataType]:

def dtype_to_ctype(dtype: Type[DataType]) -> Type[_SimpleCData]:
try:
return DTYPE_TO_CTYPE[dtype]
return _DTYPE_TO_CTYPE[dtype]
except KeyError:
raise NotImplementedError


def dtype_to_ffiname(dtype: Type[DataType]) -> str:
try:
return _DTYPE_TO_FFINAME[dtype]
except KeyError:
raise NotImplementedError


def dtype_to_py_type(dtype: Type[DataType]) -> Type:
try:
return _DTYPE_TO_PY_TYPE[dtype]
except KeyError:
raise NotImplementedError

Expand All @@ -218,8 +232,9 @@ def py_type_to_arrow_type(dtype: Type[Any]) -> "pa.lib.DataType":
raise ValueError(f"Cannot parse dtype {dtype} into Arrow dtype.")


def _maybe_cast(el: Type[DataType], dtype: Type) -> Type[DataType]:
def maybe_cast(el: Type[DataType], dtype: Type) -> Type[DataType]:
# cast el if it doesn't match
if not isinstance(el, _DTYPE_TO_PY_TYPE[dtype]):
el = _DTYPE_TO_PY_TYPE[dtype](el)
py_type = dtype_to_py_type(dtype)
if not isinstance(el, py_type):
el = py_type(el)
return el
34 changes: 17 additions & 17 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
_DOCUMENTING = True

from polars.datatypes import (
DTYPE_TO_FFINAME,
DTYPES,
Boolean,
DataType,
Expand All @@ -48,9 +47,10 @@
UInt32,
UInt64,
Utf8,
_maybe_cast,
date_like_to_physical,
dtype_to_ctype,
dtype_to_ffiname,
maybe_cast,
py_type_to_dtype,
)
from polars.utils import _ptr_to_numpy
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_ffi_func(
-------
ffi function
"""
ffi_name = DTYPE_TO_FFINAME[dtype]
ffi_name = dtype_to_ffiname(dtype)
fname = name.replace("<>", ffi_name)
if obj:
return getattr(obj, fname, default)
Expand Down Expand Up @@ -296,7 +296,7 @@ def __eq__(self, other: Any) -> "Series": # type: ignore[override]
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.eq(other._s))
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
f = get_ffi_func("eq_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -307,7 +307,7 @@ def __ne__(self, other: Any) -> "Series": # type: ignore[override]
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.neq(other._s))
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
f = get_ffi_func("neq_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -318,7 +318,7 @@ def __gt__(self, other: Any) -> "Series":
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.gt(other._s))
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
f = get_ffi_func("gt_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -330,7 +330,7 @@ def __lt__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.lt(other._s))
# cast other if it doesn't match
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
f = get_ffi_func("lt_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -341,7 +341,7 @@ def __ge__(self, other: Any) -> "Series":
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.gt_eq(other._s))
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
f = get_ffi_func("gt_eq_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -352,7 +352,7 @@ def __le__(self, other: Any) -> "Series":
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.lt_eq(other._s))
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
f = get_ffi_func("lt_eq_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -363,7 +363,7 @@ def __add__(self, other: Any) -> "Series":
other = Series("", [other])
if isinstance(other, Series):
return wrap_s(self._s.add(other._s))
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("add_<>", dtype, self._s)
if f is None:
Expand All @@ -373,7 +373,7 @@ def __add__(self, other: Any) -> "Series":
def __sub__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.sub(other._s))
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("sub_<>", dtype, self._s)
if f is None:
Expand All @@ -388,7 +388,7 @@ def __truediv__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.div(other._s))

other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("div_<>", dtype, self._s)
return wrap_s(f(other))
Expand All @@ -403,7 +403,7 @@ def __floordiv__(self, other: Any) -> "Series":
return Series._from_pyseries(self._s.div(other._s)).floor()
return Series._from_pyseries(self._s.div(other._s))

other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("div_<>", dtype, self._s)
if self.is_float():
Expand All @@ -413,7 +413,7 @@ def __floordiv__(self, other: Any) -> "Series":
def __mul__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.mul(other._s))
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("mul_<>", dtype, self._s)
if f is None:
Expand All @@ -423,7 +423,7 @@ def __mul__(self, other: Any) -> "Series":
def __mod__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.rem(other._s))
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("rem_<>", dtype, self._s)
if f is None:
Expand All @@ -434,7 +434,7 @@ def __rmod__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(other._s.rem(self._s))
dtype = date_like_to_physical(self.dtype)
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
other = match_dtype(other, dtype)
f = get_ffi_func("rem_<>_rhs", dtype, self._s)
if f is None:
Expand All @@ -445,7 +445,7 @@ def __radd__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.add(other._s))
dtype = date_like_to_physical(self.dtype)
other = _maybe_cast(other, self.dtype)
other = maybe_cast(other, self.dtype)
other = match_dtype(other, dtype)
f = get_ffi_func("add_<>_rhs", dtype, self._s)
if f is None:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any

from polars.datatypes import (
_DTYPE_TO_PY_TYPE,
Boolean,
Float32,
Float64,
Expand All @@ -13,6 +12,7 @@
UInt32,
UInt64,
Utf8,
dtype_to_py_type,
)
from polars.internals import Series

Expand Down Expand Up @@ -56,7 +56,7 @@ def assert_series_equal(
if left.name != right.name:
raise_assert_detail(obj, "Name mismatch", left.name, right.name)

_can_be_subtracted = hasattr(_DTYPE_TO_PY_TYPE[left.dtype], "__sub__")
_can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__")
if check_exact or not _can_be_subtracted:
if any((left != right).to_list()):
raise_assert_detail(
Expand Down

0 comments on commit a9347cc

Please sign in to comment.