Skip to content

Commit

Permalink
Simplify int_ranges capacity
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Dec 8, 2023
1 parent 39a7b05 commit 80200bb
Showing 1 changed file with 9 additions and 46 deletions.
55 changes: 9 additions & 46 deletions crates/polars-plan/src/dsl/function_expr/range/int_range.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::iter::zip;

use polars_core::prelude::*;
use polars_core::series::{IsSorted, Series};
use polars_core::with_match_physical_integer_polars_type;

use super::utils::{ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast};

const CAPACITY_FACTOR: usize = 5;

pub(super) fn int_range(s: &[Series], step: i64, dtype: DataType) -> PolarsResult<Series> {
let mut start = &s[0];
let mut end = &s[1];
Expand Down Expand Up @@ -71,36 +71,20 @@ where
}

pub(super) fn int_ranges(s: &[Series], step: i64) -> PolarsResult<Series> {
let start = &s[0].rechunk();
let end = &s[1].rechunk();

let mut start = start.cast(&DataType::Int64)?;
let mut end = end.cast(&DataType::Int64)?;
let start = &s[0];
let end = &s[1];

(start, end) = broadcast_scalar_inputs(start, end)?;
let start = start.cast(&DataType::Int64)?;
let end = end.cast(&DataType::Int64)?;

let start = start.i64()?;
let end = end.i64()?;

// First do a pass to determine the required value capacity.
let mut capacity = 0;
let mut increment_capacity = |start: i64, end: i64| {
if step == 1 {
capacity += (end - start).unsigned_abs() as usize;
} else {
capacity += (((end - start).unsigned_abs() / step.unsigned_abs()) + 1) as usize;
}
};
for (opt_start, opt_end) in zip(start, end) {
if let (Some(start_v), Some(end_v)) = (opt_start, opt_end) {
increment_capacity(start_v, end_v)
}
}

let len = std::cmp::max(start.len(), end.len());
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
"int_range",
start.len(),
capacity,
len,
len * CAPACITY_FACTOR,
DataType::Int64,
);

Expand All @@ -119,24 +103,3 @@ pub(super) fn int_ranges(s: &[Series], step: i64) -> PolarsResult<Series> {

ranges_impl_broadcast(start, end, range_impl, &mut builder)
}

fn broadcast_scalar_inputs(start: Series, end: Series) -> PolarsResult<(Series, Series)> {
match (start.len(), end.len()) {
(len1, len2) if len1 == len2 => Ok((start, end)),
(1, len2) => {
let start_matched = start.new_from_index(0, len2);
Ok((start_matched, end))
},
(len1, 1) => {
let end_matched = end.new_from_index(0, len1);
Ok((start, end_matched))
},
(len1, len2) => {
polars_bail!(
ComputeError:
"lengths of `start` ({}) and `end` ({}) do not match",
len1, len2
)
},
}
}

0 comments on commit 80200bb

Please sign in to comment.