Skip to content

Commit

Permalink
fix shift window functions (#4390)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 13, 2022
1 parent 3b10b81 commit 14639ae
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 39 deletions.
1 change: 1 addition & 0 deletions polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ where
.map(|len| {
// safety:
// within bounds
debug_assert!((offset + len) <= array.len());
let out = unsafe { array.slice_unchecked(offset, len) };
offset += len;
out
Expand Down
16 changes: 8 additions & 8 deletions polars/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ where
}};
}

let out = match_arrow_data_type_apply_macro_ca_logical_num!(s, sub);
let out = downcast_as_macro_arg_physical!(s, sub);
finish_cast(self, out)
}
}
Expand Down Expand Up @@ -519,7 +519,7 @@ where
$ca.add(rhs).into_series()
}};
}
let out = match_arrow_data_type_apply_macro_ca_logical_num!(s, add);
let out = downcast_as_macro_arg_physical!(s, add);
finish_cast(self, out)
}
}
Expand Down Expand Up @@ -549,7 +549,7 @@ where
}};
}

let out = match_arrow_data_type_apply_macro_ca_logical_num!(s, div);
let out = downcast_as_macro_arg_physical!(s, div);
finish_cast(self, out)
}
}
Expand Down Expand Up @@ -578,7 +578,7 @@ where
$ca.mul(rhs).into_series()
}};
}
let out = match_arrow_data_type_apply_macro_ca_logical_num!(s, mul);
let out = downcast_as_macro_arg_physical!(s, mul);
finish_cast(self, out)
}
}
Expand Down Expand Up @@ -607,7 +607,7 @@ where
$ca.rem(rhs).into_series()
}};
}
let out = match_arrow_data_type_apply_macro_ca_logical_num!(s, rem);
let out = downcast_as_macro_arg_physical!(s, rem);
finish_cast(self, out)
}
}
Expand Down Expand Up @@ -680,7 +680,7 @@ where
$rhs.lhs_sub(self).into_series()
}};
}
let out = match_arrow_data_type_apply_macro_ca_logical_num!(s, sub);
let out = downcast_as_macro_arg_physical!(s, sub);

finish_cast(rhs, out)
}
Expand All @@ -691,7 +691,7 @@ where
$rhs.lhs_div(self).into_series()
}};
}
let out = match_arrow_data_type_apply_macro_ca_logical_num!(s, div);
let out = downcast_as_macro_arg_physical!(s, div);

finish_cast(rhs, out)
}
Expand All @@ -707,7 +707,7 @@ where
}};
}

let out = match_arrow_data_type_apply_macro_ca_logical_num!(s, rem);
let out = downcast_as_macro_arg_physical!(s, rem);

