Skip to content

Commit

Permalink
fix(rust, python): fix arange with column/literal input (#5703)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 2, 2022
1 parent 334299e commit f95a49b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 6 deletions.
56 changes: 50 additions & 6 deletions polars/polars-lazy/polars-plan/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,41 @@ pub fn concat_lst<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(s: E) -> Expr {
#[cfg(feature = "arange")]
#[cfg_attr(docsrs, doc(cfg(feature = "arange")))]
pub fn arange(low: Expr, high: Expr, step: usize) -> Expr {
if (matches!(low, Expr::Literal(_)) && !matches!(low, Expr::Literal(LiteralValue::Series(_))))
|| matches!(high, Expr::Literal(_))
&& !matches!(high, Expr::Literal(LiteralValue::Series(_)))
{
let has_col_without_agg = |e: &Expr| {
has_expr(e, |ae| matches!(ae, Expr::Column(_)))
&&
// check if there is no aggregation
!has_expr(e, |ae| {
matches!(
ae,
Expr::Agg(_)
| Expr::Count
| Expr::AnonymousFunction {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
..
},
..
}
| Expr::Function {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
..
},
..
},
)
})
};
let has_lit = |e: &Expr| {
(matches!(e, Expr::Literal(_)) && !matches!(e, Expr::Literal(LiteralValue::Series(_))))
};

let any_column_no_agg = has_col_without_agg(&low) || has_col_without_agg(&high);
let literal_low = has_lit(&low);
let literal_high = has_lit(&high);

if (literal_low || literal_high) && !any_column_no_agg {
let f = move |sa: Series, sb: Series| {
let sa = sa.cast(&DataType::Int64)?;
let sb = sb.cast(&DataType::Int64)?;
Expand Down Expand Up @@ -354,8 +385,21 @@ pub fn arange(low: Expr, high: Expr, step: usize) -> Expr {
)
} else {
let f = move |sa: Series, sb: Series| {
let sa = sa.cast(&DataType::Int64)?;
let sb = sb.cast(&DataType::Int64)?;
let mut sa = sa.cast(&DataType::Int64)?;
let mut sb = sb.cast(&DataType::Int64)?;

if sa.len() != sb.len() {
if sa.len() == 1 {
sa = sa.new_from_index(0, sb.len())
} else if sb.len() == 1 {
sb = sb.new_from_index(0, sa.len())
} else {
let msg = format!("The length of the 'low' and 'high' arguments cannot be matched in the 'arange' expression.. \
Length of 'low': {}, length of 'high': {}", sa.len(), sb.len());
return Err(PolarsError::ComputeError(msg.into()));
}
}

let low = sa.i64()?;
let high = sb.i64()?;
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,21 @@ def test_arange_expr() -> None:
assert out3.dtype == pl.List
assert out3[0].to_list() == [0, 2]

df = pl.DataFrame({"start": [1, 2, 3, 5, 5, 5], "stop": [8, 3, 12, 8, 8, 8]})

assert df.select(pl.arange(pl.lit(1), pl.col("stop") + 1).alias("test")).to_dict(
False
) == {
"test": [
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
]
}


def test_round() -> None:
a = pl.Series("f", [1.003, 2.003])
Expand Down

0 comments on commit f95a49b

Please sign in to comment.