Skip to content

Commit

Permalink
put one/zero lookup behind lru_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Dec 30, 2023
1 parent ac3d0c4 commit 35a246f
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions py-polars/polars/functions/repeat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
from functools import lru_cache
from typing import TYPE_CHECKING, Any, overload

from polars import functions as F
Expand All @@ -27,6 +28,23 @@
from polars.type_aliases import IntoExpr, PolarsDataType


# create a lookup of dtypes that have a reasonable one/zero mapping; for
# anything more elaborate should use `repeat`
@lru_cache(16)
def _one_or_zero_by_dtype(value: int, dtype: PolarsDataType) -> Any:
if dtype in INTEGER_DTYPES:
return value
elif dtype in FLOAT_DTYPES:
return float(value)
elif dtype == Boolean:
return bool(value)
elif dtype == Utf8:
return str(value)
elif isinstance(dtype, List) or (isinstance(dtype, Array) and dtype.width == 1):
return [_one_or_zero_by_dtype(value, dtype.inner)]
return None


@overload
def repeat(
value: IntoExpr | None,
Expand Down Expand Up @@ -127,20 +145,6 @@ def repeat(
return expr


# dtypes that have a reasonable one/zero mapping;
# for anything more elaborate should use `repeat`
_ones_zeros = {
Utf8: ("0", "1"),
Boolean: (False, True),
List(Utf8): (["0"], ["1"]),
List(Boolean): ([False], [True]),
}
for dtype in INTEGER_DTYPES | FLOAT_DTYPES:
_ones_zeros[dtype] = (0, 1)
_ones_zeros[List(dtype)] = ([0], [1])
_ones_zeros[Array(dtype, width=1)] = ([0], [1])


@overload
def ones(
n: int | Expr,
Expand Down Expand Up @@ -214,10 +218,9 @@ def ones(
]
"""
if (one_zero := _ones_zeros.get(dtype)) is None:
if (one := _one_or_zero_by_dtype(1, dtype)) is None:
raise TypeError(f"invalid dtype for `ones`; found {dtype})")

one: Any = one_zero[1]
return repeat(one, n=n, dtype=dtype, eager=eager).alias("ones")


Expand Down Expand Up @@ -294,8 +297,7 @@ def zeros(
]
"""
if (one_zero := _ones_zeros.get(dtype)) is None:
if (zero := _one_or_zero_by_dtype(0, dtype)) is None:
raise TypeError(f"invalid dtype for `zeros`; found {dtype})")

zero: Any = one_zero[0]
return repeat(zero, n=n, dtype=dtype, eager=eager).alias("zeros")

0 comments on commit 35a246f

Please sign in to comment.