finish_cast(rhs, out)
}
Expand Down
4 changes: 3 additions & 1 deletion polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ macro_rules! match_arrow_data_type_apply_macro_ca {
/// Apply a macro on the Downcasted ChunkedArray's of DataTypes that are logical numerics.
/// So no logical.
#[macro_export]
macro_rules! match_arrow_data_type_apply_macro_ca_logical_num {
macro_rules! downcast_as_macro_arg_physical {
($self:expr, $macro:ident $(, $opt_args:expr)*) => {{
match $self.dtype() {
#[cfg(feature = "dtype-u8")]
Expand Down Expand Up @@ -838,6 +838,8 @@ where
B: PolarsDataType,
C: PolarsDataType,
{
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(b.len(), c.len());
match (a.chunks.len(), b.chunks.len(), c.chunks.len()) {
(1, 1, 1) => (Cow::Borrowed(a), Cow::Borrowed(b), Cow::Borrowed(c)),
(_, 1, 1) => (
Expand Down
82 changes: 78 additions & 4 deletions polars/polars-lazy/src/dsl/function_expr/shift_and_fill.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
use super::*;
use polars_core::downcast_as_macro_arg_physical;

pub(super) fn shift_and_fill(args: &mut [Series], periods: i64) -> Result<Series> {
let s = &args[0];
let fill_value = &args[1];
fn shift_and_fill_numeric<T>(
ca: &ChunkedArray<T>,
periods: i64,
fill_value: AnyValue,
) -> ChunkedArray<T>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkShiftFill<T, Option<T::Native>>,
{
let fill_value = fill_value.extract::<T::Native>();
ca.shift_and_fill(periods, fill_value)
}

fn shift_and_fill_with_mask(s: &Series, periods: i64, fill_value: &Series) -> Result<Series> {
let mask: BooleanChunked = if periods > 0 {
let len = s.len();
let mut bits = MutableBitmap::with_capacity(s.len());
Expand All @@ -21,6 +32,69 @@ pub(super) fn shift_and_fill(args: &mut [Series], periods: i64) -> Result<Series
let mask = BooleanArray::from_data_default(bits.into(), None);
mask.into()
};

s.shift(periods).zip_with_same_type(&mask, fill_value)
}

pub(super) fn shift_and_fill(args: &mut [Series], periods: i64) -> Result<Series> {
let s = &args[0];
let logical = s.dtype();
let physical = s.to_physical_repr();
let fill_value_s = &args[1];
let fill_value = fill_value_s.get(0);

use DataType::*;
match logical {
Boolean => {
let ca = s.bool().unwrap();
let fill_value = match fill_value {
AnyValue::Boolean(v) => Some(v),
AnyValue::Null => None,
_ => unimplemented!(),
};
ca.shift_and_fill(periods, fill_value)
.into_series()
.cast(logical)
}
Utf8 => {
let ca = s.utf8().unwrap();
let fill_value = match fill_value {
AnyValue::Utf8(v) => Some(v),
AnyValue::Null => None,
_ => unimplemented!(),
};
ca.shift_and_fill(periods, fill_value)
.into_series()
.cast(logical)
}
List(_) => {
let ca = s.list().unwrap();
let fill_value = match fill_value {
AnyValue::List(v) => Some(v),
AnyValue::Null => None,
_ => unimplemented!(),
};
ca.shift_and_fill(periods, fill_value.as_ref())
.into_series()
.cast(logical)
}
#[cfg(feature = "object")]
Object(_) => shift_and_fill_with_mask(s, periods, fill_value_s),
#[cfg(feature = "dtype-struct")]
Struct(_) => shift_and_fill_with_mask(s, periods, fill_value_s),
#[cfg(feature = "dtype-categorical")]
Categorical(_) => shift_and_fill_with_mask(s, periods, fill_value_s),
dt if dt.is_numeric() || dt.is_logical() => {
macro_rules! dispatch {
($ca:expr, $periods:expr, $fill_value:expr) => {{
shift_and_fill_numeric($ca, $periods, $fill_value).into_series()
}};
}

let out = downcast_as_macro_arg_physical!(physical, dispatch, periods, fill_value);
out.cast(logical)
}
_ => {
unimplemented!()
}
}
}
18 changes: 1 addition & 17 deletions polars/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct WindowExpr {
pub(crate) expr: Expr,
}

#[derive(Debug)]
#[cfg_attr(debug_assertions, derive(Debug))]
enum MapStrategy {
Join,
// explode now
Expand Down Expand Up @@ -150,22 +150,6 @@ impl WindowExpr {
// no explicit aggregations, map over the groups
//`(col("x").sum() * col("y")).over("groups")`
(false, false, AggState::AggregatedList(_)) => {
// if the output of a window expression is a list type, but not an explicit list
// e.g. due to a `col().shift()` we join back as the flattening and exploding makes
// my brain hurt atm.

// only select the aggregation columns to save allocations in computing the schema
if !self.phys_function.is_literal(){
// 'literal' would fail, but also 'count()' would fail
// so on failure ignore and continue.
if let Ok(df) = gb.df.select(&self.apply_columns) {
let schema = df.schema();
if matches!(self.phys_function.to_field(&schema).map(|fld| fld.dtype), Ok(DataType::List(_))) {
return Ok(MapStrategy::Join)
}
}
}

if sorted_keys {
if let GroupsProxy::Idx(g) = gb.get_groups() {
debug_assert!(g.is_sorted())
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,9 @@ def shift(self, periods: int = 1) -> Expr:
return wrap_expr(self._pyexpr.shift(periods))

def shift_and_fill(
self, periods: int, fill_value: int | float | bool | str | Expr | list[Any]
self,
periods: int,
fill_value: int | float | bool | str | Expr | list[Any],
) -> Expr:
"""
Shift the values by a given period and fill the parts that will be empty due to
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,13 @@ def test_cast_null_to_categorical() -> None:
assert pl.DataFrame().with_columns(
[pl.lit(None).cast(pl.Categorical).alias("nullable_enum")]
).dtypes == [pl.Categorical]


def test_shift_and_fill() -> None:
df = pl.DataFrame({"a": ["a", "b"]}).with_columns(
[pl.col("a").cast(pl.Categorical)]
)

s = df.with_column(pl.col("a").shift_and_fill(1, "c"))["a"]
assert s.dtype == pl.Categorical
assert s.to_list() == ["c", "a"]
22 changes: 14 additions & 8 deletions py-polars/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,24 @@ def test_window_functions_list_types() -> None:
"col_list": [[1], [1], [2], [2]],
}
)
assert (df.select(pl.col("col_list").shift(1).alias("list_shifted")))[
"list_shifted"
].to_list() == [None, [1], [1], [2]]

# filling with None is allowed, but does not make any sense
# as it is the same as shift.
# that's why we don't add it to the allowed types.
assert (
df.with_column(
df.select(
pl.col("col_list")
.shift_and_fill(1, [])
.over("col_int")
.shift_and_fill(1, None) # type: ignore[arg-type]
.alias("list_shifted")
)
).to_dict(False) == {
"col_int": [1, 1, 2, 2],
"col_list": [[1], [1], [2], [2]],
"list_shifted": [[[], [1]], [[], [1]], [[], [2]], [[], [2]]],
}
)["list_shifted"].to_list() == [None, [1], [1], [2]]

assert (df.select(pl.col("col_list").shift_and_fill(1, []).alias("list_shifted")))[
"list_shifted"
].to_list() == [[], [1], [1], [2]]


def test_sorted_window_expression() -> None:
Expand Down

0 comments on commit 14639ae

Please sign in to comment.