Skip to content

Commit

Permalink
Literal in groupby context, arange and repeat (#2958)
Browse files Browse the repository at this point in the history
1. This fixes the arange function to work in
the groupby context.

2. We improve performance and memory usage of literals in grouping
By doing so we found an inconstency in the literal expansion which
lead to the following expression working in grouping:

lit(1).cumsum() would lead to a list of the group length.

This now does what you expect: e.g. [1].

To get old behavior we introduce a new expression: `repeat`

repeat(1, count()).cumsum()
  • Loading branch information
ritchie46 committed Mar 24, 2022
1 parent 4447e3d commit 4d3b95c
Show file tree
Hide file tree
Showing 19 changed files with 290 additions and 72 deletions.
36 changes: 24 additions & 12 deletions polars/polars-core/src/series/ops/to_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@ use crate::prelude::*;
use polars_arrow::kernels::list::array_to_unit_list;
use std::borrow::Cow;

fn reshape_fast_path(name: &str, s: &Series) -> Series {
let chunks = s
.chunks()
.iter()
.map(|arr| Arc::new(array_to_unit_list(arr.clone())) as ArrayRef)
.collect::<Vec<_>>();

let mut ca = ListChunked::from_chunks(name, chunks);
ca.set_fast_explode();
ca.into_series()
}

impl Series {
/// Convert the values of this Series to a ListChunked with a length of 1,
/// So a Series of:
Expand All @@ -29,11 +41,21 @@ impl Series {
}

pub fn reshape(&self, dims: &[i64]) -> Result<Series> {
if dims.is_empty() {
panic!("dimensions cannot be empty")
}
let s = if let DataType::List(_) = self.dtype() {
Cow::Owned(self.explode()?)
} else {
Cow::Borrowed(self)
};

// no rows
if dims[0] == 0 {
let s = reshape_fast_path(self.name(), &s);
return Ok(s);
}

let s_ref = s.as_ref();

let mut dims = dims.to_vec();
Expand All @@ -56,9 +78,6 @@ impl Series {
}

match dims.len() {
0 => {
panic!("dimensions cannot be empty")
}
1 => Ok(s_ref.slice(0, dims[0] as usize)),
2 => {
let mut rows = dims[0];
Expand All @@ -74,15 +93,8 @@ impl Series {

// fast path, we can create a unit list so we only allocate offsets
if rows as usize == s_ref.len() && cols == 1 {
let chunks = s_ref
.chunks()
.iter()
.map(|arr| Arc::new(array_to_unit_list(arr.clone())) as ArrayRef)
.collect::<Vec<_>>();

let mut ca = ListChunked::from_chunks(self.name(), chunks);
ca.set_fast_explode();
return Ok(ca.into_series());
let s = reshape_fast_path(self.name(), s_ref);
return Ok(s);
}

let mut builder =
Expand Down
14 changes: 12 additions & 2 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ pub fn arange(low: Expr, high: Expr, step: usize) -> Expr {
Ok(Int64Chunked::from_iter_values("arange", low..high).into_series())
}
};
map_binary(
apply_binary(
low,
high,
f,
Expand Down Expand Up @@ -293,7 +293,7 @@ pub fn arange(low: Expr, high: Expr, step: usize) -> Expr {

Ok(builder.finish().into_series())
};
map_binary(
apply_binary(
low,
high,
f,
Expand Down Expand Up @@ -748,3 +748,13 @@ pub fn as_struct(exprs: &[Expr]) -> Expr {
options
})
}

pub fn repeat<L: Literal>(value: L, n_times: Expr) -> Expr {
let function = |s: Series, n: Series| {
let n = n.get(0).extract::<usize>().ok_or_else(|| {
PolarsError::ComputeError(format!("could not extract a size from {:?}", n).into())
})?;
Ok(s.expand_at_index(0, n))
};
apply_binary(lit(value), n_times, function, GetOutput::same_type())
}
4 changes: 3 additions & 1 deletion polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,8 @@ impl Expr {

#[cfg(feature = "repeat_by")]
#[cfg_attr(docsrs, doc(cfg(feature = "repeat_by")))]
/// Repeat the column `n` times, where `n` is determined by the values in `by`.
/// This yields an `Expr` of dtype `List`
pub fn repeat_by(self, by: Expr) -> Expr {
let function = |s: &mut [Series]| {
let by = &s[1];
Expand All @@ -1491,7 +1493,7 @@ impl Expr {
Ok(s.repeat_by(by.idx()?).into_series())
};

self.map_many(
self.apply_many(
function,
&[by],
GetOutput::map_dtype(|dt| DataType::List(dt.clone().into())),
Expand Down
5 changes: 4 additions & 1 deletion polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ impl PhysicalExpr for ApplyExpr {
let lists = acs
.iter_mut()
.map(|ac| {
let s = ac.aggregated();
let s = match ac.agg_state() {
AggState::AggregatedFlat(s) => s.reshape(&[-1, 1]).unwrap(),
_ => ac.aggregated(),
};
s.list().unwrap().clone()
})
.collect::<Vec<_>>();
Expand Down
10 changes: 5 additions & 5 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl PhysicalExpr for BinaryExpr {
if s.len() != df.height() =>
{
// this is a flat series of len eq to group tuples
let l = ac_l.aggregated();
let l = ac_l.aggregated_arity_operation();
let l = l.as_ref();
let arr_l = &l.chunks()[0];

Expand All @@ -145,7 +145,7 @@ impl PhysicalExpr for BinaryExpr {
let mut us = UnstableSeries::new(&dummy);

// this is now a list
let r = ac_r.aggregated();
let r = ac_r.aggregated_arity_operation();
let r = r.list().unwrap();

let mut ca: ListChunked = r
Expand Down Expand Up @@ -182,11 +182,11 @@ impl PhysicalExpr for BinaryExpr {
_,
) if s.len() != df.height() => {
// this is now a list
let l = ac_l.aggregated();
let l = ac_l.aggregated_arity_operation();
let l = l.list().unwrap();

// this is a flat series of len eq to group tuples
let r = ac_r.aggregated();
let r = ac_r.aggregated_arity_operation();
assert_eq!(l.len(), groups.len());
let r = r.as_ref();
let arr_r = &r.chunks()[0];
Expand Down Expand Up @@ -289,7 +289,7 @@ impl PhysicalAggregation for BinaryExpr {
state: &ExecutionState,
) -> Result<Option<Series>> {
let mut ac = self.evaluate_on_groups(df, groups, state)?;
let s = ac.aggregated();
let s = ac.aggregated_arity_operation();
Ok(Some(s))
}
}
Expand Down
27 changes: 24 additions & 3 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,34 @@ impl<'a> AggregationContext<'a> {
AggState::AggregatedList(s) | AggState::AggregatedFlat(s) => s,
AggState::Literal(s) => {
self.groups();
// todo! optimize this, we don't have to call agg_list, create the list directly.
let s = s.expand_at_index(0, self.groups.iter().map(|g| g.len()).sum());
s.agg_list(&self.groups).unwrap()
let rows = self.groups.len();
let s = s.expand_at_index(0, rows);
s.reshape(&[rows as i64, -1]).unwrap()
}
}
}

/// Different from aggregated, in arity operations we expect literals to expand to the size of the
/// group
/// eg:
///
/// lit(9) in groups [[1, 1], [2, 2, 2]]
/// becomes: [[9, 9], [9, 9, 9]]
///
/// where in [`Self::aggregated`] this becomes [9, 9]
///
/// this is because comparisons need to create mask that have a correct length.
fn aggregated_arity_operation(&mut self) -> Series {
if let AggState::Literal(s) = self.agg_state() {
let s = s.clone();
// // todo! optimize this, we don't have to call agg_list, create the list directly.
let s = s.expand_at_index(0, self.groups.iter().map(|g| g.len()).sum());
s.agg_list(&self.groups).unwrap()
} else {
self.aggregated()
}
}

/// Get the not-aggregated version of the series.
/// Note that we call it naive, because if a previous expr
/// has filtered or sorted this, this information is in the
Expand Down
12 changes: 6 additions & 6 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ The predicate produced {} values. Where the original DataFrame has {} values",
if s.len() != df.height() =>
{
// this is a flat series of len eq to group tuples
let truthy = ac_truthy.aggregated();
let truthy = ac_truthy.aggregated_arity_operation();
let truthy = truthy.as_ref();
let arr_truthy = &truthy.chunks()[0];
assert_eq!(truthy.len(), groups.len());
Expand All @@ -81,11 +81,11 @@ The predicate produced {} values. Where the original DataFrame has {} values",
let mut us = UnstableSeries::new(&dummy);

// this is now a list
let falsy = ac_falsy.aggregated();
let falsy = ac_falsy.aggregated_arity_operation();
let falsy = falsy.as_ref();
let falsy = falsy.list().unwrap();

let mask = ac_mask.aggregated();
let mask = ac_mask.aggregated_arity_operation();
let mask = mask.as_ref();
let mask = mask.list()?;
if !matches!(mask.inner_dtype(), DataType::Boolean) {
Expand Down Expand Up @@ -128,12 +128,12 @@ The predicate produced {} values. Where the original DataFrame has {} values",
if s.len() != df.height() =>
{
// this is now a list
let truthy = ac_truthy.aggregated();
let truthy = ac_truthy.aggregated_arity_operation();
let truthy = truthy.as_ref();
let truthy = truthy.list().unwrap();

// this is a flat series of len eq to group tuples
let falsy = ac_falsy.aggregated();
let falsy = ac_falsy.aggregated_arity_operation();
assert_eq!(falsy.len(), groups.len());
let falsy = falsy.as_ref();
let arr_falsy = &falsy.chunks()[0];
Expand All @@ -144,7 +144,7 @@ The predicate produced {} values. Where the original DataFrame has {} values",
let dummy = Series::try_from(("dummy", vec![arr_falsy.clone()])).unwrap();
let mut us = UnstableSeries::new(&dummy);

let mask = ac_mask.aggregated();
let mask = ac_mask.aggregated_arity_operation();
let mask = mask.as_ref();
let mask = mask.list()?;
if !matches!(mask.inner_dtype(), DataType::Boolean) {
Expand Down
20 changes: 20 additions & 0 deletions polars/tests/it/lazy/expressions/apply.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use super::*;

#[test]
#[cfg(feature = "arange")]
fn test_arange_agg() -> Result<()> {
let df = df![
"x" => [5, 5, 4, 4, 2, 2]
]?;

let out = df
.lazy()
.with_columns([arange(lit(0i32), count(), 1).over([col("x")])])
.collect()?;
assert_eq!(
Vec::from_iter(out.column("literal")?.i64()?.into_no_null_iter()),
&[0, 1, 0, 1, 0, 1]
);

Ok(())
}
1 change: 1 addition & 0 deletions polars/tests/it/lazy/expressions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod apply;
mod arity;
mod is_in;
mod slice;
Expand Down
2 changes: 1 addition & 1 deletion polars/tests/it/lazy/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ fn test_literal_window_fn() -> Result<()> {

let out = df
.lazy()
.select([lit(1)
.select([repeat(1, count())
.cumsum(false)
.list()
.over([col("chars")])
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ These functions can be used as expression and sometimes also in eager contexts.
groups
quantile
arange
repeat
argsort_by
concat_str
concat_list
Expand Down
9 changes: 2 additions & 7 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,7 @@ def version() -> str:
DataFrame,
wrap_df,
)
from polars.internals.functions import (
arg_where,
concat,
date_range,
get_dummies,
repeat,
)
from polars.internals.functions import arg_where, concat, date_range, get_dummies
from polars.internals.lazy_frame import LazyFrame
from polars.internals.lazy_functions import _date as date
from polars.internals.lazy_functions import _datetime as datetime
Expand Down Expand Up @@ -94,6 +88,7 @@ def version() -> str:
n_unique,
pearson_corr,
quantile,
repeat,
select,
spearman_rank_corr,
std,
Expand Down
24 changes: 0 additions & 24 deletions py-polars/polars/internals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional, Sequence, Union, overload

from polars import internals as pli
from polars.datatypes import py_type_to_dtype
from polars.utils import (
_datetime_to_pl_timestamp,
_timedelta_to_pl_duration,
Expand Down Expand Up @@ -128,29 +127,6 @@ def concat(
return out


def repeat(
val: Union[int, float, str, bool], n: int, name: Optional[str] = None
) -> "pli.Series":
"""
Repeat a single value n times and collect into a Series.
Parameters
----------
val
Value to repeat.
n
Number of repeats.
name
Optional name of the Series.
"""
if name is None:
name = ""

dtype = py_type_to_dtype(type(val))
s = pli.Series._repeat(name, val, n, dtype)
return s


def arg_where(mask: "pli.Series") -> "pli.Series":
"""
Get index values where Boolean mask evaluate True.
Expand Down

0 comments on commit 4d3b95c

Please sign in to comment.