Skip to content

Commit

Permalink
feat(rust, python): Add null_on_oob parameter to expr.list.get (#…
Browse files Browse the repository at this point in the history
…15395)

Co-authored-by: James Edwards <edwjames@umich.edu>
  • Loading branch information
JamesCE2001 and edwjames committed Apr 1, 2024
1 parent 5cdeea2 commit 2c9b6c1
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 44 deletions.
7 changes: 7 additions & 0 deletions crates/polars-arrow/src/legacy/kernels/list.rs
Expand Up @@ -75,6 +75,13 @@ pub fn sublist_get(arr: &ListArray<i64>, index: i64) -> ArrayRef {
unsafe { take_unchecked(&**values, &take_by) }
}

/// Check if an index is out of bounds for at least one sublist.
pub fn index_is_oob(arr: &ListArray<i64>, index: i64) -> bool {
arr.offsets()
.lengths()
.any(|len| index.negative_to_usize(len).is_none())
}

/// Convert a list `[1, 2, 3]` to a list type of `[[1], [2], [3]]`
pub fn array_to_unit_list(array: ArrayRef) -> ListArray<i64> {
let len = array.len();
Expand Down
8 changes: 6 additions & 2 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
@@ -1,7 +1,7 @@
use std::fmt::Write;

use arrow::array::ValueSize;
use arrow::legacy::kernels::list::sublist_get;
use arrow::legacy::kernels::list::{index_is_oob, sublist_get};
use polars_core::chunked_array::builder::get_list_builder;
#[cfg(feature = "list_gather")]
use polars_core::export::num::ToPrimitive;
Expand Down Expand Up @@ -341,8 +341,12 @@ pub trait ListNameSpaceImpl: AsList {
/// So index `0` would return the first item of every sublist
/// and index `-1` would return the last item of every sublist
/// if an index is out of bounds, it will return a `None`.
fn lst_get(&self, idx: i64) -> PolarsResult<Series> {
fn lst_get(&self, idx: i64, null_on_oob: bool) -> PolarsResult<Series> {
let ca = self.as_list();
if !null_on_oob && ca.downcast_iter().any(|arr| index_is_oob(arr, idx)) {
polars_bail!(ComputeError: "get index is out of bounds");
}

let chunks = ca
.downcast_iter()
.map(|arr| sublist_get(arr, idx))
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-ops/src/chunked_array/list/to_struct.rs
Expand Up @@ -72,7 +72,7 @@ pub trait ToStruct: AsList {
(0..n_fields)
.into_par_iter()
.map(|i| {
ca.lst_get(i as i64).map(|mut s| {
ca.lst_get(i as i64, true).map(|mut s| {
s.rename(&name_generator(i));
s
})
Expand Down
33 changes: 19 additions & 14 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Expand Up @@ -21,7 +21,7 @@ pub enum ListFunction {
},
Slice,
Shift,
Get,
Get(bool),
#[cfg(feature = "list_gather")]
Gather(bool),
#[cfg(feature = "list_gather")]
Expand Down Expand Up @@ -71,7 +71,7 @@ impl ListFunction {
Sample { .. } => mapper.with_same_dtype(),
Slice => mapper.with_same_dtype(),
Shift => mapper.with_same_dtype(),
Get => mapper.map_to_list_and_array_inner_dtype(),
Get(_) => mapper.map_to_list_and_array_inner_dtype(),
#[cfg(feature = "list_gather")]
Gather(_) => mapper.with_same_dtype(),
#[cfg(feature = "list_gather")]
Expand Down Expand Up @@ -136,7 +136,7 @@ impl Display for ListFunction {
},
Slice => "slice",
Shift => "shift",
Get => "get",
Get(_) => "get",
#[cfg(feature = "list_gather")]
Gather(_) => "gather",
#[cfg(feature = "list_gather")]
Expand Down Expand Up @@ -203,9 +203,9 @@ impl From<ListFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
},
Slice => wrap!(slice),
Shift => map_as_slice!(shift),
Get => wrap!(get),
Get(null_on_oob) => wrap!(get, null_on_oob),
#[cfg(feature = "list_gather")]
Gather(null_ob_oob) => map_as_slice!(gather, null_ob_oob),
Gather(null_on_oob) => map_as_slice!(gather, null_on_oob),
#[cfg(feature = "list_gather")]
GatherEvery => map_as_slice!(gather_every),
#[cfg(feature = "list_count")]
Expand Down Expand Up @@ -414,7 +414,7 @@ pub(super) fn concat(s: &mut [Series]) -> PolarsResult<Option<Series>> {
first_ca.lst_concat(other).map(|ca| Some(ca.into_series()))
}

pub(super) fn get(s: &mut [Series]) -> PolarsResult<Option<Series>> {
pub(super) fn get(s: &mut [Series], null_on_oob: bool) -> PolarsResult<Option<Series>> {
let ca = s[0].list()?;
let index = s[1].cast(&DataType::Int64)?;
let index = index.i64().unwrap();
Expand All @@ -423,7 +423,7 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult<Option<Series>> {
1 => {
let index = index.get(0);
if let Some(index) = index {
ca.lst_get(index).map(Some)
ca.lst_get(index, null_on_oob).map(Some)
} else {
Ok(Some(Series::full_null(
ca.name(),
Expand All @@ -440,19 +440,24 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult<Option<Series>> {
let take_by = index
.into_iter()
.enumerate()
.map(|(i, opt_idx)| {
opt_idx.and_then(|idx| {
.map(|(i, opt_idx)| match opt_idx {
Some(idx) => {
let (start, end) =
unsafe { (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) };
let offset = if idx >= 0 { start + idx } else { end + idx };
if offset >= end || offset < start || start == end {
None
if null_on_oob {
Ok(None)
} else {
polars_bail!(ComputeError: "get index is out of bounds");
}
} else {
Some(offset as IdxSize)
Ok(Some(offset as IdxSize))
}
})
},
None => Ok(None),
})
.collect::<IdxCa>();
.collect::<Result<IdxCa, _>>()?;
let s = Series::try_from((ca.name(), arr.values().clone())).unwrap();
unsafe { s.take_unchecked(&take_by) }
.cast(&ca.inner_dtype())
Expand All @@ -475,7 +480,7 @@ pub(super) fn gather(args: &[Series], null_on_oob: bool) -> PolarsResult<Series>
if idx.len() == 1 && null_on_oob {
// fast path
let idx = idx.get(0)?.try_extract::<i64>()?;
let out = ca.lst_get(idx)?;
let out = ca.lst_get(idx, null_on_oob)?;
// make sure we return a list
out.reshape(&[-1, 1])
} else {
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Expand Up @@ -718,6 +718,14 @@ macro_rules! wrap {
($e:expr) => {
SpecialEq::new(Arc::new($e))
};

($e:expr, $($args:expr),*) => {{
let f = move |s: &mut [Series]| {
$e(s, $($args),*)
};

SpecialEq::new(Arc::new(f))
}};
}

// Fn(&[Series], args)
Expand Down
8 changes: 4 additions & 4 deletions crates/polars-plan/src/dsl/list.rs
Expand Up @@ -151,9 +151,9 @@ impl ListNameSpace {
}

/// Get items in every sublist by index.
pub fn get(self, index: Expr) -> Expr {
pub fn get(self, index: Expr, null_on_oob: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Get),
FunctionExpr::ListExpr(ListFunction::Get(null_on_oob)),
&[index],
false,
false,
Expand Down Expand Up @@ -187,12 +187,12 @@ impl ListNameSpace {

/// Get first item of every sublist.
pub fn first(self) -> Expr {
self.get(lit(0i64))
self.get(lit(0i64), true)
}

/// Get last item of every sublist.
pub fn last(self) -> Expr {
self.get(lit(-1i64))
self.get(lit(-1i64), true)
}

/// Join all string items in a sublist and place a separator between them.
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/src/functions.rs
Expand Up @@ -987,7 +987,7 @@ impl SQLFunctionVisitor<'_> {
// Array functions
// ----
ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s)),
ArrayGet => self.visit_binary(|e, i| e.list().get(i)),
ArrayGet => self.visit_binary(|e, i| e.list().get(i, true)),
ArrayLength => self.visit_unary(|e| e.list().len()),
ArrayMax => self.visit_unary(|e| e.list().max()),
ArrayMean => self.visit_unary(|e| e.list().mean()),
Expand Down
13 changes: 11 additions & 2 deletions py-polars/polars/expr/list.py
Expand Up @@ -505,7 +505,12 @@ def concat(self, other: list[Expr | str] | Expr | str | Series | list[Any]) -> E
other_list.insert(0, wrap_expr(self._pyexpr))
return F.concat_list(other_list)

def get(self, index: int | Expr | str) -> Expr:
def get(
self,
index: int | Expr | str,
*,
null_on_oob: bool = True,
) -> Expr:
"""
Get the value by index in the sublists.
Expand All @@ -517,6 +522,10 @@ def get(self, index: int | Expr | str) -> Expr:
----------
index
Index to return per sublist
null_on_oob
Behavior if an index is out of bounds:
True -> set as null
False -> raise an error
Examples
--------
Expand All @@ -534,7 +543,7 @@ def get(self, index: int | Expr | str) -> Expr:
└───────────┴──────┘
"""
index = parse_as_expression(index)
return wrap_expr(self._pyexpr.list_get(index))
return wrap_expr(self._pyexpr.list_get(index, null_on_oob))

def gather(
self,
Expand Down
13 changes: 11 additions & 2 deletions py-polars/polars/series/list.py
Expand Up @@ -373,7 +373,12 @@ def concat(self, other: list[Series] | Series | list[Any]) -> Series:
]
"""

def get(self, index: int | Series | list[int]) -> Series:
def get(
self,
index: int | Series | list[int],
*,
null_on_oob: bool = True,
) -> Series:
"""
Get the value by index in the sublists.
Expand All @@ -385,11 +390,15 @@ def get(self, index: int | Series | list[int]) -> Series:
----------
index
Index to return per sublist
null_on_oob
Behavior if an index is out of bounds:
True -> set as null
False -> raise an error
Examples
--------
>>> s = pl.Series("a", [[3, 2, 1], [], [1, 2]])
>>> s.list.get(0)
>>> s.list.get(0, null_on_oob=True)
shape: (3,)
Series: 'a' [i64]
[
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/list.rs
Expand Up @@ -44,8 +44,12 @@ impl PyExpr {
self.inner.clone().list().eval(expr.inner, parallel).into()
}

fn list_get(&self, index: PyExpr) -> Self {
self.inner.clone().list().get(index.inner).into()
fn list_get(&self, index: PyExpr, null_on_oob: bool) -> Self {
self.inner
.clone()
.list()
.get(index.inner, null_on_oob)
.into()
}

fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_list.py
Expand Up @@ -781,7 +781,7 @@ def test_list_gather_null_struct_14927() -> None:
{"index": [1], "col_0": [None], "field_0": [None]},
schema={**df.schema, "field_0": pl.Float64},
)
expr = pl.col("col_0").list.get(0).struct.field("field_0")
expr = pl.col("col_0").list.get(0, null_on_oob=True).struct.field("field_0")
out = df.filter(pl.col("index") > 0).with_columns(expr)
assert_frame_equal(out, expected)

Expand Down

0 comments on commit 2c9b6c1

Please sign in to comment.