Skip to content

Commit

Permalink
add new arr namespace expressions (#2803)
Browse files Browse the repository at this point in the history
* arr.arg_min
* arr.arg_max
* arr.diff
* arr.shift
  • Loading branch information
ritchie46 committed Mar 2, 2022
1 parent 40aac31 commit 3687cb4
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 11 deletions.
1 change: 1 addition & 0 deletions polars/polars-core/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Special list utility methods
mod iterator;
#[cfg(feature = "list")]
#[cfg_attr(docsrs, doc(cfg(feature = "list")))]
pub mod namespace;

use crate::prelude::*;
Expand Down
29 changes: 29 additions & 0 deletions polars/polars-core/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::chunked_array::builder::get_list_builder;
use crate::prelude::*;
use crate::series::ops::NullBehavior;
use polars_arrow::kernels::list::sublist_get;
use polars_arrow::prelude::ValueSize;
use std::convert::TryFrom;
Expand Down Expand Up @@ -137,6 +138,34 @@ impl ListChunked {
self.try_apply_amortized(|s| s.as_ref().unique())
}

pub fn lst_arg_min(&self) -> IdxCa {
let mut out: IdxCa = self
.amortized_iter()
.map(|opt_s| opt_s.and_then(|s| s.as_ref().arg_min().map(|idx| idx as IdxSize)))
.collect_trusted();
out.rename(self.name());
out
}

pub fn lst_arg_max(&self) -> IdxCa {
let mut out: IdxCa = self
.amortized_iter()
.map(|opt_s| opt_s.and_then(|s| s.as_ref().arg_max().map(|idx| idx as IdxSize)))
.collect_trusted();
out.rename(self.name());
out
}

#[cfg(feature = "diff")]
#[cfg_attr(docsrs, doc(cfg(feature = "diff")))]
pub fn lst_diff(&self, n: usize, null_behavior: NullBehavior) -> ListChunked {
self.apply_amortized(|s| s.as_ref().diff(n, null_behavior))
}

pub fn lst_shift(&self, periods: i64) -> ListChunked {
self.apply_amortized(|s| s.as_ref().shift(periods))
}

pub fn lst_lengths(&self) -> UInt32Chunked {
let mut lengths = Vec::with_capacity(self.len());
self.downcast_iter().for_each(|arr| {
Expand Down
32 changes: 23 additions & 9 deletions polars/polars-core/src/chunked_array/ops/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,18 +427,29 @@ impl<'a> ChunkApply<'a, Series, Series> for ListChunked {
F: Fn(Series) -> S::Native + Copy,
S: PolarsNumericType,
{
let dtype = self.inner_dtype();
let chunks = self
.downcast_iter()
.into_iter()
.map(|array| {
let values: Vec<_> = (0..array.len())
.map(|idx| {
let arrayref: ArrayRef = unsafe { array.value_unchecked(idx) }.into();
let series = Series::try_from(("", arrayref)).unwrap();
f(series)
})
.collect_trusted();
to_array::<S>(values, array.validity().cloned())
unsafe {
let values = array
.values_iter()
.map(|array| {
// safety
// reported dtype is correct
let series = Series::from_chunks_and_dtype_unchecked(
"",
vec![array.into()],
&dtype,
);
f(series)
})
.trust_my_length(self.len())
.collect_trusted::<Vec<_>>();

to_array::<S>(values, array.validity().cloned())
}
})
.collect();
ChunkedArray::from_chunks(self.name(), chunks)
Expand All @@ -449,14 +460,17 @@ impl<'a> ChunkApply<'a, Series, Series> for ListChunked {
F: Fn(Option<Series>) -> S::Native + Copy,
S: PolarsNumericType,
{
let dtype = self.inner_dtype();
let chunks = self
.downcast_iter()
.into_iter()
.map(|array| {
let values = array.iter().map(|x| {
let x = x.map(|x| {
let x: ArrayRef = x.into();
Series::try_from(("", x)).unwrap()
// safety
// reported dtype is correct
unsafe { Series::from_chunks_and_dtype_unchecked("", vec![x], &dtype) }
});
f(x)
});
Expand Down
43 changes: 43 additions & 0 deletions polars/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::prelude::*;
use polars_core::prelude::*;
use polars_core::series::ops::NullBehavior;

/// Specialized expressions for [`Series`] of [`DataType::List`].
pub struct ListNameSpace(pub(crate) Expr);
Expand Down Expand Up @@ -140,4 +141,46 @@ impl ListNameSpace {
)
.with_fmt("arr.join")
}

/// Return the index of the minimal value of every sublist
pub fn arg_min(self) -> Expr {
self.0
.map(
|s| Ok(s.list()?.lst_arg_min().into_series()),
GetOutput::from_type(IDX_DTYPE),
)
.with_fmt("arr.arg_min")
}

/// Return the index of the maximum value of every sublist
pub fn arg_max(self) -> Expr {
self.0
.map(
|s| Ok(s.list()?.lst_arg_max().into_series()),
GetOutput::from_type(IDX_DTYPE),
)
.with_fmt("arr.arg_max")
}

/// Diff every sublist.
#[cfg(feature = "diff")]
#[cfg_attr(docsrs, doc(cfg(feature = "diff")))]
pub fn diff(self, n: usize, null_behavior: NullBehavior) -> Expr {
self.0
.map(
move |s| Ok(s.list()?.lst_diff(n, null_behavior).into_series()),
GetOutput::same_type(),
)
.with_fmt("arr.diff")
}

/// Shift every sublist.
pub fn shift(self, periods: i64) -> Expr {
self.0
.map(
move |s| Ok(s.list()?.lst_shift(periods).into_series()),
GetOutput::same_type(),
)
.with_fmt("arr.diff")
}
}
4 changes: 4 additions & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ The following methods are available under the `expr.arr` attribute.
ExprListNameSpace.last
ExprListNameSpace.contains
ExprListNameSpace.join
ExprListNameSpace.arg_min
ExprListNameSpace.arg_max
ExprListNameSpace.diff
ExprListNameSpace.shift

Categories
----------
Expand Down
4 changes: 4 additions & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ The following methods are available under the `Series.arr` attribute.
ListNameSpace.last
ListNameSpace.contains
ListNameSpace.join
ListNameSpace.arg_min
ListNameSpace.arg_max
ListNameSpace.diff
ListNameSpace.shift

Categories
----------
Expand Down
47 changes: 46 additions & 1 deletion py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ def take(self, index: Union[List[int], "Expr", "pli.Series", np.ndarray]) -> "Ex
def shift(self, periods: int = 1) -> "Expr":
"""
Shift the values by a given period and fill the parts that will be empty due to this operation
with `Nones`.
with nulls.
Parameters
----------
Expand Down Expand Up @@ -3041,6 +3041,51 @@ def join(self, separator: str) -> "Expr":

return wrap_expr(self._pyexpr.lst_join(separator))

def arg_min(self) -> "Expr":
"""
Retrieve the index of the minimal value in every sublist
Returns
-------
Series of dtype UInt32/UInt64 (depending on compilation)
"""
return wrap_expr(self._pyexpr.lst_arg_min())

def arg_max(self) -> "Expr":
"""
Retrieve the index of the maximum value in every sublist
Returns
-------
Series of dtype UInt32/UInt64 (depending on compilation)
"""
return wrap_expr(self._pyexpr.lst_arg_max())

def diff(self, n: int = 1, null_behavior: str = "ignore") -> "Expr":
"""
Calculate the n-th discrete difference of every sublist.
Parameters
----------
n
number of slots to shift
null_behavior
{'ignore', 'drop'}
"""
return wrap_expr(self._pyexpr.lst_diff(n, null_behavior))

def shift(self, periods: int = 1) -> "Expr":
"""
Shift the values by a given period and fill the parts that will be empty due to this operation
with nulls.
Parameters
----------
periods
Number of places to shift (may be negative).
"""
return wrap_expr(self._pyexpr.lst_shift(periods))


class ExprStringNameSpace:
"""
Expand Down
48 changes: 47 additions & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3984,7 +3984,6 @@ def join(self, separator: str) -> "Series":
-------
Series of dtype Utf8
"""

return pli.select(pli.lit(wrap_s(self._s)).arr.join(separator)).to_series()

def first(self) -> "Series":
Expand Down Expand Up @@ -4017,6 +4016,53 @@ def contains(self, item: Union[float, str, bool, int, date, datetime]) -> "Serie
out = s.is_in(s_list)
return out.rename(s_list.name)

def arg_min(self) -> "Series":
"""
Retrieve the index of the minimal value in every sublist
Returns
-------
Series of dtype UInt32/UInt64 (depending on compilation)
"""
return pli.select(pli.lit(wrap_s(self._s)).arr.arg_min()).to_series()

def arg_max(self) -> "Series":
"""
Retrieve the index of the maximum value in every sublist
Returns
-------
Series of dtype UInt32/UInt64 (depending on compilation)
"""
return pli.select(pli.lit(wrap_s(self._s)).arr.arg_max()).to_series()

def diff(self, n: int = 1, null_behavior: str = "ignore") -> "Series":
"""
Calculate the n-th discrete difference of every sublist.
Parameters
----------
n
number of slots to shift
null_behavior
{'ignore', 'drop'}
"""
return pli.select(
pli.lit(wrap_s(self._s)).arr.diff(n, null_behavior)
).to_series()

def shift(self, periods: int = 1) -> "Series":
"""
Shift the values by a given period and fill the parts that will be empty due to this operation
with nulls.
Parameters
----------
periods
Number of places to shift (may be negative).
"""
return pli.select(pli.lit(wrap_s(self._s)).arr.shift(periods)).to_series()


class DateTimeNameSpace:
"""
Expand Down
17 changes: 17 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,23 @@ impl PyExpr {
self.inner.clone().arr().join(separator).into()
}

fn lst_arg_min(&self) -> Self {
self.inner.clone().arr().arg_min().into()
}

fn lst_arg_max(&self) -> Self {
self.inner.clone().arr().arg_max().into()
}

fn lst_diff(&self, n: usize, null_behavior: &str) -> PyResult<Self> {
let null_behavior = str_to_null_behavior(null_behavior)?;
Ok(self.inner.clone().arr().diff(n, null_behavior).into())
}

fn lst_shift(&self, periods: i64) -> Self {
self.inner.clone().arr().shift(periods).into()
}

fn rank(&self, method: &str, reverse: bool) -> Self {
let method = str_to_rankmethod(method).unwrap();
let options = RankOptions {
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from test_series import verify_series_and_expr_api

import polars as pl
from polars import testing

Expand Down Expand Up @@ -139,3 +141,23 @@ def test_list_arr_empty() -> None:
{"cars_first": [1, 2, 4, None], "cars_literal": [2, 1, 3, 3]}
)
assert out.frame_equal(expected)


def test_list_argminmax() -> None:
s = pl.Series("a", [[1, 2], [3, 2, 1]])
expected = pl.Series("a", [0, 2], dtype=pl.UInt32)
verify_series_and_expr_api(s, expected, "arr.arg_min")
expected = pl.Series("a", [1, 0], dtype=pl.UInt32)
verify_series_and_expr_api(s, expected, "arr.arg_max")


def test_list_shift() -> None:
s = pl.Series("a", [[1, 2], [3, 2, 1]])
expected = pl.Series("a", [[None, 1], [None, 3, 2]])
assert s.arr.shift().to_list() == expected.to_list()


def test_list_diff() -> None:
s = pl.Series("a", [[1, 2], [10, 2, 1]])
expected = pl.Series("a", [[None, 1], [None, -8, -1]])
assert s.arr.diff().to_list() == expected.to_list()

0 comments on commit 3687cb4

Please sign in to comment.