Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow broadcasting in ranges #11900

Merged
merged 20 commits into from
Dec 12, 2023
Merged
63 changes: 28 additions & 35 deletions crates/polars-plan/src/dsl/function_expr/range/date_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY;
use polars_time::{datetime_range_impl, ClosedWindow, Duration};

use super::datetime_range::{datetime_range, datetime_ranges};
use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar};
use super::utils::{
ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast,
temporal_series_to_i64_scalar,
};
use crate::dsl::function_expr::FieldsMapper;

const CAPACITY_FACTOR: usize = 5;
Expand Down Expand Up @@ -73,50 +76,40 @@ fn date_ranges(s: &[Series], interval: Duration, closed: ClosedWindow) -> Polars
let start = &s[0];
let end = &s[1];

polars_ensure!(
start.len() == end.len(),
ComputeError: "`start` and `end` must have the same length",
);
let start = start.cast(&DataType::Int64)?;
let end = end.cast(&DataType::Int64)?;

let start = date_series_to_i64_ca(start)? * MILLISECONDS_IN_DAY;
let end = date_series_to_i64_ca(end)? * MILLISECONDS_IN_DAY;
let start = start.i64().unwrap() * MILLISECONDS_IN_DAY;
let end = end.i64().unwrap() * MILLISECONDS_IN_DAY;

let mut builder = ListPrimitiveChunkedBuilder::<Int32Type>::new(
"date_range",
start.len(),
start.len() * CAPACITY_FACTOR,
DataType::Int32,
);
for (start, end) in start.as_ref().into_iter().zip(&end) {
match (start, end) {
(Some(start), Some(end)) => {
// TODO: Implement an i32 version of `date_range_impl`
let rng = datetime_range_impl(
"",
start,
end,
interval,
closed,
TimeUnit::Milliseconds,
None,
)?;
let rng = rng.cast(&DataType::Date).unwrap();
let rng = rng.to_physical_repr();
let rng = rng.i32().unwrap();
builder.append_slice(rng.cont_slice().unwrap())
},
_ => builder.append_null(),
}
}
let list = builder.finish().into_series();

let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder<Int32Type>| {
let rng = datetime_range_impl(
"",
start,
end,
interval,
closed,
TimeUnit::Milliseconds,
None,
)?;
let rng = rng.cast(&DataType::Date).unwrap();
let rng = rng.to_physical_repr();
let rng = rng.i32().unwrap();
builder.append_slice(rng.cont_slice().unwrap());
Ok(())
};

let out = ranges_impl_broadcast(&start, &end, range_impl, &mut builder)?;

let to_type = DataType::List(Box::new(DataType::Date));
list.cast(&to_type)
}
fn date_series_to_i64_ca(s: &Series) -> PolarsResult<ChunkedArray<Int64Type>> {
let s = s.cast(&DataType::Int64)?;
let result = s.i64().unwrap();
Ok(result.clone())
out.cast(&to_type)
}

impl<'a> FieldsMapper<'a> {
Expand Down
33 changes: 14 additions & 19 deletions crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use polars_core::prelude::*;
use polars_core::series::Series;
use polars_time::{datetime_range_impl, ClosedWindow, Duration};

use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar};
use super::utils::{
ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast,
temporal_series_to_i64_scalar,
};
use crate::dsl::function_expr::FieldsMapper;

