Skip to content

Commit

Permalink
fix take of nested list
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 19, 2021
1 parent 2752967 commit 59ce6bb
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 17 deletions.
7 changes: 7 additions & 0 deletions polars/polars-core/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,11 @@ impl ListChunked {
pub(crate) fn can_fast_explode(&self) -> bool {
self.bit_settings & 1 << 2 != 0
}

pub(crate) fn is_nested(&self) -> bool {
match self.dtype() {
DataType::List(inner) => matches!(&**inner, DataType::List(_)),
_ => unreachable!(),
}
}
}
39 changes: 24 additions & 15 deletions polars/polars-core/src/chunked_array/ops/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::prelude::*;
use crate::utils::NoNull;

use polars_arrow::array::PolarsArray;
use std::borrow::Cow;
pub use take_random::*;
pub use traits::*;

Expand Down Expand Up @@ -324,47 +325,55 @@ impl ChunkTake for ListChunked {
I: TakeIterator,
INulls: TakeIteratorNulls,
{
let mut chunks = self.downcast_iter();
let ca_self = if self.is_nested() {
Cow::Owned(self.rechunk())
} else {
Cow::Borrowed(self)
};
let mut chunks = ca_self.downcast_iter();
match indices {
TakeIdx::Array(array) => {
let array = match self.chunks.len() {
let array = match ca_self.chunks.len() {
1 => Arc::new(take_list_unchecked(chunks.next().unwrap(), array)) as ArrayRef,
_ => {
return if !array.has_validity() {
let iter = array.values().iter().map(|i| *i as usize);
let mut ca: ListChunked = take_iter_n_chunks_unchecked!(self, iter);
ca.rename(self.name());
let mut ca: ListChunked =
take_iter_n_chunks_unchecked!(ca_self.as_ref(), iter);
ca.rename(ca_self.name());
ca
} else {
let iter = array
.into_iter()
.map(|opt_idx| opt_idx.map(|idx| *idx as usize));
let mut ca: ListChunked = take_opt_iter_n_chunks_unchecked!(self, iter);
ca.rename(self.name());
let mut ca: ListChunked =
take_opt_iter_n_chunks_unchecked!(ca_self.as_ref(), iter);
ca.rename(ca_self.name());
ca
}
}
};
self.copy_with_chunks(vec![array])
ca_self.copy_with_chunks(vec![array])
}
// todo! fast path for single chunk
TakeIdx::Iter(iter) => {
if self.chunks.len() == 1 {
if ca_self.chunks.len() == 1 {
let idx: NoNull<UInt32Chunked> = iter.map(|v| v as u32).collect();
self.take_unchecked((&idx.into_inner()).into())
ca_self.take_unchecked((&idx.into_inner()).into())
} else {
let mut ca: ListChunked = take_iter_n_chunks_unchecked!(self, iter);
ca.rename(self.name());
let mut ca: ListChunked = take_iter_n_chunks_unchecked!(ca_self.as_ref(), iter);
ca.rename(ca_self.name());
ca
}
}
TakeIdx::IterNulls(iter) => {
if self.chunks.len() == 1 {
if ca_self.chunks.len() == 1 {
let idx: UInt32Chunked = iter.map(|v| v.map(|v| v as u32)).collect();
self.take_unchecked((&idx).into())
ca_self.take_unchecked((&idx).into())
} else {
let mut ca: ListChunked = take_opt_iter_n_chunks_unchecked!(self, iter);
ca.rename(self.name());
let mut ca: ListChunked =
take_opt_iter_n_chunks_unchecked!(ca_self.as_ref(), iter);
ca.rename(ca_self.name());
ca
}
}
Expand Down
2 changes: 0 additions & 2 deletions polars/polars-lazy/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2165,8 +2165,6 @@ fn test_take_consistency() -> Result<()> {
}

#[test]
#[cfg(feature = "ignore")]
// todo! activate (strange abort in CI)
fn test_groupby_on_lists() -> Result<()> {
let s0 = Series::new("", [1i32, 2, 3]);
let s1 = Series::new("groups", [4i32, 5]);
Expand Down
7 changes: 7 additions & 0 deletions polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@
//! * `POLARS_TABLE_WIDTH` -> width of the tables used during DataFrame formatting.
//! * `POLARS_MAX_THREADS` -> maximum number of threads used to initialize thread pool (on startup).
//! * `POLARS_VERBOSE` -> print logging info to stderr
//! * `POLARS_NO_PARTITION` -> Polars may choose to partition the groupby operaiton, based on data
//! cardinality. Setting this env var will turn partitioned groupby's off
//! * `POLARS_PARTITION_SAMPLE_FRAC` -> how large chunk of the dataset to sample to determine cardinality,
//! defaults to `0.001`
//! * `POLARS_PARTITION_CARDINALITY_FRAC` -> at which (estimated) cardinality a partitioned groupby should run.
//! defaults to `0.005`, any higher cardinality will run default groupby.
//!
//!
//! ## Compile for WASM
//! To be able to pretty print a `DataFrame` in `wasm32-wasi` you need to patch the `prettytable-rs`
Expand Down

0 comments on commit 59ce6bb

Please sign in to comment.