Skip to content

Commit

Permalink
Simplify get_ffi_func (#1861)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Nov 22, 2021
1 parent 62557ee commit 446debf
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ def match_dtype(value: Any, dtype: "Type[DataType]") -> Any:
def get_ffi_func(
name: str,
dtype: Type["DataType"],
obj: Optional["Series"] = None,
default: Optional[Callable[[Any], Any]] = None,
) -> Callable[..., Any]:
obj: "PySeries",
) -> Optional[Callable[..., Any]]:
"""
Dynamically obtain the proper ffi function/ method.
Expand All @@ -91,20 +90,15 @@ def get_ffi_func(
dtype
polars dtype.
obj
Optional object to find the method for. If none provided globals are used.
default
default function to use if not found.
Object to find the method for.
Returns
-------
ffi function
ffi function, or None if not found
"""
ffi_name = dtype_to_ffiname(dtype)
fname = name.replace("<>", ffi_name)
if obj:
return getattr(obj, fname, default)
else:
return globals().get(fname, default)
return getattr(obj, fname, None)


def wrap_s(s: "PySeries") -> "Series":
Expand Down Expand Up @@ -391,6 +385,8 @@ def __truediv__(self, other: Any) -> "Series":
other = maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("div_<>", dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))

if self.dtype != physical_type:
Expand All @@ -406,6 +402,8 @@ def __floordiv__(self, other: Any) -> "Series":
other = maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("div_<>", dtype, self._s)
if f is None:
return NotImplemented
if self.is_float():
return wrap_s(f(other)).floor()
return wrap_s(f(other))
Expand Down Expand Up @@ -1879,13 +1877,17 @@ def __array_ufunc__(

try:
f = get_ffi_func("apply_ufunc_<>", dtype, self._s)
if f is None:
return NotImplemented
series = f(lambda out: ufunc(*args, out=out, **kwargs))
return wrap_s(series)
except TypeError:
# some integer to float ufuncs do not work, try on f64
s = self.cast(Float64)
args[0] = s.view(ignore_nulls=True)
f = get_ffi_func("apply_ufunc_<>", Float64, self._s)
if f is None:
return NotImplemented
series = f(lambda out: ufunc(*args, out=out, **kwargs))
return wrap_s(series)

Expand Down

0 comments on commit 446debf

Please sign in to comment.