Skip to content

Commit

Permalink
feat(rust, python): Add null_on_oob parameter to expr.array.get (#…
Browse files Browse the repository at this point in the history
…15426)

Co-authored-by: James Edwards <edwjames@umich.edu>
Co-authored-by: Ritchie Vink <ritchie46@gmail.com>
  • Loading branch information
3 people committed Apr 10, 2024
1 parent 89468f7 commit 8440457
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 46 deletions.
26 changes: 22 additions & 4 deletions crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs
@@ -1,3 +1,5 @@
use polars_error::{polars_bail, PolarsResult};
use polars_utils::index::NullCount;
use polars_utils::IdxSize;

use crate::array::{ArrayRef, FixedSizeListArray, PrimitiveArray};
Expand Down Expand Up @@ -38,18 +40,34 @@ fn sub_fixed_size_list_get_indexes(width: usize, index: &PrimitiveArray<i64>) ->
.collect_trusted()
}

pub fn sub_fixed_size_list_get_literal(arr: &FixedSizeListArray, index: i64) -> ArrayRef {
pub fn sub_fixed_size_list_get_literal(
arr: &FixedSizeListArray,
index: i64,
null_on_oob: bool,
) -> PolarsResult<ArrayRef> {
let take_by = sub_fixed_size_list_get_indexes_literal(arr.size(), arr.len(), index);
if !null_on_oob && take_by.null_count() > 0 {
polars_bail!(ComputeError: "get index is out of bounds");
}

let values = arr.values();
// SAFETY:
// the indices we generate are in bounds
unsafe { take_unchecked(&**values, &take_by) }
unsafe { Ok(take_unchecked(&**values, &take_by)) }
}

pub fn sub_fixed_size_list_get(arr: &FixedSizeListArray, index: &PrimitiveArray<i64>) -> ArrayRef {
pub fn sub_fixed_size_list_get(
arr: &FixedSizeListArray,
index: &PrimitiveArray<i64>,
null_on_oob: bool,
) -> PolarsResult<ArrayRef> {
let take_by = sub_fixed_size_list_get_indexes(arr.size(), index);
if !null_on_oob && take_by.null_count() > 0 {
polars_bail!(ComputeError: "get index is out of bounds");
}

let values = arr.values();
// SAFETY:
// the indices we generate are in bounds
unsafe { take_unchecked(&**values, &take_by) }
unsafe { Ok(take_unchecked(&**values, &take_by)) }
}
41 changes: 34 additions & 7 deletions crates/polars-ops/src/chunked_array/array/get.rs
@@ -1,15 +1,16 @@
use arrow::array::Array;
use arrow::legacy::kernels::fixed_size_list::{
sub_fixed_size_list_get, sub_fixed_size_list_get_literal,
};
use polars_core::prelude::arity::binary_to_series;
use polars_core::utils::align_chunks_binary;

use super::*;

fn array_get_literal(ca: &ArrayChunked, idx: i64) -> PolarsResult<Series> {
fn array_get_literal(ca: &ArrayChunked, idx: i64, null_on_oob: bool) -> PolarsResult<Series> {
let chunks = ca
.downcast_iter()
.map(|arr| sub_fixed_size_list_get_literal(arr, idx))
.collect::<Vec<_>>();
.map(|arr| sub_fixed_size_list_get_literal(arr, idx, null_on_oob))
.collect::<PolarsResult<Vec<_>>>()?;
Series::try_from((ca.name(), chunks))
.unwrap()
.cast(&ca.inner_dtype())
Expand All @@ -19,18 +20,24 @@ fn array_get_literal(ca: &ArrayChunked, idx: i64) -> PolarsResult<Series> {
/// So index `0` would return the first item of every sub-array
/// and index `-1` would return the last item of every sub-array
/// if an index is out of bounds, it will return a `None`.
pub fn array_get(ca: &ArrayChunked, index: &Int64Chunked) -> PolarsResult<Series> {
pub fn array_get(
ca: &ArrayChunked,
index: &Int64Chunked,
null_on_oob: bool,
) -> PolarsResult<Series> {
match index.len() {
1 => {
let index = index.get(0);
if let Some(index) = index {
array_get_literal(ca, index)
array_get_literal(ca, index, null_on_oob)
} else {
Ok(Series::full_null(ca.name(), ca.len(), &ca.inner_dtype()))
}
},
len if len == ca.len() => {
let out = binary_to_series(ca, index, |arr, idx| sub_fixed_size_list_get(arr, idx));
let out = binary_to_series_arr_get(ca, index, null_on_oob, |arr, idx, nob| {
sub_fixed_size_list_get(arr, idx, nob)
});
out?.cast(&ca.inner_dtype())
},
len => polars_bail!(
Expand All @@ -40,3 +47,23 @@ pub fn array_get(ca: &ArrayChunked, index: &Int64Chunked) -> PolarsResult<Series
),
}
}

pub fn binary_to_series_arr_get<T, U, F>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
null_on_oob: bool,
mut op: F,
) -> PolarsResult<Series>
where
T: PolarsDataType,
U: PolarsDataType,
F: FnMut(&T::Array, &U::Array, bool) -> PolarsResult<Box<dyn Array>>,
{
let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let chunks = lhs
.downcast_iter()
.zip(rhs.downcast_iter())
.map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr, null_on_oob))
.collect::<PolarsResult<Vec<_>>>()?;
Series::try_from((lhs.name(), chunks))
}
4 changes: 2 additions & 2 deletions crates/polars-ops/src/chunked_array/array/namespace.rs
Expand Up @@ -122,9 +122,9 @@ pub trait ArrayNameSpace: AsArray {
})
}

fn array_get(&self, index: &Int64Chunked) -> PolarsResult<Series> {
fn array_get(&self, index: &Int64Chunked, null_on_oob: bool) -> PolarsResult<Series> {
let ca = self.as_array();
array_get(ca, index)
array_get(ca, index, null_on_oob)
}

fn array_join(&self, separator: &StringChunked, ignore_nulls: bool) -> PolarsResult<Series> {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-ops/src/chunked_array/array/to_struct.rs
Expand Up @@ -28,7 +28,7 @@ pub trait ToStruct: AsArray {
(0..n_fields)
.into_par_iter()
.map(|i| {
ca.array_get(&Int64Chunked::from_slice("", &[i as i64]))
ca.array_get(&Int64Chunked::from_slice("", &[i as i64]), true)
.map(|mut s| {
s.rename(&name_generator(i));
s
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/array.rs
Expand Up @@ -105,9 +105,9 @@ impl ArrayNameSpace {
}

/// Get items in every sub-array by index.
pub fn get(self, index: Expr) -> Expr {
pub fn get(self, index: Expr, null_on_oob: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::ArrayExpr(ArrayFunction::Get),
FunctionExpr::ArrayExpr(ArrayFunction::Get(null_on_oob)),
&[index],
false,
false,
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/array.rs
Expand Up @@ -23,7 +23,7 @@ pub enum ArrayFunction {
Reverse,
ArgMin,
ArgMax,
Get,
Get(bool),
Join(bool),
#[cfg(feature = "is_in")]
Contains,
Expand All @@ -49,7 +49,7 @@ impl ArrayFunction {
Sort(_) => mapper.with_same_dtype(),
Reverse => mapper.with_same_dtype(),
ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE),
Get => mapper.map_to_list_and_array_inner_dtype(),
Get(_) => mapper.map_to_list_and_array_inner_dtype(),
Join(_) => mapper.with_dtype(DataType::String),
#[cfg(feature = "is_in")]
Contains => mapper.with_dtype(DataType::Boolean),
Expand Down Expand Up @@ -89,7 +89,7 @@ impl Display for ArrayFunction {
Reverse => "reverse",
ArgMin => "arg_min",
ArgMax => "arg_max",
Get => "get",
Get(_) => "get",
Join(_) => "join",
#[cfg(feature = "is_in")]
Contains => "contains",
Expand Down Expand Up @@ -122,7 +122,7 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Reverse => map!(reverse),
ArgMin => map!(arg_min),
ArgMax => map!(arg_max),
Get => map_as_slice!(get),
Get(null_on_oob) => map_as_slice!(get, null_on_oob),
Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
#[cfg(feature = "is_in")]
Contains => map_as_slice!(contains),
Expand Down Expand Up @@ -201,11 +201,11 @@ pub(super) fn arg_max(s: &Series) -> PolarsResult<Series> {
Ok(s.array()?.array_arg_max().into_series())
}

pub(super) fn get(s: &[Series]) -> PolarsResult<Series> {
pub(super) fn get(s: &[Series], null_on_oob: bool) -> PolarsResult<Series> {
let ca = s[0].array()?;
let index = s[1].cast(&DataType::Int64)?;
let index = index.i64().unwrap();
ca.array_get(index)
ca.array_get(index, null_on_oob)
}

pub(super) fn join(s: &[Series], ignore_nulls: bool) -> PolarsResult<Series> {
Expand Down
30 changes: 17 additions & 13 deletions py-polars/polars/expr/array.py
Expand Up @@ -441,7 +441,7 @@ def arg_max(self) -> Expr:
"""
return wrap_expr(self._pyexpr.arr_arg_max())

def get(self, index: int | IntoExprColumn) -> Expr:
def get(self, index: int | IntoExprColumn, *, null_on_oob: bool = True) -> Expr:
"""
Get the value by index in the sub-arrays.
Expand All @@ -453,28 +453,32 @@ def get(self, index: int | IntoExprColumn) -> Expr:
----------
index
Index to return per sub-array
null_on_oob
Behavior if an index is out of bounds:
True -> set as null
False -> raise an error
Examples
--------
>>> df = pl.DataFrame(
... {"arr": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "idx": [1, -2, 4]},
... {"arr": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "idx": [1, -2, 0]},
... schema={"arr": pl.Array(pl.Int32, 3), "idx": pl.Int32},
... )
>>> df.with_columns(get=pl.col("arr").arr.get("idx"))
>>> df.with_columns(get=pl.col("arr").arr.get("idx", null_on_oob=True))
shape: (3, 3)
┌───────────────┬─────┬─────
│ arr ┆ idx ┆ get
│ --- ┆ --- ┆ ---
│ array[i32, 3] ┆ i32 ┆ i32
╞═══════════════╪═════╪═════
│ [1, 2, 3] ┆ 1 ┆ 2
│ [4, 5, 6] ┆ -2 ┆ 5
│ [7, 8, 9] ┆ 4null
└───────────────┴─────┴─────
┌───────────────┬─────┬─────┐
│ arr ┆ idx ┆ get │
│ --- ┆ --- ┆ --- │
│ array[i32, 3] ┆ i32 ┆ i32 │
╞═══════════════╪═════╪═════╡
│ [1, 2, 3] ┆ 1 ┆ 2 │
│ [4, 5, 6] ┆ -2 ┆ 5 │
│ [7, 8, 9] ┆ 07
└───────────────┴─────┴─────┘
"""
index = parse_as_expression(index)
return wrap_expr(self._pyexpr.arr_get(index))
return wrap_expr(self._pyexpr.arr_get(index, null_on_oob))

def first(self) -> Expr:
"""
Expand Down
10 changes: 7 additions & 3 deletions py-polars/polars/series/array.py
Expand Up @@ -330,7 +330,7 @@ def arg_max(self) -> Series:
"""

def get(self, index: int | IntoExprColumn) -> Series:
def get(self, index: int | IntoExprColumn, *, null_on_oob: bool = True) -> Series:
"""
Get the value by index in the sub-arrays.
Expand All @@ -342,6 +342,10 @@ def get(self, index: int | IntoExprColumn) -> Series:
----------
index
Index to return per sublist
null_on_oob
Behavior if an index is out of bounds:
True -> set as null
False -> raise an error
Returns
-------
Expand All @@ -353,13 +357,13 @@ def get(self, index: int | IntoExprColumn) -> Series:
>>> s = pl.Series(
... "a", [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=pl.Array(pl.Int32, 3)
... )
>>> s.arr.get(pl.Series([1, -2, 4]))
>>> s.arr.get(pl.Series([1, -2, 0]), null_on_oob=True)
shape: (3,)
Series: 'a' [i32]
[
2
5
null
7
]
"""
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/array.rs
Expand Up @@ -80,8 +80,12 @@ impl PyExpr {
self.inner.clone().arr().arg_max().into()
}

fn arr_get(&self, index: PyExpr) -> Self {
self.inner.clone().arr().get(index.inner).into()
fn arr_get(&self, index: PyExpr, null_on_oob: bool) -> Self {
self.inner
.clone()
.arr()
.get(index.inner, null_on_oob)
.into()
}

fn arr_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self {
Expand Down
60 changes: 54 additions & 6 deletions py-polars/tests/unit/namespaces/array/test_array.py
Expand Up @@ -135,27 +135,75 @@ def test_array_get() -> None:
)

# Test index literal.
out = s.arr.get(1)
out = s.arr.get(1, null_on_oob=False)
expected = pl.Series("a", [2, 6, 8], dtype=pl.Int64)
assert_series_equal(out, expected)

# Null index literal.
out_df = s.to_frame().select(pl.col.a.arr.get(pl.lit(None)))
out_df = s.to_frame().select(pl.col.a.arr.get(pl.lit(None), null_on_oob=False))
expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame()
assert_frame_equal(out_df, expected_df)

# Out-of-bounds index literal.
out = s.arr.get(100)
with pytest.raises(pl.ComputeError, match="get index is out of bounds"):
out = s.arr.get(100, null_on_oob=False)

# Negative index literal.
out = s.arr.get(-2, null_on_oob=False)
expected = pl.Series("a", [3, None, 9], dtype=pl.Int64)
assert_series_equal(out, expected)

# Test index expr.
with pytest.raises(pl.ComputeError, match="get index is out of bounds"):
out = s.arr.get(pl.Series([1, -2, 100]), null_on_oob=False)

out = s.arr.get(pl.Series([1, -2, 0]), null_on_oob=False)
expected = pl.Series("a", [2, None, 7], dtype=pl.Int64)
assert_series_equal(out, expected)

# Test logical type.
s = pl.Series(
"a",
[
[datetime.date(1999, 1, 1), datetime.date(2000, 1, 1)],
[datetime.date(2001, 10, 1), None],
[None, None],
],
dtype=pl.Array(pl.Date, 2),
)
with pytest.raises(pl.ComputeError, match="get index is out of bounds"):
out = s.arr.get(pl.Series([1, -2, 4]), null_on_oob=False)


def test_array_get_null_on_oob() -> None:
s = pl.Series(
"a",
[[1, 2, 3, 4], [5, 6, None, None], [7, 8, 9, 10]],
dtype=pl.Array(pl.Int64, 4),
)

# Test index literal.
out = s.arr.get(1, null_on_oob=True)
expected = pl.Series("a", [2, 6, 8], dtype=pl.Int64)
assert_series_equal(out, expected)

# Null index literal.
out_df = s.to_frame().select(pl.col.a.arr.get(pl.lit(None), null_on_oob=True))
expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame()
assert_frame_equal(out_df, expected_df)

# Out-of-bounds index literal.
out = s.arr.get(100, null_on_oob=True)
expected = pl.Series("a", [None, None, None], dtype=pl.Int64)
assert_series_equal(out, expected)

# Negative index literal.
out = s.arr.get(-2)
out = s.arr.get(-2, null_on_oob=True)
expected = pl.Series("a", [3, None, 9], dtype=pl.Int64)
assert_series_equal(out, expected)

# Test index expr.
out = s.arr.get(pl.Series([1, -2, 100]))
out = s.arr.get(pl.Series([1, -2, 100]), null_on_oob=True)
expected = pl.Series("a", [2, None, None], dtype=pl.Int64)
assert_series_equal(out, expected)

Expand All @@ -169,7 +217,7 @@ def test_array_get() -> None:
],
dtype=pl.Array(pl.Date, 2),
)
out = s.arr.get(pl.Series([1, -2, 4]))
out = s.arr.get(pl.Series([1, -2, 4]), null_on_oob=True)
expected = pl.Series(
"a",
[datetime.date(2000, 1, 1), datetime.date(2001, 10, 1), None],
Expand Down

0 comments on commit 8440457

Please sign in to comment.