Skip to content

Commit

Permalink
fix arr.get() offsets (#3731)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 18, 2022
1 parent 2b7d403 commit 697b5a6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 3 deletions.
33 changes: 32 additions & 1 deletion polars/polars-arrow/src/kernels/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ fn sublist_get_indexes(arr: &ListArray<i64>, index: i64) -> IdxArr {
let len = offset - previous;
previous = offset;
// make sure that empty lists don't get accessed
if len == 0 || index >= len {
// and out of bounds return null
if len == 0 {
return None;
}
if index >= len {
cum_offset += len as IdxSize;
return None;
}

Expand Down Expand Up @@ -111,6 +116,32 @@ mod test {
assert_eq!(out.values().as_slice(), &[2, 4, 5]);
let out = sublist_get_indexes(&arr, 3);
assert_eq!(out.null_count(), 3);

let values = Int32Array::from_iter([
Some(1),
Some(1),
Some(3),
Some(4),
Some(5),
Some(6),
Some(7),
Some(8),
Some(9),
None,
Some(11),
]);
let offsets = Buffer::from(vec![0i64, 1, 2, 3, 6, 9, 11]);

let dtype = ListArray::<i64>::default_datatype(DataType::Int32);
let arr = ListArray::<i64>::from_data(dtype, offsets, Box::new(values), None);

let out = sublist_get_indexes(&arr, 1);
assert_eq!(
out.into_iter()
.map(|opt_v| opt_v.cloned())
.collect::<Vec<_>>(),
&[None, None, None, Some(4), Some(7), Some(10)]
);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object = ["polars-core/object"]
to_dummies = []
list_to_struct = ["polars-core/dtype-struct", "list"]
list = []
diff = []
diff = ["polars-core/diff"]
strings = ["polars-core/strings"]
string_justify = ["polars-core/strings"]
log = []
14 changes: 14 additions & 0 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ def test_list_arr_get() -> None:
expected = pl.Series("a", [1, None, 7])
testing.assert_series_equal(out, expected)

assert pl.DataFrame(
{"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]}
).with_columns(
[pl.col("a").arr.get(i).alias(f"get_{i}") for i in range(4)]
).to_dict(
False
) == {
"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]],
"get_0": [1, 2, 3, 4, 7, None],
"get_1": [None, None, None, 5, 8, 11],
"get_2": [None, None, None, 6, 9, None],
"get_3": [None, None, None, None, None, None],
}


def test_contains() -> None:
a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]])
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_list_to_struct() -> None:
[pl.col("a").arr.to_struct(n_field_strategy="max_width")]
).to_series().to_list() == [
{"field_0": 1, "field_1": 2, "field_2": None},
{"field_0": 1, "field_1": 2, "field_2": 1},
{"field_0": 1, "field_1": 2, "field_2": 3},
]


Expand Down

0 comments on commit 697b5a6

Please sign in to comment.