Skip to content

Commit

Permalink
Add 'get' for list sublists
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 9, 2021
1 parent 4c0d7b3 commit 774bf27
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 1 deletion.
19 changes: 19 additions & 0 deletions polars/polars-arrow/src/index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
pub(crate) trait IndexToUsize {
/// Translate the negative index to an offset.
fn to_usize(self, length: usize) -> Option<usize>;
}

impl IndexToUsize for i64 {
fn to_usize(self, length: usize) -> Option<usize> {
if self >= 0 && (self as usize) < length {
Some(self as usize)
} else {
let subtract = self.abs() as usize;
if subtract > length {
None
} else {
Some(length - subtract)
}
}
}
}
100 changes: 100 additions & 0 deletions polars/polars-arrow/src/kernels/list.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use crate::index::IndexToUsize;
use crate::kernels::take::take_unchecked;
use crate::utils::CustomIterTools;
use arrow::array::{ArrayRef, ListArray, PrimitiveArray};

/// Get the indices that would result in a get operation on the lists values.
/// for example, consider this list:
/// ```text
/// [[1, 2, 3],
/// [4, 5],
/// [6]]
///
/// This contains the following values array:
/// [1, 2, 3, 4, 5, 6]
///
/// get index 0
/// would lead to the following indexes:
/// [0, 3, 5].
/// if we use those in a take operation on the values array we get:
/// [1, 4, 6]
///
///
/// get index -1
/// would lead to the following indexes:
/// [2, 4, 5].
/// if we use those in a take operation on the values array we get:
/// [3, 5, 6]
///
/// ```
fn sublist_get_indexes(arr: &ListArray<i64>, index: i64) -> PrimitiveArray<u32> {
let mut iter = arr.offsets().iter();

let mut cum_offset = 0u32;

if let Some(mut previous) = iter.next().copied() {
let a: PrimitiveArray<u32> = iter
.map(|&offset| {
let len = offset - previous;
previous = offset;

let out = index
.to_usize(len as usize)
.map(|idx| idx as u32 + cum_offset);
cum_offset += len as u32;
out
})
.collect_trusted();

a
} else {
PrimitiveArray::<u32>::from_slice(&[])
}
}

pub fn sublist_get(arr: &ListArray<i64>, index: i64) -> ArrayRef {
let take_by = sublist_get_indexes(arr, index);
let values = arr.values();
// Safety:
// the indices we generate are in bounds
unsafe { take_unchecked(&**values, &take_by) }
}

#[cfg(test)]
mod test {
use super::*;
use arrow::array::Int32Array;
use arrow::buffer::Buffer;
use arrow::datatypes::DataType;
use std::sync::Arc;

fn get_array() -> ListArray<i64> {
let values = Int32Array::from_slice(&[1, 2, 3, 4, 5, 6]);
let offsets = Buffer::from(&[0i64, 3, 5, 6]);

let dtype = ListArray::<i64>::default_datatype(DataType::Int32);
ListArray::<i64>::from_data(dtype, offsets, Arc::new(values), None)
}

#[test]
fn test_sublist_get_indexes() {
let arr = get_array();
let out = sublist_get_indexes(&arr, 0);
assert_eq!(out.values().as_slice(), &[0, 3, 5]);
let out = sublist_get_indexes(&arr, -1);
assert_eq!(out.values().as_slice(), &[2, 4, 5]);
}

#[test]
fn test_sublist_get() {
let arr = get_array();

let out = sublist_get(&arr, 0);
let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();

assert_eq!(out.values().as_slice(), &[1, 4, 6]);
let out = sublist_get(&arr, -1);
let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();
assert_eq!(out.values().as_slice(), &[3, 5, 6]);
}
}
1 change: 1 addition & 0 deletions polars/polars-arrow/src/kernels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use arrow::array::BooleanArray;
use arrow::bitmap::utils::BitChunks;
use std::iter::Enumerate;
pub mod float;
pub mod list;
pub mod rolling;
pub mod set;
#[cfg(feature = "strings")]
Expand Down
30 changes: 29 additions & 1 deletion polars/polars-arrow/src/kernels/take.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
use crate::utils::with_match_primitive_type;
use crate::{bit_util::unset_bit_raw, prelude::*, utils::CustomIterTools};
use arrow::array::*;
use arrow::bitmap::MutableBitmap;
use arrow::buffer::{Buffer, MutableBuffer};
use arrow::datatypes::DataType;
use arrow::datatypes::{DataType, PhysicalType};
use arrow::types::NativeType;
use std::sync::Arc;

/// # Safety
/// Does not do bounds checks
pub unsafe fn take_unchecked(arr: &dyn Array, idx: &UInt32Array) -> ArrayRef {
use PhysicalType::*;
match arr.data_type().to_physical_type() {
Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
let arr: &PrimitiveArray<$T> = arr.as_any().downcast_ref().unwrap();
if arr.validity().is_some() {
take_primitive_unchecked::<$T>(arr, idx)
} else {
take_no_null_primitive::<$T>(arr, idx)
}
}),
LargeUtf8 => {
let arr = arr.as_any().downcast_ref().unwrap();
take_utf8_unchecked(arr, idx)
}
// TODO! implement proper unchecked version
#[cfg(feature = "compute")]
Boolean => {
use arrow::compute::take::take;
Arc::from(take(arr, idx).unwrap())
}
_ => unimplemented!(),
}
}

