Skip to content

Commit

Permalink
fix[python, rust]: allow expr+series as well as series+expr, addr…
Browse files Browse the repository at this point in the history
…ess some micros/millis timeunit debt (#4988)
  • Loading branch information
alexander-beedie committed Sep 26, 2022
1 parent 0885f56 commit 589f364
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 63 deletions.
59 changes: 33 additions & 26 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ pub struct DatetimeArgs {
pub hour: Option<Expr>,
pub minute: Option<Expr>,
pub second: Option<Expr>,
pub millisecond: Option<Expr>,
pub microsecond: Option<Expr>,
}

#[cfg(feature = "temporal")]
Expand All @@ -414,7 +414,7 @@ pub fn datetime(args: DatetimeArgs) -> Expr {
let hour = args.hour;
let minute = args.minute;
let second = args.second;
let millisecond = args.millisecond;
let microsecond = args.microsecond;

let function = SpecialEq::new(Arc::new(move |s: &mut [Series]| {
assert_eq!(s.len(), 7);
Expand Down Expand Up @@ -452,11 +452,11 @@ pub fn datetime(args: DatetimeArgs) -> Expr {
}
let second = second.u32()?;

let mut millisecond = s[6].cast(&DataType::UInt32)?;
if millisecond.len() < max_len {
millisecond = millisecond.expand_at_index(0, max_len);
let mut microsecond = s[6].cast(&DataType::UInt32)?;
if microsecond.len() < max_len {
microsecond = microsecond.expand_at_index(0, max_len);
}
let millisecond = millisecond.u32()?;
let microsecond = microsecond.u32()?;

let ca: Int64Chunked = year
.into_iter()
Expand All @@ -465,24 +465,25 @@ pub fn datetime(args: DatetimeArgs) -> Expr {
.zip(hour.into_iter())
.zip(minute.into_iter())
.zip(second.into_iter())
.zip(millisecond.into_iter())
.map(|((((((y, m), d), h), mnt), s), ms)| {
if let (Some(y), Some(m), Some(d), Some(h), Some(mnt), Some(s), Some(ms)) =
(y, m, d, h, mnt, s, ms)
.zip(microsecond.into_iter())
.map(|((((((y, m), d), h), mnt), s), us)| {
if let (Some(y), Some(m), Some(d), Some(h), Some(mnt), Some(s), Some(us)) =
(y, m, d, h, mnt, s, us)
{
Some(
NaiveDate::from_ymd(y, m, d)
.and_hms_milli(h, mnt, s, ms)
.timestamp_millis(),
.and_hms_micro(h, mnt, s, us)
.timestamp_micros(),
)
} else {
None
}
})
.collect_trusted();

Ok(ca.into_datetime(TimeUnit::Milliseconds, None).into_series())
Ok(ca.into_datetime(TimeUnit::Microseconds, None).into_series())
}) as Arc<dyn SeriesUdf>);

Expr::AnonymousFunction {
input: vec![
year,
Expand All @@ -491,10 +492,10 @@ pub fn datetime(args: DatetimeArgs) -> Expr {
hour.unwrap_or_else(|| lit(0)),
minute.unwrap_or_else(|| lit(0)),
second.unwrap_or_else(|| lit(0)),
millisecond.unwrap_or_else(|| lit(0)),
microsecond.unwrap_or_else(|| lit(0)),
],
function,
output_type: GetOutput::from_type(DataType::Datetime(TimeUnit::Milliseconds, None)),
output_type: GetOutput::from_type(DataType::Datetime(TimeUnit::Microseconds, None)),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
Expand All @@ -510,6 +511,7 @@ pub struct DurationArgs {
pub days: Option<Expr>,
pub seconds: Option<Expr>,
pub nanoseconds: Option<Expr>,
pub microseconds: Option<Expr>,
pub milliseconds: Option<Expr>,
pub minutes: Option<Expr>,
pub hours: Option<Expr>,
Expand All @@ -519,14 +521,15 @@ pub struct DurationArgs {
#[cfg(feature = "temporal")]
pub fn duration(args: DurationArgs) -> Expr {
let function = SpecialEq::new(Arc::new(move |s: &mut [Series]| {
assert_eq!(s.len(), 7);
assert_eq!(s.len(), 8);
let days = s[0].cast(&DataType::Int64).unwrap();
let seconds = s[1].cast(&DataType::Int64).unwrap();
let mut nanoseconds = s[2].cast(&DataType::Int64).unwrap();
let milliseconds = s[3].cast(&DataType::Int64).unwrap();
let minutes = s[4].cast(&DataType::Int64).unwrap();
let hours = s[5].cast(&DataType::Int64).unwrap();
let weeks = s[6].cast(&DataType::Int64).unwrap();
let microseconds = s[3].cast(&DataType::Int64).unwrap();
let milliseconds = s[4].cast(&DataType::Int64).unwrap();
let minutes = s[5].cast(&DataType::Int64).unwrap();
let hours = s[6].cast(&DataType::Int64).unwrap();
let weeks = s[7].cast(&DataType::Int64).unwrap();

let max_len = s.iter().map(|s| s.len()).max().unwrap();

Expand All @@ -538,14 +541,17 @@ pub fn duration(args: DurationArgs) -> Expr {
if nanoseconds.len() != max_len {
nanoseconds = nanoseconds.expand_at_index(0, max_len);
}
if condition(&days) {
nanoseconds = nanoseconds + days * NANOSECONDS * SECONDS_IN_DAY;
if condition(&microseconds) {
nanoseconds = nanoseconds + (microseconds * 1_000);
}
if condition(&milliseconds) {
nanoseconds = nanoseconds + (milliseconds * 1_000_000);
}
if condition(&seconds) {
nanoseconds = nanoseconds + &seconds * NANOSECONDS;
nanoseconds = nanoseconds + (seconds * NANOSECONDS);
}
if condition(&milliseconds) {
nanoseconds = nanoseconds + milliseconds * 1_000_000;
if condition(&days) {
nanoseconds = nanoseconds + (days * NANOSECONDS * SECONDS_IN_DAY);
}
if condition(&minutes) {
nanoseconds = nanoseconds + minutes * NANOSECONDS * 60;
Expand All @@ -565,13 +571,14 @@ pub fn duration(args: DurationArgs) -> Expr {
args.days.unwrap_or_else(|| lit(0i64)),
args.seconds.unwrap_or_else(|| lit(0i64)),
args.nanoseconds.unwrap_or_else(|| lit(0i64)),
args.microseconds.unwrap_or_else(|| lit(0i64)),
args.milliseconds.unwrap_or_else(|| lit(0i64)),
args.minutes.unwrap_or_else(|| lit(0i64)),
args.hours.unwrap_or_else(|| lit(0i64)),
args.weeks.unwrap_or_else(|| lit(0i64)),
],
function,
output_type: GetOutput::from_type(DataType::Datetime(TimeUnit::Milliseconds, None)),
output_type: GetOutput::from_type(DataType::Datetime(TimeUnit::Microseconds, None)),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/logical_plan/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ impl Literal for NaiveDateTime {
if in_nanoseconds_window(&self) {
Expr::Literal(LiteralValue::DateTime(self, TimeUnit::Nanoseconds))
} else {
Expr::Literal(LiteralValue::DateTime(self, TimeUnit::Milliseconds))
Expr::Literal(LiteralValue::DateTime(self, TimeUnit::Microseconds))
}
}
}
Expand Down
66 changes: 40 additions & 26 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,13 +1223,14 @@ def argsort_by(

def duration(
*,
days: pli.Expr | str | None = None,
seconds: pli.Expr | str | None = None,
nanoseconds: pli.Expr | str | None = None,
milliseconds: pli.Expr | str | None = None,
minutes: pli.Expr | str | None = None,
hours: pli.Expr | str | None = None,
weeks: pli.Expr | str | None = None,
days: pli.Expr | str | int | None = None,
seconds: pli.Expr | str | int | None = None,
nanoseconds: pli.Expr | str | int | None = None,
microseconds: pli.Expr | str | int | None = None,
milliseconds: pli.Expr | str | int | None = None,
minutes: pli.Expr | str | int | None = None,
hours: pli.Expr | str | int | None = None,
weeks: pli.Expr | str | int | None = None,
) -> pli.Expr:
"""
Create polars `Duration` from distinct time components.
Expand Down Expand Up @@ -1281,25 +1282,37 @@ def duration(
seconds = pli.expr_to_lit_or_expr(seconds, str_to_lit=False)._pyexpr
if milliseconds is not None:
milliseconds = pli.expr_to_lit_or_expr(milliseconds, str_to_lit=False)._pyexpr
if microseconds is not None:
microseconds = pli.expr_to_lit_or_expr(microseconds, str_to_lit=False)._pyexpr
if nanoseconds is not None:
nanoseconds = pli.expr_to_lit_or_expr(nanoseconds, str_to_lit=False)._pyexpr
if days is not None:
days = pli.expr_to_lit_or_expr(days, str_to_lit=False)._pyexpr
if weeks is not None:
weeks = pli.expr_to_lit_or_expr(weeks, str_to_lit=False)._pyexpr

return pli.wrap_expr(
py_duration(days, seconds, nanoseconds, milliseconds, minutes, hours, weeks)
py_duration(
days,
seconds,
nanoseconds,
microseconds,
milliseconds,
minutes,
hours,
weeks,
)
)


def _datetime(
year: pli.Expr | str,
month: pli.Expr | str,
day: pli.Expr | str,
hour: pli.Expr | str | None = None,
minute: pli.Expr | str | None = None,
second: pli.Expr | str | None = None,
millisecond: pli.Expr | str | None = None,
year: pli.Expr | str | int,
month: pli.Expr | str | int,
day: pli.Expr | str | int,
hour: pli.Expr | str | int | None = None,
minute: pli.Expr | str | int | None = None,
second: pli.Expr | str | int | None = None,
microsecond: pli.Expr | str | int | None = None,
) -> pli.Expr:
"""
Create polars `Datetime` from distinct time components.
Expand All @@ -1313,13 +1326,13 @@ def _datetime(
day
column or literal, ranging from 1-31.
hour
column or literal, ranging from 1-24.
column or literal, ranging from 1-23.
minute
column or literal, ranging from 1-60.
column or literal, ranging from 1-59.
second
column or literal, ranging from 1-60.
millisecond
column or literal, ranging from 1-1000.
column or literal, ranging from 1-59.
microsecond
column or literal, ranging from 1-999999.
Returns
-------
Expand All @@ -1336,8 +1349,9 @@ def _datetime(
minute = pli.expr_to_lit_or_expr(minute, str_to_lit=False)._pyexpr
if second is not None:
second = pli.expr_to_lit_or_expr(second, str_to_lit=False)._pyexpr
if millisecond is not None:
millisecond = pli.expr_to_lit_or_expr(millisecond, str_to_lit=False)._pyexpr
if microsecond is not None:
microsecond = pli.expr_to_lit_or_expr(microsecond, str_to_lit=False)._pyexpr

return pli.wrap_expr(
py_datetime(
year_expr._pyexpr,
Expand All @@ -1346,15 +1360,15 @@ def _datetime(
hour,
minute,
second,
millisecond,
microsecond,
)
)


def _date(
year: pli.Expr | str,
month: pli.Expr | str,
day: pli.Expr | str,
year: pli.Expr | str | int,
month: pli.Expr | str | int,
day: pli.Expr | str | int,
) -> pli.Expr:
"""
Create polars Date from distinct time components.
Expand Down
7 changes: 5 additions & 2 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,13 @@ def __le__(self, other: Any) -> Series:
return self._comp(other, "lt_eq")

def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Series:
if isinstance(other, pli.Expr):
# expand pl.lit, pl.datetime, pl.duration Exprs to compatible Series
other = self.to_frame().select(other).to_series()
if isinstance(other, Series):
return wrap_s(getattr(self._s, op_s)(other._s))
# we recurse and the if statement above will
# ensure we return early

# recurse; the 'if' statement above will ensure we return early
if isinstance(other, (date, datetime, timedelta, str)):
other = Series("", [other])
return self._arithmetic(other, op_s, op_ffi)
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,10 @@ def is_categorical_dtype(data_type: Any) -> bool:
# validate compatibility.
Time: times(),
Date: dates(),
Duration: timedeltas(),
Duration: timedeltas(
min_value=timedelta(microseconds=-(2**63)),
max_value=timedelta(microseconds=(2**63) - 1),
),
# TODO: confirm datetime min/max limits with different timeunit granularity.
# TODO: specific strategies for temporal dtypes with timeunits.
Datetime: datetimes(min_value=datetime(1970, 1, 1)),
Expand Down
9 changes: 6 additions & 3 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ fn py_datetime(
hour: Option<dsl::PyExpr>,
minute: Option<dsl::PyExpr>,
second: Option<dsl::PyExpr>,
millisecond: Option<dsl::PyExpr>,
microsecond: Option<dsl::PyExpr>,
) -> dsl::PyExpr {
let hour = hour.map(|e| e.inner);
let minute = minute.map(|e| e.inner);
let second = second.map(|e| e.inner);
let millisecond = millisecond.map(|e| e.inner);
let microsecond = microsecond.map(|e| e.inner);

let args = DatetimeArgs {
year: year.inner,
Expand All @@ -233,17 +233,19 @@ fn py_datetime(
hour,
minute,
second,
millisecond,
microsecond,
};

polars::lazy::dsl::datetime(args).into()
}

#[allow(clippy::too_many_arguments)]
#[pyfunction]
fn py_duration(
days: Option<PyExpr>,
seconds: Option<PyExpr>,
nanoseconds: Option<PyExpr>,
microseconds: Option<PyExpr>,
milliseconds: Option<PyExpr>,
minutes: Option<PyExpr>,
hours: Option<PyExpr>,
Expand All @@ -253,6 +255,7 @@ fn py_duration(
days: days.map(|e| e.inner),
seconds: seconds.map(|e| e.inner),
nanoseconds: nanoseconds.map(|e| e.inner),
microseconds: microseconds.map(|e| e.inner),
milliseconds: milliseconds.map(|e| e.inner),
minutes: minutes.map(|e| e.inner),
hours: hours.map(|e| e.inner),
Expand Down
36 changes: 36 additions & 0 deletions py-polars/tests/parametric/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# -------------------------------------------------
from __future__ import annotations

from decimal import Decimal

from hypothesis import given, settings
from hypothesis.strategies import sampled_from

Expand Down Expand Up @@ -39,3 +41,37 @@ def test_series_slice(

assert sliced_py_data == sliced_pl_data, f"slice [{start}:{stop}:{step}] failed"
assert_series_equal(srs, srs, check_exact=True)


@given(
s1=series(min_size=1, max_size=10, dtype=pl.Datetime),
s2=series(min_size=1, max_size=10, dtype=pl.Duration),
)
def test_series_timeunits(
s1: pl.Series,
s2: pl.Series,
) -> None:
# datetime
assert s1.to_list() == list(s1)
assert list(s1.dt.millisecond()) == [v.microsecond // 1000 for v in s1]
assert list(s1.dt.nanosecond()) == [v.microsecond * 1000 for v in s1]
assert list(s1.dt.microsecond()) == [v.microsecond for v in s1]

# duration
millis = s2.dt.milliseconds().to_list()
micros = s2.dt.microseconds().to_list()

assert s1.to_list() == list(s1)
assert millis == [int(Decimal(v) / 1000) for v in s2.cast(int)]
assert micros == list(s2.cast(int))

# special handling for ns timeunit (as we may generate a microsecs-based
# timedelta that results in 64bit overflow on conversion to nanosecs)
lower_bound, upper_bound = -(2**63), (2**63) - 1
if all(
(lower_bound <= (us * 1000) <= upper_bound)
for us in micros
if isinstance(us, int)
):
for ns, us in zip(s2.dt.nanoseconds(), micros):
assert ns == (us * 1000) # type: ignore[operator]

0 comments on commit 589f364

Please sign in to comment.