Skip to content

Commit

Permalink
fix: Allow broadcasting in ranges (#11900)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Dec 12, 2023
1 parent 7ee12fb commit a6483c6
Show file tree
Hide file tree
Showing 9 changed files with 313 additions and 145 deletions.
63 changes: 28 additions & 35 deletions crates/polars-plan/src/dsl/function_expr/range/date_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY;
use polars_time::{datetime_range_impl, ClosedWindow, Duration};

use super::datetime_range::{datetime_range, datetime_ranges};
use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar};
use super::utils::{
ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast,
temporal_series_to_i64_scalar,
};
use crate::dsl::function_expr::FieldsMapper;

const CAPACITY_FACTOR: usize = 5;
Expand Down Expand Up @@ -73,50 +76,40 @@ fn date_ranges(s: &[Series], interval: Duration, closed: ClosedWindow) -> Polars
let start = &s[0];
let end = &s[1];

polars_ensure!(
start.len() == end.len(),
ComputeError: "`start` and `end` must have the same length",
);
let start = start.cast(&DataType::Int64)?;
let end = end.cast(&DataType::Int64)?;

let start = date_series_to_i64_ca(start)? * MILLISECONDS_IN_DAY;
let end = date_series_to_i64_ca(end)? * MILLISECONDS_IN_DAY;
let start = start.i64().unwrap() * MILLISECONDS_IN_DAY;
let end = end.i64().unwrap() * MILLISECONDS_IN_DAY;

let mut builder = ListPrimitiveChunkedBuilder::<Int32Type>::new(
"date_range",
start.len(),
start.len() * CAPACITY_FACTOR,
DataType::Int32,
);
for (start, end) in start.as_ref().into_iter().zip(&end) {
match (start, end) {
(Some(start), Some(end)) => {
// TODO: Implement an i32 version of `date_range_impl`
let rng = datetime_range_impl(
"",
start,
end,
interval,
closed,
TimeUnit::Milliseconds,
None,
)?;
let rng = rng.cast(&DataType::Date).unwrap();
let rng = rng.to_physical_repr();
let rng = rng.i32().unwrap();
builder.append_slice(rng.cont_slice().unwrap())
},
_ => builder.append_null(),
}
}
let list = builder.finish().into_series();

let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder<Int32Type>| {
let rng = datetime_range_impl(
"",
start,
end,
interval,
closed,
TimeUnit::Milliseconds,
None,
)?;
let rng = rng.cast(&DataType::Date).unwrap();
let rng = rng.to_physical_repr();
let rng = rng.i32().unwrap();
builder.append_slice(rng.cont_slice().unwrap());
Ok(())
};

let out = ranges_impl_broadcast(&start, &end, range_impl, &mut builder)?;

let to_type = DataType::List(Box::new(DataType::Date));
list.cast(&to_type)
}
fn date_series_to_i64_ca(s: &Series) -> PolarsResult<ChunkedArray<Int64Type>> {
let s = s.cast(&DataType::Int64)?;
let result = s.i64().unwrap();
Ok(result.clone())
out.cast(&to_type)
}

impl<'a> FieldsMapper<'a> {
Expand Down
33 changes: 14 additions & 19 deletions crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use polars_core::prelude::*;
use polars_core::series::Series;
use polars_time::{datetime_range_impl, ClosedWindow, Duration};

use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar};
use super::utils::{
ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast,
temporal_series_to_i64_scalar,
};
use crate::dsl::function_expr::FieldsMapper;

