Skip to content

Commit

Permalink
feat(rust, python): add arr.take expression (#6116)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 8, 2023
1 parent 9565988 commit 05447ab
Show file tree
Hide file tree
Showing 18 changed files with 232 additions and 9 deletions.
3 changes: 2 additions & 1 deletion polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ list_eval = ["polars-lazy/list_eval"]
cumulative_eval = ["polars-lazy/cumulative_eval"]
chunked_ids = ["polars-core/chunked_ids", "polars-lazy/chunked_ids", "polars-core/chunked_ids"]
to_dummies = ["polars-ops/to_dummies"]
bigidx = ["polars-core/bigidx", "polars-lazy/bigidx"]
bigidx = ["polars-core/bigidx", "polars-lazy/bigidx", "polars-ops/big_idx"]
list_to_struct = ["polars-ops/list_to_struct", "polars-lazy/list_to_struct"]
list_take = ["polars-ops/list_take", "polars-lazy/list_take"]
describe = ["polars-core/describe"]
timezones = ["polars-core/timezones", "polars-lazy/timezones"]
string_justify = ["polars-lazy/string_justify", "polars-ops/string_justify"]
Expand Down
21 changes: 13 additions & 8 deletions polars/polars-arrow/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@
use arrow::array::UInt32Array;
#[cfg(feature = "bigidx")]
use arrow::array::UInt64Array;
use num::{NumCast, Signed, Zero};

pub trait IndexToUsize {
/// Translate the negative index to an offset.
fn negative_to_usize(self, index: usize) -> Option<usize>;
fn negative_to_usize(self, len: usize) -> Option<usize>;
}

impl IndexToUsize for i64 {
fn negative_to_usize(self, index: usize) -> Option<usize> {
if self >= 0 && (self as usize) < index {
Some(self as usize)
impl<I> IndexToUsize for I
where
I: PartialOrd + PartialEq + NumCast + Signed + Zero,
{
#[inline]
fn negative_to_usize(self, len: usize) -> Option<usize> {
if self >= Zero::zero() && (self.to_usize().unwrap()) < len {
Some(self.to_usize().unwrap())
} else {
let subtract = self.unsigned_abs() as usize;
if subtract > index {
let subtract = self.abs().to_usize().unwrap();
if subtract > len {
None
} else {
Some(index - subtract)
Some(len - subtract)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ date_offset = ["polars-plan/date_offset"]
trigonometry = ["polars-plan/trigonometry"]
sign = ["polars-plan/sign"]
timezones = ["polars-plan/timezones"]
list_take = ["polars-ops/list_take", "polars-plan/list_take"]

true_div = ["polars-plan/true_div"]

Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dtype-struct = ["polars-core/dtype-struct"]
dtype-binary = ["polars-core/dtype-binary"]
object = ["polars-core/object"]
date_offset = ["polars-time"]
list_take = ["polars-ops/list_take"]
trigonometry = []
sign = []
timezones = ["polars-time/timezones", "polars-core/timezones"]
Expand Down
21 changes: 21 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub enum ListFunction {
Contains,
Slice,
Get,
#[cfg(feature = "list_take")]
Take,
}

impl Display for ListFunction {
Expand All @@ -22,6 +24,8 @@ impl Display for ListFunction {
Contains => "contains",
Slice => "slice",
Get => "get",
#[cfg(feature = "list_take")]
Take => "take",
};
write!(f, "{name}")
}
Expand Down Expand Up @@ -185,3 +189,20 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult<Series> {

}
}

#[cfg(feature = "list_take")]
pub(super) fn take(args: &[Series]) -> PolarsResult<Series> {
let ca = &args[0];
let idx = &args[1];
let ca = ca.list()?;

if idx.len() == 1 {
// fast path
let idx = idx.get(0)?.try_extract::<i64>()?;
let out = ca.lst_get(idx)?;
// make sure we return a list
out.reshape(&[-1, 1])
} else {
ca.lst_take(idx)
}
}
2 changes: 2 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Contains => wrap!(list::contains),
Slice => wrap!(list::slice),
Get => wrap!(list::get),
#[cfg(feature = "list_take")]
Take => map_as_slice!(list::take),
}
}
#[cfg(feature = "dtype-struct")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ impl FunctionExpr {
Contains => with_dtype(DataType::Boolean),
Slice => same_type(),
Get => inner_type_list(),
#[cfg(feature = "list_take")]
Take => same_type(),
}
}
#[cfg(feature = "dtype-struct")]
Expand Down
8 changes: 8 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ impl ListNameSpace {
.map_many_private(FunctionExpr::ListExpr(ListFunction::Get), &[index], false)
}

/// Get items in every sublist by multiple indexes.
#[cfg(feature = "list_take")]
#[cfg_attr(docsrs, doc(cfg(feature = "list_take")))]
pub fn take(self, index: Expr) -> Expr {
self.0
.map_many_private(FunctionExpr::ListExpr(ListFunction::Take), &[index], false)
}

/// Get first item of every sublist.
pub fn first(self) -> Expr {
self.get(lit(0i64))
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dtype-i16 = ["polars-core/dtype-i16"]
object = ["polars-core/object"]
propagate_nans = []
performant = ["polars-core/performant"]
big_idx = ["polars-core/bigidx"]

# ops
to_dummies = []
Expand All @@ -56,3 +57,4 @@ cross_join = ["polars-core/cross_join"]
chunked_ids = ["polars-core/chunked_ids"]
asof_join = ["polars-core/asof_join"]
semi_anti_join = ["polars-core/semi_anti_join"]
list_take = []
122 changes: 122 additions & 0 deletions polars/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::fmt::Write;
use polars_arrow::kernels::list::sublist_get;
use polars_arrow::prelude::ValueSize;
use polars_core::chunked_array::builder::get_list_builder;
#[cfg(feature = "list_take")]
use polars_core::export::num::{NumCast, Signed, Zero};
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
use polars_core::utils::{try_get_supertype, CustomIterTools};
Expand Down Expand Up @@ -213,6 +215,72 @@ pub trait ListNameSpaceImpl: AsList {
Series::try_from((ca.name(), chunks))
}

#[cfg(feature = "list_take")]
fn lst_take(&self, idx: &Series) -> PolarsResult<Series> {
let list_ca = self.as_list();

let index_typed_index = |idx: &Series| {
let other = idx.cast(&IDX_DTYPE).unwrap();
let idx = other.idx().unwrap();
list_ca
.amortized_iter()
.map(|s| s.map(|s| s.as_ref().take(idx)).transpose())
.collect::<PolarsResult<ListChunked>>()
.map(|mut ca| {
ca.rename(list_ca.name());
ca.into_series()
})
};

use DataType::*;
match idx.dtype() {
List(_) => {
let idx_ca = idx.list().unwrap();
let mut out = list_ca
.amortized_iter()
.zip(idx_ca.into_iter())
.map(|(opt_s, opt_idx)| {
{
match (opt_s, opt_idx) {
(Some(s), Some(idx)) => take_series(s.as_ref(), idx),
_ => None,
}
}
.transpose()
})
.collect::<PolarsResult<ListChunked>>()?;
out.rename(list_ca.name());

Ok(out.into_series())
}
UInt32 | UInt64 => index_typed_index(idx),
dt if dt.is_signed() => {
if let Some(min) = idx.min::<i64>() {
if min > 0 {
let idx = idx.cast(&IDX_DTYPE).unwrap();
index_typed_index(&idx)
} else {
let mut out = list_ca
.amortized_iter()
.map(|opt_s| {
opt_s
.and_then(|s| take_series(s.as_ref(), idx.clone()))
.transpose()
})
.collect::<PolarsResult<ListChunked>>()?;
out.rename(list_ca.name());
Ok(out.into_series())
}
} else {
Err(PolarsError::ComputeError("All indices are null".into()))
}
}
dt => Err(PolarsError::ComputeError(
format!("Cannot use dtype: '{dt}' as index.").into(),
)),
}
}

fn lst_concat(&self, other: &[Series]) -> PolarsResult<ListChunked> {
let ca = self.as_list();
let other_len = other.len();
Expand Down Expand Up @@ -360,3 +428,57 @@ pub trait ListNameSpaceImpl: AsList {
}

impl ListNameSpaceImpl for ListChunked {}

#[cfg(feature = "list_take")]
fn take_series(s: &Series, idx: Series) -> Option<PolarsResult<Series>> {
let len = s.len();
let idx = cast_index(idx, len);
let idx = idx.idx().unwrap();
Some(s.take(idx))
}

#[cfg(feature = "list_take")]
fn cast_index_ca<T: PolarsNumericType>(idx: &ChunkedArray<T>, len: usize) -> Series
where
T::Native: Copy + PartialOrd + PartialEq + NumCast + Signed + Zero,
{
idx.into_iter()
.map(|opt_idx| opt_idx.and_then(|idx| idx.negative_to_usize(len).map(|idx| idx as IdxSize)))
.collect::<IdxCa>()
.into_series()
}

#[cfg(feature = "list_take")]
fn cast_index(idx: Series, len: usize) -> Series {
use DataType::*;
match idx.dtype() {
#[cfg(feature = "big_idx")]
UInt32 => idx.cast(&IDX_DTYPE).unwrap(),
#[cfg(feature = "big_idx")]
UInt64 => idx,
#[cfg(not(feature = "big_idx"))]
UInt64 => idx.cast(&IDX_DTYPE).unwrap(),
#[cfg(not(feature = "big_idx"))]
UInt32 => idx,
dt if dt.is_unsigned() => idx.cast(&IDX_DTYPE).unwrap(),
Int8 => {
let a = idx.i8().unwrap();
cast_index_ca(a, len)
}
Int16 => {
let a = idx.i16().unwrap();
cast_index_ca(a, len)
}
Int32 => {
let a = idx.i32().unwrap();
cast_index_ca(a, len)
}
Int64 => {
let a = idx.i64().unwrap();
cast_index_ca(a, len)
}
_ => {
unreachable!()
}
}
}
1 change: 1 addition & 0 deletions polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
//! - `interpolate` [interpolate None values](crate::chunked_array::ops::Interpolate)
//! - `extract_jsonpath` - [Run jsonpath queries on Utf8Chunked](https://goessner.net/articles/JsonPath/)
//! - `list` - List utils.
//! - `list_take` take sublist by multiple indices
//! - `rank` - Ranking algorithms.
//! - `moment` - kurtosis and skew statistics
//! - `ewma` - Exponential moving average windows
Expand Down
2 changes: 2 additions & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ performant = ["polars/performant"]
timezones = ["polars/timezones"]
cse = ["polars/cse"]
merge_sorted = ["polars/merge_sorted"]
list_take = ["polars/list_take"]

all = [
"json",
Expand Down Expand Up @@ -100,6 +101,7 @@ all = [
"polars/binary_encoding",
"streaming",
"performant",
"list_take",
]

# we cannot conditionaly activate simd
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ The following methods are available under the `expr.arr` attribute.
Expr.arr.sort
Expr.arr.sum
Expr.arr.tail
Expr.arr.take
Expr.arr.to_struct
Expr.arr.unique
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ The following methods are available under the `Series.arr` attribute.
Series.arr.sort
Series.arr.sum
Series.arr.tail
Series.arr.take
Series.arr.to_struct
Series.arr.unique
18 changes: 18 additions & 0 deletions py-polars/polars/internals/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,24 @@ def get(self, index: int | pli.Expr | str) -> pli.Expr:
index = pli.expr_to_lit_or_expr(index, str_to_lit=False)._pyexpr
return pli.wrap_expr(self._pyexpr.lst_get(index))

def take(self, index: pli.Expr | pli.Series | list[int]) -> pli.Expr:
"""
Take sublists by multiple indices.
The indices may be defined in a single column, or by sublists in another
column of dtype ``List``.
Parameters
----------
index
Indices to return per sublist
"""
if isinstance(index, list):
index = pli.Series(index)
index = pli.expr_to_lit_or_expr(index, str_to_lit=False)._pyexpr
return pli.wrap_expr(self._pyexpr.lst_take(index))

def __getitem__(self, item: int) -> pli.Expr:
return self.get(item)

Expand Down
14 changes: 14 additions & 0 deletions py-polars/polars/internals/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ def get(self, index: int | pli.Series | list[int]) -> pli.Series:
"""

def take(self, index: pli.Series | list[int]) -> pli.Series:
"""
Take sublists by multiple indices.
The indices may be defined in a single column, or by sublists in another
column of dtype ``List``.
Parameters
----------
index
Indices to return per sublist
"""

def __getitem__(self, item: int) -> pli.Series:
return self.get(item)

Expand Down
5 changes: 5 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,11 @@ impl PyExpr {
self.inner.clone().arr().get(index.inner).into()
}

#[cfg(feature = "list_take")]
fn lst_take(&self, index: PyExpr) -> Self {
self.inner.clone().arr().take(index.inner).into()
}

fn lst_join(&self, separator: &str) -> Self {
self.inner.clone().arr().join(separator).into()
}
Expand Down

0 comments on commit 05447ab

Please sign in to comment.