diff --git a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs index aba3a9f5384e..cf7efcfbdc15 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs @@ -5,7 +5,7 @@ use polars_time::{datetime_range_impl, ClosedWindow, Duration}; use super::datetime_range::{datetime_range, datetime_ranges}; use super::utils::{ - ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast, + ensure_range_bounds_contain_exactly_one_value, temporal_ranges_impl_broadcast, temporal_series_to_i64_scalar, }; use crate::dsl::function_expr::FieldsMapper; @@ -106,7 +106,7 @@ fn date_ranges(s: &[Series], interval: Duration, closed: ClosedWindow) -> Polars Ok(()) }; - let out = ranges_impl_broadcast(&start, &end, range_impl, &mut builder)?; + let out = temporal_ranges_impl_broadcast(&start, &end, range_impl, &mut builder)?; let to_type = DataType::List(Box::new(DataType::Date)); out.cast(&to_type) diff --git a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs index 3ebbbad3481a..97c1a4988898 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs @@ -5,7 +5,7 @@ use polars_core::series::Series; use polars_time::{datetime_range_impl, ClosedWindow, Duration}; use super::utils::{ - ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast, + ensure_range_bounds_contain_exactly_one_value, temporal_ranges_impl_broadcast, temporal_series_to_i64_scalar, }; use crate::dsl::function_expr::FieldsMapper; @@ -202,7 +202,7 @@ pub(super) fn datetime_ranges( Ok(()) }; - ranges_impl_broadcast(start, end, range_impl, &mut builder)? + temporal_ranges_impl_broadcast(start, end, range_impl, &mut builder)? }, _ => unimplemented!(), }; diff --git a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs index 8b1cb1dc05c4..e67a73297541 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs @@ -2,7 +2,7 @@ use polars_core::prelude::*; use polars_core::series::{IsSorted, Series}; use polars_core::with_match_physical_integer_polars_type; -use super::utils::{ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast}; +use super::utils::{ensure_range_bounds_contain_exactly_one_value, numeric_ranges_impl_broadcast}; const CAPACITY_FACTOR: usize = 5; @@ -70,15 +70,18 @@ where Ok(ca.into_series()) } -pub(super) fn int_ranges(s: &[Series], step: i64) -> PolarsResult { +pub(super) fn int_ranges(s: &[Series]) -> PolarsResult { let start = &s[0]; let end = &s[1]; + let step = &s[2]; let start = start.cast(&DataType::Int64)?; let end = end.cast(&DataType::Int64)?; + let step = step.cast(&DataType::Int64)?; let start = start.i64()?; let end = end.i64()?; + let step = step.i64()?; let len = std::cmp::max(start.len(), end.len()); let mut builder = ListPrimitiveChunkedBuilder::::new( @@ -88,18 +91,19 @@ pub(super) fn int_ranges(s: &[Series], step: i64) -> PolarsResult { DataType::Int64, ); - let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder| { - match step { - 1 => builder.append_iter_values(start..end), - 2.. => builder.append_iter_values((start..end).step_by(step as usize)), - _ => builder.append_iter_values( - (end..start) - .step_by(step.unsigned_abs() as usize) - .map(|x| start - (x - end)), - ), + let range_impl = + |start, end, step: i64, builder: &mut ListPrimitiveChunkedBuilder| { + match step { + 1 => builder.append_iter_values(start..end), + 2.. => builder.append_iter_values((start..end).step_by(step as usize)), + _ => builder.append_iter_values( + (end..start) + .step_by(step.unsigned_abs() as usize) + .map(|x| start - (x - end)), + ), + }; + Ok(()) }; - Ok(()) - }; - ranges_impl_broadcast(start, end, range_impl, &mut builder) + numeric_ranges_impl_broadcast(start, end, step, range_impl, &mut builder) } diff --git a/crates/polars-plan/src/dsl/function_expr/range/mod.rs b/crates/polars-plan/src/dsl/function_expr/range/mod.rs index 97da2fa5f7da..aa1add9d400d 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/mod.rs @@ -28,9 +28,7 @@ pub enum RangeFunction { step: i64, dtype: DataType, }, - IntRanges { - step: i64, - }, + IntRanges, #[cfg(feature = "temporal")] DateRange { interval: Duration, @@ -76,7 +74,7 @@ impl RangeFunction { use RangeFunction::*; let field = match self { IntRange { dtype, .. } => Field::new("int", dtype.clone()), - IntRanges { .. } => Field::new("int_range", DataType::List(Box::new(DataType::Int64))), + IntRanges => Field::new("int_range", DataType::List(Box::new(DataType::Int64))), #[cfg(feature = "temporal")] DateRange { interval, @@ -158,7 +156,7 @@ impl Display for RangeFunction { use RangeFunction::*; let s = match self { IntRange { .. } => "int_range", - IntRanges { .. } => "int_ranges", + IntRanges => "int_ranges", #[cfg(feature = "temporal")] DateRange { .. } => "date_range", #[cfg(feature = "temporal")] @@ -183,8 +181,8 @@ impl From for SpecialEq> { IntRange { step, dtype } => { map_as_slice!(int_range::int_range, step, dtype.clone()) }, - IntRanges { step } => { - map_as_slice!(int_range::int_ranges, step) + IntRanges => { + map_as_slice!(int_range::int_ranges) }, #[cfg(feature = "temporal")] DateRange { diff --git a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs index 7b03b77b4ab6..4f506ea934a6 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs @@ -3,7 +3,7 @@ use polars_core::series::Series; use polars_time::{time_range_impl, ClosedWindow, Duration}; use super::utils::{ - ensure_range_bounds_contain_exactly_one_value, ranges_impl_broadcast, + ensure_range_bounds_contain_exactly_one_value, temporal_ranges_impl_broadcast, temporal_series_to_i64_scalar, }; @@ -59,7 +59,7 @@ pub(super) fn time_ranges( Ok(()) }; - let out = ranges_impl_broadcast(start, end, range_impl, &mut builder)?; + let out = temporal_ranges_impl_broadcast(start, end, range_impl, &mut builder)?; let to_type = DataType::List(Box::new(DataType::Time)); out.cast(&to_type) diff --git a/crates/polars-plan/src/dsl/function_expr/range/utils.rs b/crates/polars-plan/src/dsl/function_expr/range/utils.rs index 100b582ffabb..b748daf0879a 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/utils.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/utils.rs @@ -1,5 +1,5 @@ use polars_core::prelude::{ - polars_bail, polars_ensure, ChunkedArray, IntoSeries, ListBuilderTrait, + polars_bail, polars_ensure, ChunkedArray, Int64Chunked, IntoSeries, ListBuilderTrait, ListPrimitiveChunkedBuilder, PolarsIntegerType, PolarsResult, Series, }; @@ -21,8 +21,124 @@ pub(super) fn ensure_range_bounds_contain_exactly_one_value( Ok(()) } +/// Create a numeric ranges column from the given start/end/step columns and a range function. +pub(super) fn numeric_ranges_impl_broadcast( + start: &ChunkedArray, + end: &ChunkedArray, + step: &Int64Chunked, + range_impl: F, + builder: &mut ListPrimitiveChunkedBuilder, +) -> PolarsResult +where + T: PolarsIntegerType, + U: PolarsIntegerType, + F: Fn(T::Native, T::Native, i64, &mut ListPrimitiveChunkedBuilder) -> PolarsResult<()>, +{ + match (start.len(), end.len(), step.len()) { + (len_start, len_end, len_step) if len_start == len_end && len_start == len_step => { + build_numeric_ranges::<_, _, _, T, U, F>( + start.downcast_iter().flatten(), + end.downcast_iter().flatten(), + step.downcast_iter().flatten(), + range_impl, + builder, + )?; + }, + (1, len_end, 1) => { + let start_scalar = start.get(0); + let step_scalar = step.get(0); + match (start_scalar, step_scalar) { + (Some(start), Some(step)) => build_numeric_ranges::<_, _, _, T, U, F>( + std::iter::repeat(Some(&start)), + end.downcast_iter().flatten(), + std::iter::repeat(Some(&step)), + range_impl, + builder, + )?, + _ => build_nulls(builder, len_end), + } + }, + (len_start, 1, 1) => { + let end_scalar = end.get(0); + let step_scalar = step.get(0); + match (end_scalar, step_scalar) { + (Some(end), Some(step)) => build_numeric_ranges::<_, _, _, T, U, F>( + start.downcast_iter().flatten(), + std::iter::repeat(Some(&end)), + std::iter::repeat(Some(&step)), + range_impl, + builder, + )?, + _ => build_nulls(builder, len_start), + } + }, + (1, 1, len_step) => { + let start_scalar = start.get(0); + let end_scalar = end.get(0); + match (start_scalar, end_scalar) { + (Some(start), Some(end)) => build_numeric_ranges::<_, _, _, T, U, F>( + std::iter::repeat(Some(&start)), + std::iter::repeat(Some(&end)), + step.downcast_iter().flatten(), + range_impl, + builder, + )?, + _ => build_nulls(builder, len_step), + } + }, + (len_start, len_end, 1) if len_start == len_end => { + let step_scalar = step.get(0); + match step_scalar { + Some(step) => build_numeric_ranges::<_, _, _, T, U, F>( + start.downcast_iter().flatten(), + end.downcast_iter().flatten(), + std::iter::repeat(Some(&step)), + range_impl, + builder, + )?, + None => build_nulls(builder, len_start), + } + }, + (len_start, 1, len_step) if len_start == len_step => { + let end_scalar = end.get(0); + match end_scalar { + Some(end) => build_numeric_ranges::<_, _, _, T, U, F>( + start.downcast_iter().flatten(), + std::iter::repeat(Some(&end)), + step.downcast_iter().flatten(), + range_impl, + builder, + )?, + None => build_nulls(builder, len_start), + } + }, + (1, len_end, len_step) if len_end == len_step => { + let start_scalar = start.get(0); + match start_scalar { + Some(start) => build_numeric_ranges::<_, _, _, T, U, F>( + std::iter::repeat(Some(&start)), + end.downcast_iter().flatten(), + step.downcast_iter().flatten(), + range_impl, + builder, + )?, + None => build_nulls(builder, len_end), + } + }, + (len_start, len_end, len_step) => { + polars_bail!( + ComputeError: + "lengths of `start` ({}), `end` ({}) and `step` ({}) do not match", + len_start, len_end, len_step + ) + }, + }; + let out = builder.finish().into_series(); + Ok(out) +} + /// Create a ranges column from the given start/end columns and a range function. -pub(super) fn ranges_impl_broadcast( +pub(super) fn temporal_ranges_impl_broadcast( start: &ChunkedArray, end: &ChunkedArray, range_impl: F, @@ -35,7 +151,7 @@ where { match (start.len(), end.len()) { (len_start, len_end) if len_start == len_end => { - build_ranges::<_, _, T, U, F>( + build_temporal_ranges::<_, _, T, U, F>( start.downcast_iter().flatten(), end.downcast_iter().flatten(), range_impl, @@ -45,7 +161,7 @@ where (1, len_end) => { let start_scalar = start.get(0); match start_scalar { - Some(start) => build_ranges::<_, _, T, U, F>( + Some(start) => build_temporal_ranges::<_, _, T, U, F>( std::iter::repeat(Some(&start)), end.downcast_iter().flatten(), range_impl, @@ -57,7 +173,7 @@ where (len_start, 1) => { let end_scalar = end.get(0); match end_scalar { - Some(end) => build_ranges::<_, _, T, U, F>( + Some(end) => build_temporal_ranges::<_, _, T, U, F>( start.downcast_iter().flatten(), std::iter::repeat(Some(&end)), range_impl, @@ -78,8 +194,33 @@ where Ok(out) } +/// Iterate over a start and end column and create a range with the step for each entry. +fn build_numeric_ranges<'a, I, J, K, T, U, F>( + start: I, + end: J, + step: K, + range_impl: F, + builder: &mut ListPrimitiveChunkedBuilder, +) -> PolarsResult<()> +where + I: Iterator>, + J: Iterator>, + K: Iterator>, + T: PolarsIntegerType, + U: PolarsIntegerType, + F: Fn(T::Native, T::Native, i64, &mut ListPrimitiveChunkedBuilder) -> PolarsResult<()>, +{ + for ((start, end), step) in start.zip(end).zip(step) { + match (start, end, step) { + (Some(start), Some(end), Some(step)) => range_impl(*start, *end, *step, builder)?, + _ => builder.append_null(), + } + } + Ok(()) +} + /// Iterate over a start and end column and create a range for each entry. -fn build_ranges<'a, I, J, T, U, F>( +fn build_temporal_ranges<'a, I, J, T, U, F>( start: I, end: J, range_impl: F, diff --git a/crates/polars-plan/src/dsl/functions/range.rs b/crates/polars-plan/src/dsl/functions/range.rs index 498f3b885c39..6b6fa1af73fb 100644 --- a/crates/polars-plan/src/dsl/functions/range.rs +++ b/crates/polars-plan/src/dsl/functions/range.rs @@ -22,12 +22,12 @@ pub fn int_range(start: Expr, end: Expr, step: i64, dtype: DataType) -> Expr { } /// Generate a range of integers for each row of the input columns. -pub fn int_ranges(start: Expr, end: Expr, step: i64) -> Expr { - let input = vec![start, end]; +pub fn int_ranges(start: Expr, end: Expr, step: Expr) -> Expr { + let input = vec![start, end, step]; Expr::Function { input, - function: FunctionExpr::Range(RangeFunction::IntRanges { step }), + function: FunctionExpr::Range(RangeFunction::IntRanges), options: FunctionOptions { allow_rename: true, ..Default::default() diff --git a/py-polars/polars/functions/range/int_range.py b/py-polars/polars/functions/range/int_range.py index 0f5c68eeb63b..b5895a825092 100644 --- a/py-polars/polars/functions/range/int_range.py +++ b/py-polars/polars/functions/range/int_range.py @@ -200,7 +200,7 @@ def int_range( def int_ranges( start: int | IntoExprColumn, end: int | IntoExprColumn, - step: int = ..., + step: int | IntoExprColumn = ..., *, dtype: PolarsIntegerType = ..., eager: Literal[False] = ..., @@ -212,7 +212,7 @@ def int_ranges( def int_ranges( start: int | IntoExprColumn, end: int | IntoExprColumn, - step: int = ..., + step: int | IntoExprColumn = ..., *, dtype: PolarsIntegerType = ..., eager: Literal[True], @@ -224,7 +224,7 @@ def int_ranges( def int_ranges( start: int | IntoExprColumn, end: int | IntoExprColumn, - step: int = ..., + step: int | IntoExprColumn = ..., *, dtype: PolarsIntegerType = ..., eager: bool, @@ -235,7 +235,7 @@ def int_ranges( def int_ranges( start: int | IntoExprColumn, end: int | IntoExprColumn, - step: int = 1, + step: int | IntoExprColumn = 1, *, dtype: PolarsIntegerType = Int64, eager: bool = False, @@ -283,6 +283,7 @@ def int_ranges( """ start = parse_as_expression(start) end = parse_as_expression(end) + step = parse_as_expression(step) result = wrap_expr(plr.int_ranges(start, end, step, dtype)) if eager: diff --git a/py-polars/src/functions/range.rs b/py-polars/src/functions/range.rs index 8ec73b53bc88..d034ef20157e 100644 --- a/py-polars/src/functions/range.rs +++ b/py-polars/src/functions/range.rs @@ -13,10 +13,10 @@ pub fn int_range(start: PyExpr, end: PyExpr, step: i64, dtype: Wrap) - } #[pyfunction] -pub fn int_ranges(start: PyExpr, end: PyExpr, step: i64, dtype: Wrap) -> PyExpr { +pub fn int_ranges(start: PyExpr, end: PyExpr, step: PyExpr, dtype: Wrap) -> PyExpr { let dtype = dtype.0; - let mut result = dsl::int_ranges(start.inner, end.inner, step); + let mut result = dsl::int_ranges(start.inner, end.inner, step.inner); if dtype != DataType::Int64 { result = result.cast(DataType::List(Box::new(dtype))) diff --git a/py-polars/tests/unit/functions/range/test_int_range.py b/py-polars/tests/unit/functions/range/test_int_range.py index 286f5f484128..6fe6b60fdf59 100644 --- a/py-polars/tests/unit/functions/range/test_int_range.py +++ b/py-polars/tests/unit/functions/range/test_int_range.py @@ -190,21 +190,40 @@ def test_int_range_non_integer_dtype() -> None: def test_int_ranges_broadcasting() -> None: df = pl.DataFrame({"int": [1, 2, 3]}) result = df.select( - pl.int_ranges("int", 3).alias("end"), - pl.int_ranges(1, "int").alias("start"), + # result column name means these columns will be broadcast + pl.int_ranges(1, pl.Series([2, 4, 6]), "int").alias("start"), + pl.int_ranges("int", 6, "int").alias("end"), + pl.int_ranges("int", pl.col("int") + 2, 1).alias("step"), + pl.int_ranges("int", 3, 1).alias("end_step"), + pl.int_ranges(1, "int", 1).alias("start_step"), + pl.int_ranges(1, 6, "int").alias("start_end"), + pl.int_ranges("int", pl.Series([4, 5, 10]), "int").alias("no_broadcast"), ) expected = pl.DataFrame( { + "start": [[1], [1, 3], [1, 4]], "end": [ + [1, 2, 3, 4, 5], + [2, 4], + [3], + ], + "step": [[1, 2], [2, 3], [3, 4]], + "end_step": [ [1, 2], [2], [], ], - "start": [ + "start_step": [ [], [1], [1, 2], ], + "start_end": [ + [1, 2, 3, 4, 5], + [1, 3, 5], + [1, 4], + ], + "no_broadcast": [[1, 2, 3], [2, 4], [3, 6, 9]], } ) assert_frame_equal(result, expected)