Skip to content

Commit

Permalink
feat[python]: Improve numpy ufunc dtype support in expressions (#3228,
Browse files Browse the repository at this point in the history
…#5012) (#5017)

As numpy ufunc dtype detection for series was improved/fixed in #3583,
direct dispatch can be done from the numpy ufunc implementation
on an expression to the numpy ufunc implementation on a series,
without the need to infer the dtype there.

Overriding the output dtype for numpy ufunc on series was not working,
and is fixed too.
  • Loading branch information
ghuls committed Sep 28, 2022
1 parent 6eb71ce commit b2d0c29
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
11 changes: 1 addition & 10 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from polars.datatypes import (
DataType,
Datetime,
Float64,
PolarsDataType,
UInt32,
is_polars_dtype,
Expand Down Expand Up @@ -239,21 +238,13 @@ def __array_ufunc__(
"""Numpy universal functions."""
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
out_type = ufunc(np.array([1])).dtype
dtype: type[DataType] | None
if "float" in str(out_type):
dtype = Float64
else:
dtype = None

args = [inp for inp in inputs if not isinstance(inp, Expr)]

def function(s: pli.Series) -> pli.Series: # pragma: no cover
return ufunc(s, *args, **kwargs)

dtype = kwargs.get("dtype", dtype)

return self.map(function, return_dtype=dtype)
return self.map(function)

def __getstate__(self) -> Any:
return self._pyexpr.__getstate__()
Expand Down
11 changes: 5 additions & 6 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,22 +772,21 @@ def __array_ufunc__(
break

# Override minimum dtype if requested.
dtype = (
dtype_char = (
np.dtype(kwargs.pop("dtype")).char
if "dtype" in kwargs
else dtype_char_minimum
)

f = get_ffi_func(
"apply_ufunc_<>", numpy_char_code_to_dtype(dtype_char_minimum), s
)
f = get_ffi_func("apply_ufunc_<>", numpy_char_code_to_dtype(dtype_char), s)

if f is None:
raise NotImplementedError(
f"Could not find `apply_ufunc_{numpy_char_code_to_dtype(dtype)}`."
"Could not find "
f"`apply_ufunc_{numpy_char_code_to_dtype(dtype_char)}`."
)

series = f(lambda out: ufunc(*args, out=out, **kwargs))
series = f(lambda out: ufunc(*args, out=out, dtype=dtype_char, **kwargs))
return wrap_s(series)
else:
raise NotImplementedError(
Expand Down
20 changes: 17 additions & 3 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,15 +891,29 @@ def test_arithmetic() -> None:


def test_ufunc() -> None:
df = pl.DataFrame({"a": [1, 2]})
# NOTE: unfortunately we must use cast instead of a type: ignore comment
# 1. CI job with Python 3.10, numpy==1.23.1 -> mypy complains about arg-type
# 2. so we try to resolve it with type: ignore[arg-type]
# 3. CI job with Python 3.7, numpy==1.21.6 -> mypy complains about
# unused type: ignore comment
# for more information, see: https://github.com/python/mypy/issues/8823
out = df.select(np.log(cast(Any, col("a"))))
assert out["a"][1] == 0.6931471805599453
df = pl.DataFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)])
out = df.select(
[
np.power(cast(Any, pl.col("a")), 2).alias("power_uint8"),
np.power(cast(Any, pl.col("a")), 2.0).alias("power_float64"),
np.power(cast(Any, pl.col("a")), 2, dtype=np.uint16).alias("power_uint16"),
]
)
expected = pl.DataFrame(
[
pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8),
pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16),
]
)
assert out.frame_equal(expected)
assert out.dtypes == expected.dtypes


def test_clip() -> None:
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,10 @@ def test_ufunc() -> None:
cast(pl.Series, np.power(s_uint8, 2.0)),
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
)
assert_series_equal(
cast(pl.Series, np.power(s_uint8, 2, dtype=np.uint16)),
pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt16),
)

s_int8 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int8)
assert_series_equal(
Expand All @@ -467,6 +471,10 @@ def test_ufunc() -> None:
cast(pl.Series, np.power(s_int8, 2.0)),
pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),
)
assert_series_equal(
cast(pl.Series, np.power(s_int8, 2, dtype=np.int16)),
pl.Series("a", [1, 4, 9, 16], dtype=pl.Int16),
)

s_uint32 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt32)
assert_series_equal(
Expand Down

0 comments on commit b2d0c29

Please sign in to comment.