Skip to content

Commit

Permalink
reduce some branching in list iter (#2244)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 1, 2022
1 parent d57a28f commit 214bddc
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 54 deletions.
8 changes: 7 additions & 1 deletion polars/polars-core/src/chunked_array/iterator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,20 @@ impl<'a> IntoIterator for &'a ListChunked {
type Item = Option<Series>;
type IntoIter = Box<dyn PolarsIterator<Item = Self::Item> + 'a>;
fn into_iter(self) -> Self::IntoIter {
let dtype = self.inner_dtype().to_arrow();

// we know that we only iterate over length == self.len()
unsafe {
Box::new(
self.downcast_iter()
.map(|arr| arr.iter())
.flatten()
.trust_my_length(self.len())
.map(|arr| arr.map(|arr| Series::try_from(("", arr)).unwrap())),
.map(move |arr| {
arr.map(|arr| {
Series::try_from_unchecked("", vec![Arc::from(arr)], &dtype).unwrap()
})
}),
)
}
}
Expand Down
8 changes: 1 addition & 7 deletions polars/polars-core/src/chunked_array/list/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,7 @@ impl ListChunked {
series_container,
inner: NonNull::new(ptr).unwrap(),
lifetime: PhantomData,
// safety: we know the iterators len
iter: unsafe {
self.downcast_iter()
.map(|arr| arr.iter())
.flatten()
.trust_my_length(self.len())
},
iter: self.downcast_iter().map(|arr| arr.iter()).flatten(),
}
}

Expand Down
107 changes: 61 additions & 46 deletions polars/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,16 @@ use arrow::temporal_conversions::NANOSECONDS;
use polars_arrow::compute::cast::cast;
use std::convert::TryFrom;

fn convert_list_inner(arr: &ArrayRef, fld: &ArrowField) -> ArrayRef {
// if inner type is Utf8, we need to convert that to large utf8
match fld.data_type() {
ArrowDataType::Utf8 => {
let arr = arr.as_any().downcast_ref::<ListArray<i64>>().unwrap();
let offsets = arr.offsets().iter().map(|x| *x as i64).collect();
let values = arr.values();
let values =
utf8_to_large_utf8(values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap());

Arc::new(LargeListArray::from_data(
ArrowDataType::LargeList(
ArrowField::new(fld.name(), ArrowDataType::LargeUtf8, true).into(),
),
offsets,
Arc::new(values),
arr.validity().cloned(),
))
}
_ => arr.clone(),
}
}

// TODO: add types
impl TryFrom<(&str, Vec<ArrayRef>)> for Series {
type Error = PolarsError;

fn try_from(name_arr: (&str, Vec<ArrayRef>)) -> Result<Self> {
let (name, chunks) = name_arr;

let mut chunks_iter = chunks.iter();
let data_type: &ArrowDataType = chunks_iter
.next()
.ok_or_else(|| PolarsError::NoData("Expected at least on ArrayRef".into()))?
.data_type();

for chunk in chunks_iter {
if chunk.data_type() != data_type {
return Err(PolarsError::InvalidOperation(
"Cannot create series from multiple arrays with different types".into(),
));
}
}

match data_type {
impl Series {
// Create a new Series without checking if the inner dtype of the chunks is correct
// # Safety
// The caller must ensure that the given `dtype` matches all the `ArrayRef` dtypes.
pub(crate) unsafe fn try_from_unchecked(
name: &str,
chunks: Vec<ArrayRef>,
dtype: &ArrowDataType,
) -> Result<Self> {
match dtype {
ArrowDataType::LargeUtf8 => {
Ok(Utf8Chunked::new_from_chunks(name, chunks).into_series())
}
Expand Down Expand Up @@ -327,7 +292,7 @@ impl TryFrom<(&str, Vec<ArrayRef>)> for Series {
// this is highly unsafe. it will dereference a raw ptr on the heap
// make sure the ptr is allocated and from this pid
// (the pid is checked before dereference)
let s = unsafe {
let s = {
let pe = PolarsExtension::new(arr.clone());
let s = pe.get_series();
pe.take_and_forget();
Expand All @@ -342,6 +307,56 @@ impl TryFrom<(&str, Vec<ArrayRef>)> for Series {
}
}

fn convert_list_inner(arr: &ArrayRef, fld: &ArrowField) -> ArrayRef {
// if inner type is Utf8, we need to convert that to large utf8
match fld.data_type() {
ArrowDataType::Utf8 => {
let arr = arr.as_any().downcast_ref::<ListArray<i64>>().unwrap();
let offsets = arr.offsets().iter().map(|x| *x as i64).collect();
let values = arr.values();
let values =
utf8_to_large_utf8(values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap());

Arc::new(LargeListArray::from_data(
ArrowDataType::LargeList(
ArrowField::new(fld.name(), ArrowDataType::LargeUtf8, true).into(),
),
offsets,
Arc::new(values),
arr.validity().cloned(),
))
}
_ => arr.clone(),
}
}

// TODO: add types
impl TryFrom<(&str, Vec<ArrayRef>)> for Series {
type Error = PolarsError;

fn try_from(name_arr: (&str, Vec<ArrayRef>)) -> Result<Self> {
let (name, chunks) = name_arr;

let mut chunks_iter = chunks.iter();
let data_type: ArrowDataType = chunks_iter
.next()
.ok_or_else(|| PolarsError::NoData("Expected at least on ArrayRef".into()))?
.data_type()
.clone();

for chunk in chunks_iter {
if chunk.data_type() != &data_type {
return Err(PolarsError::InvalidOperation(
"Cannot create series from multiple arrays with different types".into(),
));
}
}
// Safety:
// dtype is checked
unsafe { Series::try_from_unchecked(name, chunks, &data_type) }
}
}

impl TryFrom<(&str, ArrayRef)> for Series {
type Error = PolarsError;

Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 214bddc

Please sign in to comment.