Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Nov 29, 2023
1 parent 9765105 commit 4acdc4d
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 57 deletions.
49 changes: 39 additions & 10 deletions crates/polars-plan/src/dsl/function_expr/range/time_range.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use std::cmp;
use std::iter::{zip, Zip};

use polars_core::prelude::*;
use polars_core::series::Series;
use polars_time::{time_range_impl, ClosedWindow, Duration};

use super::utils::{
broadcast_scalar_inputs_iter, ensure_range_bounds_contain_exactly_one_value,
temporal_series_to_i64_scalar,
};
use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar};

const CAPACITY_FACTOR: usize = 5;

Expand Down Expand Up @@ -44,13 +41,45 @@ pub(super) fn time_ranges(

let start_phys = start.to_physical_repr();
let end_phys = end.to_physical_repr();
let start_ca = start_phys.i64().unwrap();
let end_ca = end_phys.i64().unwrap();
let start = start_phys.i64().unwrap();
let end = end_phys.i64().unwrap();

let (start_iter, end_iter) = broadcast_scalar_inputs_iter(start_ca, end_ca)?;
let start_end_iter = std::iter::zip(start_iter, end_iter);
// let (start_iter, end_iter) = broadcast_scalar_inputs_iter(start_ca, end_ca)?;
// let start_end_iter = std::iter::zip(start_iter, end_iter);

let len = cmp::max(start.len(), end.len());
match (start.len(), end.len()) {
(len_start, len_end) if len_start == len_end => {
let start_end_iter = zip(start, end);
time_ranges_impl(start_end_iter, len_start, interval, closed)
},
(1, len_end) => {
let start_scalar = unsafe { start.get_unchecked(0) };
let start_iter = std::iter::repeat(start_scalar).take(len_end);
let start_end_iter = zip(start_iter, end);
time_ranges_impl(start_end_iter, len_end, interval, closed)
},
(len_start, 1) => {
let end_scalar = unsafe { end.get_unchecked(0) };
let end_iter = std::iter::repeat(end_scalar).take(len_start);
let start_end_iter = zip(start, end_iter);
time_ranges_impl(start_end_iter, len_start, interval, closed)
},
(len_start, len_end) => {
polars_bail!(
ComputeError:
"lengths of `start` ({}) and `end` ({}) do not match",
len_start, len_end
)
},
}
}

fn time_ranges_impl(
start_end_iter: Zip<impl Iterator<Item = Option<i64>>, impl Iterator<Item = Option<i64>>>,
len: usize,
interval: Duration,
closed: ClosedWindow,
) -> PolarsResult<Series> {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
"time_range",
len,
Expand Down
88 changes: 41 additions & 47 deletions crates/polars-plan/src/dsl/function_expr/range/utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::iter::zip;
// use std::iter::zip;

use polars_core::prelude::{
polars_bail, polars_ensure, ChunkedArray, PolarsIterator, PolarsResult, *,
};
// use polars_core::prelude::{
// polars_bail, polars_ensure, ChunkedArray, PolarsIterator, PolarsResult, *,
// };
use polars_core::prelude::{polars_bail, polars_ensure, PolarsResult};
use polars_core::series::Series;

pub(super) fn temporal_series_to_i64_scalar(s: &Series) -> Option<i64> {
Expand Down Expand Up @@ -47,46 +48,39 @@ pub(super) fn broadcast_scalar_inputs(
}
}

pub(super) fn broadcast_scalar_inputs_iter<T>(
start: &ChunkedArray<T>,
end: &ChunkedArray<T>,
) -> PolarsResult<(
Box<dyn PolarsIterator<Item = Option<T::Native>>>,
Box<dyn PolarsIterator<Item = Option<T::Native>>>,
)>
where
T: PolarsNumericType,
{
match (start.len(), end.len()) {
(len1, len2) if len1 == len2 => {
// let zipped = zip(start.into_iter(), end.into_iter());
let zipped = (start.into_iter(), end.into_iter());
Ok(zipped)
},
(1, len2) => {
let start_scalar: Option<T::Native> = unsafe { start.get_unchecked(0) };
let start_iter = Box::new(std::iter::repeat(start_scalar).take(len2));

// let end_iter = end.into_iter();

// let zipped = zip(start_iter, end.into_iter());
let zipped = (start_iter, end.into_iter());
Ok(zipped)
},
(len1, 1) => {
let end_scalar: Option<T::Native> = unsafe { end.get_unchecked(0) };
let end_iter = std::iter::repeat(end_scalar); //.take(len1)

// let zipped = zip(start.into_iter(), end_iter);
let zipped = (start.into_iter(), end_iter);
Ok(zipped)
},
(len1, len2) => {
polars_bail!(
ComputeError:
"lengths of `start` ({}) and `end` ({}) do not match",
len1, len2
)
},
}
}
// pub(super) fn broadcast_scalar_inputs_iter<T>(
// start: &ChunkedArray<T>,
// end: &ChunkedArray<T>,
// ) -> PolarsResult<(
// Box<dyn PolarsIterator<Item = Option<T::Native>>>,
// Box<dyn PolarsIterator<Item = Option<T::Native>>>,
// )>
// where
// T: PolarsNumericType,
// {
// match (start.len(), end.len()) {
// (len_start, len_end) if len_start == len_end => {
// let zipped = zip(start.into_iter(), end.into_iter());
// Ok(zipped)
// },
// (1, len_end) => {
// let start_scalar: Option<T::Native> = unsafe { start.get_unchecked(0) };
// let start_iter = Box::new(std::iter::repeat(start_scalar).take(len_end));
// let zipped = zip(start_iter, end.into_iter());
// Ok(zipped)
// },
// (len_start, 1) => {
// let end_scalar: Option<T::Native> = unsafe { end.get_unchecked(0) };
// let end_iter = std::iter::repeat(end_scalar).take(len_start);
// let zipped = zip(start.into_iter(), end_iter);
// Ok(zipped)
// },
// (len_start, len_end) => {
// polars_bail!(
// ComputeError:
// "lengths of `start` ({}) and `end` ({}) do not match",
// len_start, len_end
// )
// },
// }
// }

0 comments on commit 4acdc4d

Please sign in to comment.