const CAPACITY_FACTOR: usize = 5;
Expand Down Expand Up @@ -91,11 +94,6 @@ pub(super) fn datetime_ranges(
let start = &s[0];
let end = &s[1];

polars_ensure!(
start.len() == end.len(),
ComputeError: "`start` and `end` must have the same length",
);

// Note: `start` and `end` have already been cast to their supertype,
// so only `start`'s dtype needs to be matched against.
#[allow(unused_mut)] // `dtype` is mutated within a "feature = timezones" block.
Expand Down Expand Up @@ -159,31 +157,28 @@ pub(super) fn datetime_ranges(
let start = start.i64().unwrap();
let end = end.i64().unwrap();

let list = match dtype {
let out = match dtype {
DataType::Datetime(tu, ref tz) => {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
"datetime_range",
start.len(),
start.len() * CAPACITY_FACTOR,
DataType::Int64,
);
for (start, end) in start.into_iter().zip(end) {
match (start, end) {
(Some(start), Some(end)) => {
let rng =
datetime_range_impl("", start, end, interval, closed, tu, tz.as_ref())?;
builder.append_slice(rng.cont_slice().unwrap())
},
_ => builder.append_null(),
}
}
builder.finish().into_series()

let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder<Int64Type>| {
let rng = datetime_range_impl("", start, end, interval, closed, tu, tz.as_ref())?;
builder.append_slice(rng.cont_slice().unwrap());
Ok(())
};

ranges_impl_broadcast(start, end, range_impl, &mut builder)?
},
_ => unimplemented!(),
};

let to_type = DataType::List(Box::new(dtype));
list.cast(&to_type)
out.cast(&to_type)
}

impl<'a> FieldsMapper<'a> {
Expand Down
86 changes: 25 additions & 61 deletions crates/polars-plan/src/dsl/function_expr/range/int_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use polars_core::prelude::*;
use polars_core::series::{IsSorted, Series};
use polars_core::with_match_physical_integer_polars_type;

use super::utils::ensure_range_bounds_contain_exactly_one_value;
use super::utils::{ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast};

const CAPACITY_FACTOR: usize = 5;

pub(super) fn int_range(s: &[Series], step: i64, dtype: DataType) -> PolarsResult<Series> {
let mut start = &s[0];
Expand Down Expand Up @@ -69,73 +71,35 @@ where
}

pub(super) fn int_ranges(s: &[Series], step: i64) -> PolarsResult<Series> {
let start = &s[0].rechunk();
let end = &s[1].rechunk();

let output_name = "int_range";

let mut start = start.cast(&DataType::Int64)?;
let mut end = end.cast(&DataType::Int64)?;

if start.len() != end.len() {
if start.len() == 1 {
start = start.new_from_index(0, end.len())
} else if end.len() == 1 {
end = end.new_from_index(0, start.len())
} else {
polars_bail!(
ComputeError:
"lengths of `start`: {} and `end`: {} arguments `\
cannot be matched in the `int_ranges` expression",
start.len(), end.len()
);
}
}
let start = &s[0];
let end = &s[1];

let start = start.cast(&DataType::Int64)?;
let end = end.cast(&DataType::Int64)?;

let start = start.i64()?;
let end = end.i64()?;

let start = start.downcast_iter().next().unwrap();
let end = end.downcast_iter().next().unwrap();

// First do a pass to determine the required value capacity.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this pass and set a default capacity of 5 times the length of the column, similar to the other ranges functions. If this capacity calculation must be preserved, I have to make some adjustments as it relied on fully materializing start/end.

let mut values_capacity = 0;
for (opt_start, opt_end) in start.into_iter().zip(end) {
if let (Some(start_v), Some(end_v)) = (opt_start, opt_end) {
if step == 1 {
values_capacity += (end_v - start_v).unsigned_abs() as usize;
} else {
values_capacity +=
(((end_v - start_v).unsigned_abs() / step.unsigned_abs()) + 1) as usize;
}
}
}

let len = std::cmp::max(start.len(), end.len());
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
output_name,
start.len(),
values_capacity,
"int_range",
len,
len * CAPACITY_FACTOR,
DataType::Int64,
);

for (opt_start, opt_end) in start.into_iter().zip(end) {
match (opt_start, opt_end) {
(Some(&start_v), Some(&end_v)) => match step {
1 => {
builder.append_iter_values(start_v..end_v);
},
2.. => {
builder.append_iter_values((start_v..end_v).step_by(step as usize));
},
_ => builder.append_iter_values(
(end_v..start_v)
.step_by(step.unsigned_abs() as usize)
.map(|x| start_v - (x - end_v)),
),
},
_ => builder.append_null(),
}
}
let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder<Int64Type>| {
match step {
1 => builder.append_iter_values(start..end),
2.. => builder.append_iter_values((start..end).step_by(step as usize)),
_ => builder.append_iter_values(
(end..start)
.step_by(step.unsigned_abs() as usize)
.map(|x| start - (x - end)),
),
};
Ok(())
};

Ok(builder.finish().into_series())
ranges_impl_broadcast(start, end, range_impl, &mut builder)
}
48 changes: 22 additions & 26 deletions crates/polars-plan/src/dsl/function_expr/range/time_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use polars_core::prelude::*;
use polars_core::series::Series;
use polars_time::{time_range_impl, ClosedWindow, Duration};

use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar};
use super::utils::{
ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast,
temporal_series_to_i64_scalar,
};

const CAPACITY_FACTOR: usize = 5;

Expand Down Expand Up @@ -34,37 +37,30 @@ pub(super) fn time_ranges(
let start = &s[0];
let end = &s[1];

polars_ensure!(
start.len() == end.len(),
ComputeError: "`start` and `end` must have the same length",
);
let start = start.cast(&DataType::Time)?;
let end = end.cast(&DataType::Time)?;

let start = time_series_to_i64_ca(start)?;
let end = time_series_to_i64_ca(end)?;
let start_phys = start.to_physical_repr();
let end_phys = end.to_physical_repr();
let start = start_phys.i64().unwrap();
let end = end_phys.i64().unwrap();

let len = std::cmp::max(start.len(), end.len());
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
"time_range",
start.len(),
start.len() * CAPACITY_FACTOR,
len,
len * CAPACITY_FACTOR,
DataType::Int64,
);
for (start, end) in start.as_ref().into_iter().zip(&end) {
match (start, end) {
(Some(start), Some(end)) => {
let rng = time_range_impl("", start, end, interval, closed)?;
builder.append_slice(rng.cont_slice().unwrap())
},
_ => builder.append_null(),
}
}
let list = builder.finish().into_series();

let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder<Int64Type>| {
let rng = time_range_impl("", start, end, interval, closed)?;
builder.append_slice(rng.cont_slice().unwrap());
Ok(())
};

let out = ranges_impl_broadcast(start, end, range_impl, &mut builder)?;

let to_type = DataType::List(Box::new(DataType::Time));
list.cast(&to_type)
}
fn time_series_to_i64_ca(s: &Series) -> PolarsResult<ChunkedArray<Int64Type>> {
let s = s.cast(&DataType::Time)?;
let s = s.to_physical_repr();
let result = s.i64().unwrap();
Ok(result.clone())
out.cast(&to_type)
}
Loading