-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
268 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
pub(crate) trait IndexToUsize { | ||
/// Translate the negative index to an offset. | ||
fn to_usize(self, length: usize) -> Option<usize>; | ||
} | ||
|
||
impl IndexToUsize for i64 { | ||
fn to_usize(self, length: usize) -> Option<usize> { | ||
if self >= 0 && (self as usize) < length { | ||
Some(self as usize) | ||
} else { | ||
let subtract = self.abs() as usize; | ||
if subtract > length { | ||
None | ||
} else { | ||
Some(length - subtract) | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
use crate::index::IndexToUsize; | ||
use crate::kernels::take::take_unchecked; | ||
use crate::utils::CustomIterTools; | ||
use arrow::array::{ArrayRef, ListArray, PrimitiveArray}; | ||
|
||
/// Get the indices that would result in a get operation on the lists values. | ||
/// for example, consider this list: | ||
/// ```text | ||
/// [[1, 2, 3], | ||
/// [4, 5], | ||
/// [6]] | ||
/// | ||
/// This contains the following values array: | ||
/// [1, 2, 3, 4, 5, 6] | ||
/// | ||
/// get index 0 | ||
/// would lead to the following indexes: | ||
/// [0, 3, 5]. | ||
/// if we use those in a take operation on the values array we get: | ||
/// [1, 4, 6] | ||
/// | ||
/// | ||
/// get index -1 | ||
/// would lead to the following indexes: | ||
/// [2, 4, 5]. | ||
/// if we use those in a take operation on the values array we get: | ||
/// [3, 5, 6] | ||
/// | ||
/// ``` | ||
fn sublist_get_indexes(arr: &ListArray<i64>, index: i64) -> PrimitiveArray<u32> { | ||
let mut iter = arr.offsets().iter(); | ||
|
||
let mut cum_offset = 0u32; | ||
|
||
if let Some(mut previous) = iter.next().copied() { | ||
let a: PrimitiveArray<u32> = iter | ||
.map(|&offset| { | ||
let len = offset - previous; | ||
previous = offset; | ||
|
||
let out = index | ||
.to_usize(len as usize) | ||
.map(|idx| idx as u32 + cum_offset); | ||
cum_offset += len as u32; | ||
out | ||
}) | ||
.collect_trusted(); | ||
|
||
a | ||
} else { | ||
PrimitiveArray::<u32>::from_slice(&[]) | ||
} | ||
} | ||
|
||
pub fn sublist_get(arr: &ListArray<i64>, index: i64) -> ArrayRef { | ||
let take_by = sublist_get_indexes(arr, index); | ||
let values = arr.values(); | ||
// Safety: | ||
// the indices we generate are in bounds | ||
unsafe { take_unchecked(&**values, &take_by) } | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use super::*; | ||
use arrow::array::Int32Array; | ||
use arrow::buffer::Buffer; | ||
use arrow::datatypes::DataType; | ||
use std::sync::Arc; | ||
|
||
fn get_array() -> ListArray<i64> { | ||
let values = Int32Array::from_slice(&[1, 2, 3, 4, 5, 6]); | ||
let offsets = Buffer::from(&[0i64, 3, 5, 6]); | ||
|
||
let dtype = ListArray::<i64>::default_datatype(DataType::Int32); | ||
ListArray::<i64>::from_data(dtype, offsets, Arc::new(values), None) | ||
} | ||
|
||
#[test] | ||
fn test_sublist_get_indexes() { | ||
let arr = get_array(); | ||
let out = sublist_get_indexes(&arr, 0); | ||
assert_eq!(out.values().as_slice(), &[0, 3, 5]); | ||
let out = sublist_get_indexes(&arr, -1); | ||
assert_eq!(out.values().as_slice(), &[2, 4, 5]); | ||
} | ||
|
||
#[test] | ||
fn test_sublist_get() { | ||
let arr = get_array(); | ||
|
||
let out = sublist_get(&arr, 0); | ||
let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap(); | ||
|
||
assert_eq!(out.values().as_slice(), &[1, 4, 6]); | ||
let out = sublist_get(&arr, -1); | ||
let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap(); | ||
assert_eq!(out.values().as_slice(), &[3, 5, 6]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import polars as pl | ||
from polars import testing | ||
|
||
|
||
def test_list_arr_get() -> None: | ||
a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) | ||
out = a.arr.get(0) | ||
expected = pl.Series("a", [1, 4, 6]) | ||
testing.assert_series_equal(out, expected) | ||
|
||
out = a.arr.get(-1) | ||
expected = pl.Series("a", [3, 5, 9]) | ||
testing.assert_series_equal(out, expected) | ||
|
||
out = a.arr.get(-3) | ||
expected = pl.Series("a", [1, None, 7]) | ||
testing.assert_series_equal(out, expected) |