Skip to content

Commit

Permalink
fix explode with empty lists (#4113)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 22, 2022
1 parent a0066a9 commit 480dd14
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 54 deletions.
38 changes: 25 additions & 13 deletions polars/polars-core/src/chunked_array/ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,33 +249,38 @@ pub(crate) fn offsets_to_indexes(offsets: &[i64], capacity: usize) -> Vec<IdxSiz
}
let mut idx = Vec::with_capacity(capacity);

let mut count = 0;
// `value_count` counts the taken values from the list values
// and aret the same unit as `offsets`
let mut value_count = 0;
// `empty_count` counts the duplicates taken because of empty list
let mut empty_count = 0;
let mut last_idx = 0;
let mut previous_empty = false;
for offset in &offsets[1..] {
while count < *offset {
count += 1;
// this get all the elements up till offsets
while value_count < *offset {
value_count += 1;
idx.push(last_idx)
}

// then we compute the previous offsets
// Safety:
// we started iterating from 1, so there is always a previous offset
// we take the pointer to the previous element and deref that to get
// the previous offset
let previous_offset = unsafe { *(offset as *const i64).offset(-1) };

if !previous_empty && (previous_offset != *offset) {
last_idx += 1;
} else {
count += 1;
// if the previous offset is equal to the current offset we have an empty
// list and we duplicate previous index
if previous_offset == *offset {
empty_count += 1;
idx.push(last_idx);
last_idx += 1;
}
previous_empty = previous_offset == *offset;

last_idx += 1;
}
// undo latest increment
last_idx -= 1;

for _ in 0..(capacity - count as usize) {
// take the remaining values
for _ in 0..(capacity - value_count as usize - empty_count as usize) {
idx.push(last_idx);
}
idx
Expand Down Expand Up @@ -564,4 +569,11 @@ mod test {

Ok(())
}

#[test]
fn test_row_offsets() {
let offsets = &[0, 1, 2, 2, 3, 4, 4];
let out = offsets_to_indexes(offsets, 6);
assert_eq!(out, &[0, 1, 2, 3, 4, 5]);
}
}
38 changes: 0 additions & 38 deletions polars/polars-core/src/frame/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,44 +278,6 @@ impl<'a> From<&AnyValue<'a>> for Field {
Field::new("", val.into())
}
}
impl<'a> From<&AnyValue<'a>> for DataType {
fn from(val: &AnyValue<'a>) -> Self {
use AnyValue::*;
match val {
Null => DataType::Null,
Boolean(_) => DataType::Boolean,
Utf8(_) => DataType::Utf8,
Utf8Owned(_) => DataType::Utf8,
UInt32(_) => DataType::UInt32,
UInt64(_) => DataType::UInt64,
Int32(_) => DataType::Int32,
Int64(_) => DataType::Int64,
Float32(_) => DataType::Float32,
Float64(_) => DataType::Float64,
#[cfg(feature = "dtype-date")]
Date(_) => DataType::Date,
#[cfg(feature = "dtype-datetime")]
Datetime(_, tu, tz) => DataType::Datetime(*tu, (*tz).clone()),
#[cfg(feature = "dtype-time")]
Time(_) => DataType::Time,
List(s) => DataType::List(Box::new(s.dtype().clone())),
#[cfg(feature = "dtype-struct")]
StructOwned(payload) => DataType::Struct(payload.1.to_vec()),
#[cfg(feature = "dtype-struct")]
Struct(_, fields) => DataType::Struct(fields.to_vec()),
#[cfg(feature = "dtype-duration")]
Duration(_, tu) => DataType::Duration(*tu),
UInt8(_) => DataType::UInt8,
UInt16(_) => DataType::UInt16,
Int8(_) => DataType::Int8,
Int16(_) => DataType::Int16,
#[cfg(feature = "dtype-categorical")]
Categorical(_, rev_map) => DataType::Categorical(Some(Arc::new((*rev_map).clone()))),
#[cfg(feature = "object")]
Object(o) => DataType::Object(o.type_name()),
}
}
}

impl From<&Row<'_>> for Schema {
fn from(row: &Row) -> Self {
Expand Down
39 changes: 39 additions & 0 deletions polars/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,42 @@ impl Series {
}
}
}

impl<'a> From<&AnyValue<'a>> for DataType {
fn from(val: &AnyValue<'a>) -> Self {
use AnyValue::*;
match val {
Null => DataType::Null,
Boolean(_) => DataType::Boolean,
Utf8(_) => DataType::Utf8,
Utf8Owned(_) => DataType::Utf8,
UInt32(_) => DataType::UInt32,
UInt64(_) => DataType::UInt64,
Int32(_) => DataType::Int32,
Int64(_) => DataType::Int64,
Float32(_) => DataType::Float32,
Float64(_) => DataType::Float64,
#[cfg(feature = "dtype-date")]
Date(_) => DataType::Date,
#[cfg(feature = "dtype-datetime")]
Datetime(_, tu, tz) => DataType::Datetime(*tu, (*tz).clone()),
#[cfg(feature = "dtype-time")]
Time(_) => DataType::Time,
List(s) => DataType::List(Box::new(s.dtype().clone())),
#[cfg(feature = "dtype-struct")]
StructOwned(payload) => DataType::Struct(payload.1.to_vec()),
#[cfg(feature = "dtype-struct")]
Struct(_, fields) => DataType::Struct(fields.to_vec()),
#[cfg(feature = "dtype-duration")]
Duration(_, tu) => DataType::Duration(*tu),
UInt8(_) => DataType::UInt8,
UInt16(_) => DataType::UInt16,
Int8(_) => DataType::Int8,
Int16(_) => DataType::Int16,
#[cfg(feature = "dtype-categorical")]
Categorical(_, rev_map) => DataType::Categorical(Some(Arc::new((*rev_map).clone()))),
#[cfg(feature = "object")]
Object(o) => DataType::Object(o.type_name()),
}
}
}
4 changes: 1 addition & 3 deletions py-polars/src/arrow_interop/to_rust.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::error::PyPolarsErr;
use polars_core::export::rayon::prelude::*;
use polars_core::prelude::*;
use polars_core::utils::accumulate_dataframes_vertical_unchecked;
use polars_core::utils::arrow::ffi;
use polars_core::utils::{
accumulate_dataframes_vertical, accumulate_dataframes_vertical_unchecked,
};
use polars_core::POOL;
use pyo3::ffi::Py_uintptr_t;
use pyo3::prelude::*;
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ def test_explode_empty_list_4003() -> None:
"id": [1, 2, 3],
"nested": [None, 1, 2],
}


def test_explode_empty_list_4107() -> None:
df = pl.DataFrame({"b": [[1], [2], []] * 2}).with_row_count()

pl.testing.assert_frame_equal(
df.explode(["b"]), df.explode(["b"]).drop("row_nr").with_row_count()
)

0 comments on commit 480dd14

Please sign in to comment.