/// Take kernel for single chunk with nulls and arrow array as index that may have nulls.
/// # Safety
/// caller must ensure indices are in bounds
Expand Down
1 change: 1 addition & 0 deletions polars/polars-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod bit_util;
#[cfg(feature = "compute")]
pub mod compute;
pub mod error;
pub mod index;
pub mod is_valid;
pub mod kernels;
pub mod prelude;
Expand Down
46 changes: 46 additions & 0 deletions polars/polars-arrow/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::trusted_len::{FromIteratorReversed, PushUnchecked, TrustedLen};
use arrow::array::PrimitiveArray;
use arrow::bitmap::Bitmap;
use arrow::types::NativeType;
use std::ops::BitAnd;

pub struct TrustMyLength<I: Iterator<Item = J>, J> {
Expand Down Expand Up @@ -109,3 +111,47 @@ impl<T> FromTrustedLenIterator<T> for Vec<T> {
v
}
}

impl<T: NativeType> FromTrustedLenIterator<Option<T>> for PrimitiveArray<T> {
fn from_iter_trusted_length<I: IntoIterator<Item = Option<T>>>(iter: I) -> Self
where
I::IntoIter: TrustedLen,
{
let iter = iter.into_iter();
unsafe { PrimitiveArray::from_trusted_len_iter_unchecked(iter) }
}
}

impl<T: NativeType> FromTrustedLenIterator<T> for PrimitiveArray<T> {
fn from_iter_trusted_length<I: IntoIterator<Item = T>>(iter: I) -> Self
where
I::IntoIter: TrustedLen,
{
let iter = iter.into_iter();
unsafe { PrimitiveArray::from_trusted_len_values_iter_unchecked(iter) }
}
}

macro_rules! with_match_primitive_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use arrow::datatypes::PrimitiveType::*;
use arrow::types::{days_ms, months_days_ns};
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
DaysMs => __with_ty__! { days_ms },
MonthDayNano => __with_ty__! { months_days_ns },
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
})}
pub(crate) use with_match_primitive_type;
14 changes: 14 additions & 0 deletions polars/polars-core/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::chunked_array::builder::get_list_builder;
use crate::prelude::*;
use polars_arrow::kernels::list::sublist_get;
use polars_arrow::prelude::ValueSize;
use std::convert::TryFrom;

impl ListChunked {
pub fn lst_max(&self) -> Series {
Expand Down Expand Up @@ -55,6 +57,18 @@ impl ListChunked {
UInt32Chunked::new_from_aligned_vec(self.name(), lengths)
}

/// Get the value by index in the sublists.
/// So index `0` would return the first item of every sublist
/// and index `-1` would return the last item of every sublist
/// if an index is out of bounds, it will return a `None`.
pub fn lst_get(&self, idx: i64) -> Result<Series> {
let chunks = self
.downcast_iter()
.map(|arr| sublist_get(arr, idx))
.collect::<Vec<_>>();
Series::try_from((self.name(), chunks))
}

pub fn lst_concat(&self, other: &[Series]) -> Result<ListChunked> {
let mut iters = Vec::with_capacity(other.len() + 1);
let dtype = self.dtype();
Expand Down
14 changes: 14 additions & 0 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2135,6 +2135,20 @@ def concat(self, other: Union[tp.List[Union[Expr, str]], Expr, str]) -> "Expr":
other_list.insert(0, wrap_expr(self._pyexpr))
return pli.concat_list(other_list)

def get(self, index: int) -> "Expr":
"""
Get the value by index in the sublists.
So index `0` would return the first item of every sublist
and index `-1` would return the last item of every sublist
if an index is out of bounds, it will return a `None`.
Parameters
----------
index
Index to return per sublist
"""
return wrap_expr(self._pyexpr.lst_get(index))


class ExprStringNameSpace:
"""
Expand Down
14 changes: 14 additions & 0 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3395,6 +3395,20 @@ def concat(self, other: Union[tp.List[Series], Series]) -> "Series":
df.insert_at_idx(0, s)
return df.select(pli.concat_list(names))[s.name] # type: ignore

def get(self, index: int) -> "Series":
"""
Get the value by index in the sublists.
So index `0` would return the first item of every sublist
and index `-1` would return the last item of every sublist
if an index is out of bounds, it will return a `None`.
Parameters
----------
index
Index to return per sublist
"""
return pli.select(pli.lit(wrap_s(self._s)).arr.get(index)).to_series()


class DateTimeNameSpace:
"""
Expand Down
13 changes: 13 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,19 @@ impl PyExpr {
.into()
}

fn lst_get(&self, index: i64) -> Self {
self.inner
.clone()
.map(
move |s| s.list()?.lst_get(index),
GetOutput::map_field(|field| match field.data_type() {
DataType::List(inner) => Field::new(field.name(), *inner.clone()),
_ => panic!("should be a list type"),
}),
)
.into()
}

fn rank(&self, method: &str) -> Self {
let method = str_to_rankmethod(method).unwrap();
self.inner.clone().rank(method).into()
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import polars as pl
from polars import testing


def test_list_arr_get() -> None:
a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]])
out = a.arr.get(0)
expected = pl.Series("a", [1, 4, 6])
testing.assert_series_equal(out, expected)

out = a.arr.get(-1)
expected = pl.Series("a", [3, 5, 9])
testing.assert_series_equal(out, expected)

out = a.arr.get(-3)
expected = pl.Series("a", [1, None, 7])
testing.assert_series_equal(out, expected)

0 comments on commit 774bf27

Please sign in to comment.