Skip to content

Commit

Permalink
fix several list related bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 19, 2021
1 parent ac89b6d commit eafba50
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 17 deletions.
12 changes: 6 additions & 6 deletions polars/polars-core/src/chunked_array/list/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::prelude::*;
use crate::series::unstable::{ArrayBox, UnstableSeries};
use crate::utils::CustomIterTools;
use arrow::array::ArrayRef;
use std::convert::TryFrom;
use std::marker::PhantomData;
use std::pin::Pin;
use std::ptr::NonNull;
Expand Down Expand Up @@ -56,12 +57,11 @@ impl ListChunked {
/// that Series.
#[cfg(feature = "private")]
pub fn amortized_iter(&self) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
let series_container = if self.is_empty() {
// in case of no data, the actual Series does not matter
Box::pin(Series::new("", &[true]))
} else {
Box::pin(self.get(0).unwrap())
};
// we create the series container from the inner array
// so that the container has the proper dtype.
let arr = self.downcast_iter().next().unwrap();
let inner_values = arr.values();
let series_container = Box::pin(Series::try_from(("", inner_values.clone())).unwrap());

let ptr = &series_container.chunks()[0] as *const ArrayRef as *mut ArrayRef;

Expand Down
11 changes: 10 additions & 1 deletion polars/polars-core/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,25 @@ impl ListChunked {
Some(s) => s,
None => {
builder.append_null();
// make sure that the iterators advance before we continue
for it in &mut iters {
it.next().unwrap();
}
continue;
}
};
let mut already_null = false;
for it in &mut iters {
match it.next().unwrap() {
Some(s) => {
acc.append(s.as_ref())?;
}
None => {
builder.append_null();
if !already_null {
builder.append_null();
already_null = true;
}

continue;
}
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/ops/full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl ListChunked {
inner_dtype: &DataType,
) -> ListChunked {
let arr = new_null_array(
ArrowDataType::List(Box::new(ArrowField::new(
ArrowDataType::LargeList(Box::new(ArrowField::new(
"item",
inner_dtype.to_arrow(),
true,
Expand Down
1 change: 0 additions & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,6 @@ def concat_list(exprs: Sequence[Union[str, "pli.Expr"]]) -> "pli.Expr":
... ]
... )
... )
shape: (5, 1)
┌─────────────────┐
│ A_rolling │
Expand Down
37 changes: 29 additions & 8 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,37 @@ def test_list_concat_rolling_window() -> None:
# inspired by: https://stackoverflow.com/questions/70377100/use-the-rolling-function-of-polars-to-get-a-list-of-all-values-in-the-rolling-wi
# this tests if it works without specifically creating list dtype upfront.
# note that the given answer is prefered over this snippet as that reuses the list array when shifting
df = pl.DataFrame(
{
"A": [1.0, 2.0, 9.0, 2.0, 13.0],
}
)

out = df.with_columns(
[pl.col("A").shift(i).alias(f"A_lag_{i}") for i in range(3)]
).select(
[pl.concat_list([f"A_lag_{i}" for i in range(3)][::-1]).alias("A_rolling")]
)
assert out.shape == (5, 1)
assert out.to_series().dtype == pl.List

# this test proper null behavior of concat list
out = (
pl.DataFrame(
{
"A": [1.0, 2.0, 9.0, 2.0, 13.0],
}
df.with_column(pl.col("A").reshape((-1, 1))) # first turn into a list
.with_columns(
[
pl.col("A").shift(i).alias(f"A_lag_{i}")
for i in range(3) # slice the lists to a lag
]
)
.with_columns([pl.col("A").shift(i).alias(f"A_lag_{i}") for i in range(3)])
.select(
[pl.concat_list([f"A_lag_{i}" for i in range(3)][::-1]).alias("A_rolling")]
[
pl.all(),
pl.concat_list([f"A_lag_{i}" for i in range(3)][::-1]).alias(
"A_rolling"
),
]
)
)
assert out.shape == (5, 1)
assert out.to_series().dtype == pl.List
assert out.shape == (5, 5)
assert out["A_rolling"].dtype == pl.List

0 comments on commit eafba50

Please sign in to comment.