const CAPACITY_FACTOR: usize = 5;
Expand Down Expand Up @@ -91,11 +94,6 @@ pub(super) fn datetime_ranges(
let start = &s[0];
let end = &s[1];

polars_ensure!(
start.len() == end.len(),
ComputeError: "`start` and `end` must have the same length",
);

// Note: `start` and `end` have already been cast to their supertype,
// so only `start`'s dtype needs to be matched against.
#[allow(unused_mut)] // `dtype` is mutated within a "feature = timezones" block.
Expand Down Expand Up @@ -159,31 +157,28 @@ pub(super) fn datetime_ranges(
let start = start.i64().unwrap();
let end = end.i64().unwrap();

let list = match dtype {
let out = match dtype {
DataType::Datetime(tu, ref tz) => {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
"datetime_range",
start.len(),
start.len() * CAPACITY_FACTOR,
DataType::Int64,
);
for (start, end) in start.into_iter().zip(end) {
match (start, end) {
(Some(start), Some(end)) => {
let rng =
datetime_range_impl("", start, end, interval, closed, tu, tz.as_ref())?;
builder.append_slice(rng.cont_slice().unwrap())
},
_ => builder.append_null(),
}
}
builder.finish().into_series()

let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder<Int64Type>| {
let rng = datetime_range_impl("", start, end, interval, closed, tu, tz.as_ref())?;
builder.append_slice(rng.cont_slice().unwrap());
Ok(())
};

ranges_impl_broadcast(start, end, range_impl, &mut builder)?
},
_ => unimplemented!(),
};

let to_type = DataType::List(Box::new(dtype));
list.cast(&to_type)
out.cast(&to_type)
}

impl<'a> FieldsMapper<'a> {
Expand Down
86 changes: 25 additions & 61 deletions crates/polars-plan/src/dsl/function_expr/range/int_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use polars_core::prelude::*;
use polars_core::series::{IsSorted, Series};
use polars_core::with_match_physical_integer_polars_type;

use super::utils::ensure_range_bounds_contain_exactly_one_value;
use super::utils::{ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast};

const CAPACITY_FACTOR: usize = 5;

pub(super) fn int_range(s: &[Series], step: i64, dtype: DataType) -> PolarsResult<Series> {
let mut start = &s[0];
Expand Down Expand Up @@ -69,73 +71,35 @@ where
}

pub(super) fn int_ranges(s: &[Series], step: i64) -> PolarsResult<Series> {
let start = &s[0].rechunk();
let end = &s[1].rechunk();

let output_name = "int_range";

let mut start = start.cast(&DataType::Int64)?;
let mut end = end.cast(&DataType::Int64)?;

if start.len() != end.len() {
if start.len() == 1 {
start = start.new_from_index(0, end.len())
} else if end.len() == 1 {
end = end.new_from_index(0, start.len())
} else {
polars_bail!(
ComputeError:
"lengths of `start`: {} and `end`: {} arguments `\
cannot be matched in the `int_ranges` expression",
start.len(), end.len()
);
}
}
let start = &s[0];
let end = &s[1];

let start = start.cast(&DataType::Int64)?;
let end = end.cast(&DataType::Int64)?;

let start = start.i64()?;
let end = end.i64()?;

let start = start.downcast_iter().next().unwrap();
let end = end.downcast_iter().next().unwrap();

// First do a pass to determine the required value capacity.
let mut values_capacity = 0;
for (opt_start, opt_end) in start.into_iter().zip(end) {
if let (Some(start_v), Some(end_v)) = (opt_start, opt_end) {
if step == 1 {
values_capacity += (end_v - start_v).unsigned_abs() as usize;
} else {
values_capacity +=
(((end_v - start_v).unsigned_abs() / step.unsigned_abs()) + 1) as usize;
}
}
}

let len = std::cmp::max(start.len(), end.len());
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
output_name,
start.len(),
values_capacity,
"int_range",
len,
len * CAPACITY_FACTOR,
DataType::Int64,
);

for (opt_start, opt_end) in start.into_iter().zip(end) {
match (opt_start, opt_end) {
(Some(&start_v), Some(&end_v)) => match step {
1 => {
builder.append_iter_values(start_v..end_v);
},
2.. => {
builder.append_iter_values((start_v..end_v).step_by(step as usize));
},
_ => builder.append_iter_values(
(end_v..start_v)
.step_by(step.unsigned_abs() as usize)
.map(|x| start_v - (x - end_v)),
),
},
_ => builder.append_null(),
}
}
let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder<Int64Type>| {
match step {
1 => builder.append_iter_values(start..end),
2.. => builder.append_iter_values((start..end).step_by(step as usize)),
_ => builder.append_iter_values(
(end..start)
.step_by(step.unsigned_abs() as usize)
.map(|x| start - (x - end)),
),
};
Ok(())
};

Ok(builder.finish().into_series())
ranges_impl_broadcast(start, end, range_impl, &mut builder)
}
48 changes: 22 additions & 26 deletions crates/polars-plan/src/dsl/function_expr/range/time_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use polars_core::prelude::*;
use polars_core::series::Series;
use polars_time::{time_range_impl, ClosedWindow, Duration};

use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar};
use super::utils::{
ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast,
temporal_series_to_i64_scalar,
};

const CAPACITY_FACTOR: usize = 5;

Expand Down Expand Up @@ -34,37 +37,30 @@ pub(super) fn time_ranges(
let start = &s[0];
let end = &s[1];

polars_ensure!(
start.len() == end.len(),
ComputeError: "`start` and `end` must have the same length",
);
let start = start.cast(&DataType::Time)?;
let end = end.cast(&DataType::Time)?;

let start = time_series_to_i64_ca(start)?;
let end = time_series_to_i64_ca(end)?;
let start_phys = start.to_physical_repr();
let end_phys = end.to_physical_repr();
let start = start_phys.i64().unwrap();
let end = end_phys.i64().unwrap();

let len = std::cmp::max(start.len(), end.len());
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
"time_range",
start.len(),
start.len() * CAPACITY_FACTOR,
len,
len * CAPACITY_FACTOR,
DataType::Int64,
);
for (start, end) in start.as_ref().into_iter().zip(&end) {
match (start, end) {
(Some(start), Some(end)) => {
let rng = time_range_impl("", start, end, interval, closed)?;
builder.append_slice(rng.cont_slice().unwrap())
},
_ => builder.append_null(),
}
}
let list = builder.finish().into_series();

let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder<Int64Type>| {
let rng = time_range_impl("", start, end, interval, closed)?;
builder.append_slice(rng.cont_slice().unwrap());
Ok(())
};

let out = ranges_impl_broadcast(start, end, range_impl, &mut builder)?;

let to_type = DataType::List(Box::new(DataType::Time));
list.cast(&to_type)
}
fn time_series_to_i64_ca(s: &Series) -> PolarsResult<ChunkedArray<Int64Type>> {
let s = s.cast(&DataType::Time)?;
let s = s.to_physical_repr();
let result = s.i64().unwrap();
Ok(result.clone())
out.cast(&to_type)
}
Loading

0 comments on commit a6483c6

Please sign in to comment.