Skip to content

Commit

Permalink
python series compare date/datetime
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 20, 2021
1 parent 68eb203 commit f9077e4
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 2 deletions.
56 changes: 55 additions & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
maybe_cast,
py_type_to_dtype,
)
from polars.utils import _ptr_to_numpy
from polars.utils import _date_to_pl_date, _datetime_to_pl_timestamp, _ptr_to_numpy

try:
import pandas as pd
Expand Down Expand Up @@ -290,6 +290,15 @@ def __rxor__(self, other: "Series") -> "Series":
return self.__xor__(other)

def __eq__(self, other: Any) -> "Series": # type: ignore[override]
if isinstance(other, datetime) and self.dtype == Datetime:
ts = _datetime_to_pl_timestamp(other)
f = get_ffi_func("eq_<>", Int64, self._s)
return wrap_s(f(ts)) # type: ignore
if isinstance(other, date) and self.dtype == Date:
d = _date_to_pl_date(other)
f = get_ffi_func("eq_<>", Int32, self._s)
return wrap_s(f(d)) # type: ignore

if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other)
if isinstance(other, Series):
Expand All @@ -301,6 +310,15 @@ def __eq__(self, other: Any) -> "Series": # type: ignore[override]
return wrap_s(f(other))

def __ne__(self, other: Any) -> "Series": # type: ignore[override]
if isinstance(other, datetime) and self.dtype == Datetime:
ts = _datetime_to_pl_timestamp(other)
f = get_ffi_func("neq_<>", Int64, self._s)
return wrap_s(f(ts)) # type: ignore
if isinstance(other, date) and self.dtype == Date:
d = _date_to_pl_date(other)
f = get_ffi_func("neq_<>", Int32, self._s)
return wrap_s(f(d)) # type: ignore

if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other)
if isinstance(other, Series):
Expand All @@ -312,6 +330,15 @@ def __ne__(self, other: Any) -> "Series": # type: ignore[override]
return wrap_s(f(other))

def __gt__(self, other: Any) -> "Series":
if isinstance(other, datetime) and self.dtype == Datetime:
ts = _datetime_to_pl_timestamp(other)
f = get_ffi_func("gt_<>", Int64, self._s)
return wrap_s(f(ts)) # type: ignore
if isinstance(other, date) and self.dtype == Date:
d = _date_to_pl_date(other)
f = get_ffi_func("gt_<>", Int32, self._s)
return wrap_s(f(d)) # type: ignore

if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other)
if isinstance(other, Series):
Expand All @@ -323,6 +350,15 @@ def __gt__(self, other: Any) -> "Series":
return wrap_s(f(other))

def __lt__(self, other: Any) -> "Series":
if isinstance(other, datetime) and self.dtype == Datetime:
ts = _datetime_to_pl_timestamp(other)
f = get_ffi_func("lt_<>", Int64, self._s)
return wrap_s(f(ts)) # type: ignore
if isinstance(other, date) and self.dtype == Date:
d = _date_to_pl_date(other)
f = get_ffi_func("lt_<>", Int32, self._s)
return wrap_s(f(d)) # type: ignore

if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other)
if isinstance(other, Series):
Expand All @@ -335,6 +371,15 @@ def __lt__(self, other: Any) -> "Series":
return wrap_s(f(other))

def __ge__(self, other: Any) -> "Series":
if isinstance(other, datetime) and self.dtype == Datetime:
ts = _datetime_to_pl_timestamp(other)
f = get_ffi_func("gt_eq_<>", Int64, self._s)
return wrap_s(f(ts)) # type: ignore
if isinstance(other, date) and self.dtype == Date:
d = _date_to_pl_date(other)
f = get_ffi_func("gt_eq_<>", Int32, self._s)
return wrap_s(f(d)) # type: ignore

if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other)
if isinstance(other, Series):
Expand All @@ -346,6 +391,15 @@ def __ge__(self, other: Any) -> "Series":
return wrap_s(f(other))

def __le__(self, other: Any) -> "Series":
if isinstance(other, datetime) and self.dtype == Datetime:
ts = _datetime_to_pl_timestamp(other)
f = get_ffi_func("lt_eq_<>", Int64, self._s)
return wrap_s(f(ts)) # type: ignore
if isinstance(other, date) and self.dtype == Date:
d = _date_to_pl_date(other)
f = get_ffi_func("lt_eq_<>", Int32, self._s)
return wrap_s(f(d)) # type: ignore

if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other)
if isinstance(other, Series):
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ctypes
import typing as tp
from datetime import datetime, timedelta, timezone
from datetime import date, datetime, timedelta, timezone
from typing import Any, Dict, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -47,3 +47,8 @@ def _datetime_to_pl_timestamp(dt: datetime) -> int:
Converts a python datetime to a timestamp in nanoseconds
"""
return int(dt.replace(tzinfo=timezone.utc).timestamp() * 1e9)


def _date_to_pl_date(d: date) -> int:
dt = datetime.combine(d, datetime.min.time()).replace(tzinfo=timezone.utc)
return int(dt.timestamp()) // (3600 * 24)
36 changes: 36 additions & 0 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,39 @@ def test_date_range() -> None:
assert result.dt[1] == datetime(1985, 1, 2, 12, 0)
assert result.dt[2] == datetime(1985, 1, 4, 0, 0)
assert result.dt[-1] == datetime(2015, 6, 30, 12, 0)


def test_date_comp() -> None:
one = datetime(2001, 1, 1)
two = datetime(2001, 1, 2)
a = pl.Series("a", [one, two])

assert (a == one).to_list() == [True, False]
assert (a != one).to_list() == [False, True]
assert (a > one).to_list() == [False, True]
assert (a >= one).to_list() == [True, True]
assert (a < one).to_list() == [False, False]
assert (a <= one).to_list() == [True, False]

one = date(2001, 1, 1) # type: ignore
two = date(2001, 1, 2) # type: ignore
a = pl.Series("a", [one, two])
assert (a == one).to_list() == [True, False]
assert (a == two).to_list() == [False, True]
assert (a > one).to_list() == [False, True]
assert (a >= one).to_list() == [True, True]
assert (a < one).to_list() == [False, False]
assert (a <= one).to_list() == [True, False]

# also test if the conversion stays correct with wide date ranges
one = date(201, 1, 1) # type: ignore
two = date(201, 1, 2) # type: ignore
a = pl.Series("a", [one, two])
assert (a == one).to_list() == [True, False]
assert (a == two).to_list() == [False, True]

one = date(5001, 1, 1) # type: ignore
two = date(5001, 1, 2) # type: ignore
a = pl.Series("a", [one, two])
assert (a == one).to_list() == [True, False]
assert (a == two).to_list() == [False, True]

0 comments on commit f9077e4

Please sign in to comment.