Skip to content

Commit

Permalink
Merge pull request #68 from pola-rs/new-polars-compat
Browse files Browse the repository at this point in the history
update for new polars version
  • Loading branch information
MarcoGorelli committed Mar 21, 2024
2 parents 6e86dc5 + 88461db commit fb3e9c9
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 24 deletions.
22 changes: 21 additions & 1 deletion .github/workflows/CI.yml
Expand Up @@ -32,7 +32,7 @@ jobs:
strategy:
matrix:
target: [x86_64]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand All @@ -47,6 +47,26 @@ jobs:
- run: make install
- run: make test

linux_min_version_tests:
runs-on: ubuntu-latest
strategy:
matrix:
target: [x86_64]
python-version: ["3.8"]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Set up Rust
run: rustup show
- uses: mozilla-actions/sccache-action@v0.0.3
- run: make venv
- run: venv/bin/python -m pip install polars==0.20.6 # min version
- run: make install
- run: make test

linux:
runs-on: ubuntu-latest
strategy:
Expand Down
50 changes: 28 additions & 22 deletions polars_xdt/functions.py
Expand Up @@ -3,12 +3,12 @@
import re
import sys
from datetime import date, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Sequence

import polars as pl
from polars.utils.udfs import _get_shared_lib_location

from polars_xdt.utils import parse_into_expr
from polars_xdt.utils import parse_into_expr, parse_version, register_plugin

if sys.version_info >= (3, 10):
from typing import TypeAlias
Expand All @@ -21,7 +21,12 @@
RollStrategy: TypeAlias = Literal["raise", "forward", "backward"]


lib = _get_shared_lib_location(__file__)
if parse_version(pl.__version__) < parse_version("0.20.16"):
from polars.utils.udfs import _get_shared_lib_location

lib: str | Path = _get_shared_lib_location(__file__)
else:
lib = Path(__file__).parent

mapping = {"Mon": 1, "Tue": 2, "Wed": 3, "Thu": 4, "Fri": 5, "Sat": 6, "Sun": 7}
reverse_mapping = {value: key for key, value in mapping.items()}
Expand Down Expand Up @@ -174,11 +179,11 @@ def offset_by(
)
weekmask = get_weekmask(weekend)

result = expr.register_plugin(
result = register_plugin(
args=[expr, n],
lib=lib,
symbol="advance_n_days",
is_elementwise=True,
args=[n],
kwargs={
"holidays": holidays_int,
"weekmask": weekmask,
Expand Down Expand Up @@ -248,11 +253,11 @@ def is_workday(
holidays_int = sorted(
{(holiday - date(1970, 1, 1)).days for holiday in holidays},
)
return expr.register_plugin(
return register_plugin(
lib=lib,
symbol="is_workday",
is_elementwise=True,
args=[],
args=[expr],
kwargs={
"weekmask": weekmask,
"holidays": holidays_int,
Expand Down Expand Up @@ -329,11 +334,11 @@ def from_local_datetime(
"""
expr = parse_into_expr(expr)
from_tz = parse_into_expr(from_tz, str_as_lit=True)
return expr.register_plugin(
return register_plugin(
lib=lib,
symbol="from_local_datetime",
is_elementwise=True,
args=[from_tz],
args=[expr, from_tz],
kwargs={
"to_tz": to_tz,
"ambiguous": ambiguous,
Expand Down Expand Up @@ -396,11 +401,11 @@ def to_local_datetime(
"""
expr = parse_into_expr(expr)
time_zone = parse_into_expr(time_zone, str_as_lit=True)
return expr.register_plugin(
return register_plugin(
lib=lib,
symbol="to_local_datetime",
is_elementwise=True,
args=[time_zone],
args=[expr, time_zone],
)


Expand Down Expand Up @@ -454,11 +459,11 @@ def format_localized(
"""
expr = parse_into_expr(expr)
return expr.register_plugin(
return register_plugin(
lib=lib,
symbol="format_localized",
is_elementwise=True,
args=[],
args=[expr],
kwargs={"format": format, "locale": locale},
)

Expand Down Expand Up @@ -493,11 +498,11 @@ def to_julian_date(expr: str | pl.Expr) -> pl.Expr:
"""
expr = parse_into_expr(expr)
return expr.register_plugin(
return register_plugin(
lib=lib,
symbol="to_julian_date",
is_elementwise=True,
args=[],
args=[expr],
)


Expand Down Expand Up @@ -721,11 +726,11 @@ def workday_count(
holidays_int = sorted(
{(holiday - date(1970, 1, 1)).days for holiday in holidays},
)
return start_dates.register_plugin(
return register_plugin(
lib=lib,
symbol="workday_count",
is_elementwise=True,
args=[end_dates],
args=[start_dates, end_dates],
kwargs={
"weekmask": weekmask,
"holidays": holidays_int,
Expand Down Expand Up @@ -794,11 +799,11 @@ def month_delta(
start_dates = parse_into_expr(start_dates)
end_dates = parse_into_expr(end_dates)

return start_dates.register_plugin(
return register_plugin(
lib=lib,
symbol="month_delta",
is_elementwise=True,
args=[end_dates],
args=[start_dates, end_dates],
)


Expand Down Expand Up @@ -891,10 +896,11 @@ def arg_previous_greater(expr: IntoExpr) -> pl.Expr:
"""
expr = parse_into_expr(expr)
return expr.register_plugin(
return register_plugin(
lib=lib,
symbol="arg_previous_greater",
is_elementwise=False,
args=[expr],
)


Expand Down Expand Up @@ -974,10 +980,10 @@ def ewma_by_time(
half_life_us = (
int(half_life.total_seconds()) * 1_000_000 + half_life.microseconds
)
return times.register_plugin(
return register_plugin(
lib=lib,
symbol="ewma_by_time",
is_elementwise=False,
args=[values],
args=[times, values],
kwargs={"half_life": half_life_us},
)
41 changes: 40 additions & 1 deletion polars_xdt/utils.py
@@ -1,10 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import re
from typing import TYPE_CHECKING, Any, Sequence

import polars as pl

if TYPE_CHECKING:
from pathlib import Path

from polars.type_aliases import IntoExpr, PolarsDataType


Expand Down Expand Up @@ -47,3 +50,39 @@ def parse_into_expr(
expr = pl.lit(expr, dtype=dtype)

return expr


def register_plugin(
*,
lib: str | Path,
symbol: str,
is_elementwise: bool,
kwargs: dict[str, Any] | None = None,
args: list[IntoExpr],
) -> pl.Expr:
if parse_version(pl.__version__) < parse_version("0.20.16"):
assert isinstance(args[0], pl.Expr)
assert isinstance(lib, str)
return args[0].register_plugin(
lib=lib,
symbol=symbol,
args=args[1:],
kwargs=kwargs,
is_elementwise=is_elementwise,
)
from polars.plugins import register_plugin_function

return register_plugin_function(
args=args,
plugin_path=lib,
function_name=symbol,
kwargs=kwargs,
is_elementwise=is_elementwise,
)


def parse_version(version: Sequence[str | int]) -> tuple[int, ...]:
# Simple version parser; split into a tuple of ints for comparison.
if isinstance(version, str):
version = version.split(".")
return tuple(int(re.sub(r"\D", "", str(v))) for v in version)

0 comments on commit fb3e9c9

Please sign in to comment.