Skip to content

Commit

Permalink
fix(rust, python): fix date_range on expressions (#5750)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 8, 2022
1 parent 4fa0b01 commit d4e4ce3
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 41 deletions.
22 changes: 12 additions & 10 deletions polars/polars-core/src/chunked_array/builder/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,19 @@ where
}
}

pub fn append_slice(&mut self, opt_v: Option<&[T::Native]>) {
match opt_v {
Some(items) => {
let values = self.builder.mut_values();
values.extend_from_slice(items);
self.builder.try_push_valid().unwrap();
pub fn append_slice(&mut self, items: &[T::Native]) {
let values = self.builder.mut_values();
values.extend_from_slice(items);
self.builder.try_push_valid().unwrap();

if items.is_empty() {
self.fast_explode = false;
}
}
if items.is_empty() {
self.fast_explode = false;
}
}

pub fn append_opt_slice(&mut self, opt_v: Option<&[T::Native]>) {
match opt_v {
Some(items) => self.append_slice(items),
None => {
self.builder.push_null();
}
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ mod test {
fn test_cast_list() -> PolarsResult<()> {
let mut builder =
ListPrimitiveChunkedBuilder::<Int32Type>::new("a", 10, 10, DataType::Int32);
builder.append_slice(Some(&[1i32, 2, 3]));
builder.append_slice(Some(&[1i32, 2, 3]));
builder.append_opt_slice(Some(&[1i32, 2, 3]));
builder.append_opt_slice(Some(&[1i32, 2, 3]));
let ca = builder.finish();

let new = ca.cast(&DataType::List(DataType::Float64.into()))?;
Expand Down
12 changes: 6 additions & 6 deletions polars/polars-core/src/chunked_array/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ mod test {

let mut builder =
ListPrimitiveChunkedBuilder::<Float64Type>::new("", 10, 10, DataType::Float64);
builder.append_slice(Some(&[1.0, 2.0, 3.0]));
builder.append_slice(Some(&[2.0, 4.0, 5.0]));
builder.append_slice(Some(&[6.0, 7.0, 8.0]));
builder.append_opt_slice(Some(&[1.0, 2.0, 3.0]));
builder.append_opt_slice(Some(&[2.0, 4.0, 5.0]));
builder.append_opt_slice(Some(&[6.0, 7.0, 8.0]));
let list = builder.finish();

let ndarr = list.to_ndarray::<Float64Type>()?;
Expand All @@ -187,9 +187,9 @@ mod test {
// test list array that is not square
let mut builder =
ListPrimitiveChunkedBuilder::<Float64Type>::new("", 10, 10, DataType::Float64);
builder.append_slice(Some(&[1.0, 2.0, 3.0]));
builder.append_slice(Some(&[2.0]));
builder.append_slice(Some(&[6.0, 7.0, 8.0]));
builder.append_opt_slice(Some(&[1.0, 2.0, 3.0]));
builder.append_opt_slice(Some(&[2.0]));
builder.append_opt_slice(Some(&[6.0, 7.0, 8.0]));
let list = builder.finish();
assert!(list.to_ndarray::<Float64Type>().is_err());
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -884,8 +884,8 @@ mod test {
fn test_fmt_list() {
let mut builder =
ListPrimitiveChunkedBuilder::<Int32Type>::new("a", 10, 10, DataType::Int32);
builder.append_slice(Some(&[1, 2, 3]));
builder.append_slice(None);
builder.append_opt_slice(Some(&[1, 2, 3]));
builder.append_opt_slice(None);
let list = builder.finish().into_series();

assert_eq!(
Expand Down
121 changes: 104 additions & 17 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,25 +166,112 @@ pub(super) fn date_range_dispatch(
let start = &s[0];
let stop = &s[1];

match start.dtype() {
DataType::Date => {
let start = start.to_physical_repr();
let stop = stop.to_physical_repr();
// to milliseconds
let start = start.get(0).extract::<i64>().unwrap() * SECONDS_IN_DAY * 1000;
let stop = stop.get(0).extract::<i64>().unwrap() * SECONDS_IN_DAY * 1000;

date_range_impl(name, start, stop, every, closed, TimeUnit::Milliseconds, tz)
if start.len() != stop.len() {
return Err(PolarsError::ComputeError(
"'start' and 'stop' should have the same length.".into(),
));
}
const TO_MS: i64 = SECONDS_IN_DAY * 1000;

if start.len() == 1 && stop.len() == 1 {
match start.dtype() {
DataType::Date => {
let start = start.to_physical_repr();
let stop = stop.to_physical_repr();
// to milliseconds
let start = start.get(0).extract::<i64>().unwrap() * TO_MS;
let stop = stop.get(0).extract::<i64>().unwrap() * TO_MS;

date_range_impl(
name,
start,
stop,
every,
closed,
TimeUnit::Milliseconds,
tz.as_ref(),
)
.cast(&DataType::Date)
}
DataType::Datetime(tu, _) => {
let start = start.to_physical_repr();
let stop = stop.to_physical_repr();
let start = start.get(0).extract::<i64>().unwrap();
let stop = stop.get(0).extract::<i64>().unwrap();
}
DataType::Datetime(tu, _) => {
let start = start.to_physical_repr();
let stop = stop.to_physical_repr();
let start = start.get(0).extract::<i64>().unwrap();
let stop = stop.get(0).extract::<i64>().unwrap();

Ok(date_range_impl(name, start, stop, every, closed, *tu, tz).into_series())
Ok(
date_range_impl(name, start, stop, every, closed, *tu, tz.as_ref())
.into_series(),
)
}
_ => unimplemented!(),
}
_ => todo!(),
} else {
let dtype = start.dtype();

let mut start = start.to_physical_repr().cast(&DataType::Int64)?;
let mut stop = stop.to_physical_repr().cast(&DataType::Int64)?;

let (tu, tz) = match dtype {
DataType::Date => {
start = &start * TO_MS;
stop = &stop * TO_MS;
(TimeUnit::Milliseconds, None)
}
DataType::Datetime(tu, tz) => (*tu, tz.as_ref()),
_ => unimplemented!(),
};

let start = start.i64().unwrap();
let stop = stop.i64().unwrap();

let list = match dtype {
DataType::Date => {
let mut builder = ListPrimitiveChunkedBuilder::<Int32Type>::new(
name,
start.len(),
start.len() * 5,
DataType::Int32,
);
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let date_range =
date_range_impl("", start, stop, every, closed, tu, tz);
let date_range = date_range.cast(&DataType::Date).unwrap();
let date_range = date_range.to_physical_repr();
let date_range = date_range.i32().unwrap();
builder.append_slice(date_range.cont_slice().unwrap())
}
_ => builder.append_null(),
}
}
builder.finish().into_series()
}
DataType::Datetime(_, _) => {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
name,
start.len(),
start.len() * 5,
DataType::Int64,
);

for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let date_range =
date_range_impl("", start, stop, every, closed, tu, tz);
builder.append_slice(date_range.cont_slice().unwrap())
}
_ => builder.append_null(),
}
}
builder.finish().into_series()
}
_ => unimplemented!(),
};

let to_type = DataType::List(Box::new(dtype.clone()));
list.cast(&to_type)
}
}
6 changes: 3 additions & 3 deletions polars/polars-time/src/date_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub fn date_range_impl(
every: Duration,
closed: ClosedWindow,
tu: TimeUnit,
_tz: Option<TimeZone>,
_tz: Option<&TimeZone>,
) -> DatetimeChunked {
let mut out = Int64Chunked::new_vec(name, date_range_vec(start, stop, every, closed, tu))
.into_datetime(tu, None);
Expand All @@ -26,7 +26,7 @@ pub fn date_range_impl(
if let Some(tz) = _tz {
out = out
.with_time_zone(Some("UTC".to_string()))
.cast_time_zone(&tz)
.cast_time_zone(tz)
.unwrap()
}
out.set_sorted(start > stop);
Expand All @@ -51,5 +51,5 @@ pub fn date_range(
),
TimeUnit::Milliseconds => (start.timestamp_millis(), stop.timestamp_millis()),
};
date_range_impl(name, start, stop, every, closed, tu, tz)
date_range_impl(name, start, stop, every, closed, tu, tz.as_ref())
}
2 changes: 1 addition & 1 deletion py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ fn py_date_range(
Duration::parse(every),
closed.0,
tu.0,
tz,
tz.as_ref(),
)
.into_series()
.into()
Expand Down
44 changes: 44 additions & 0 deletions py-polars/tests/unit/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,50 @@ def test_date_range_lazy() -> None:
)
]

assert pl.DataFrame(
{
"start": [date(2000, 1, 1), date(2022, 6, 1)],
"stop": [date(2000, 1, 2), date(2022, 6, 2)],
}
).with_columns(
pl.date_range(
pl.col("start"),
pl.col("stop"),
interval="1d",
).alias("dts")
).to_dict(
False
) == {
"start": [date(2000, 1, 1), date(2022, 6, 1)],
"stop": [date(2000, 1, 2), date(2022, 6, 2)],
"dts": [
[date(2000, 1, 1), date(2000, 1, 2)],
[date(2022, 6, 1), date(2022, 6, 2)],
],
}

assert pl.DataFrame(
{
"start": [datetime(2000, 1, 1), datetime(2022, 6, 1)],
"stop": [datetime(2000, 1, 2), datetime(2022, 6, 2)],
}
).with_columns(
pl.date_range(
pl.col("start"),
pl.col("stop"),
interval="1d",
).alias("dts")
).to_dict(
False
) == {
"start": [datetime(2000, 1, 1, 0, 0), datetime(2022, 6, 1, 0, 0)],
"stop": [datetime(2000, 1, 2, 0, 0), datetime(2022, 6, 2, 0, 0)],
"dts": [
[datetime(2000, 1, 1, 0, 0), datetime(2000, 1, 2, 0, 0)],
[datetime(2022, 6, 1, 0, 0), datetime(2022, 6, 2, 0, 0)],
],
}


@pytest.mark.parametrize(
"one,two",
Expand Down

0 comments on commit d4e4ce3

Please sign in to comment.