Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default= kwarg to .list.get() accessor method #10547

Merged
merged 13 commits into from Apr 6, 2022
49 changes: 38 additions & 11 deletions python/cudf/cudf/core/column/lists.py
Expand Up @@ -2,7 +2,7 @@

import pickle
from functools import cached_property
from typing import List, Sequence
from typing import List, Optional, Sequence

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -337,16 +337,20 @@ def __init__(self, parent: ParentType):
)
super().__init__(parent=parent)

def get(self, index: int) -> ParentType:
def get(
self, index: int, default: Optional[ScalarLike] = None
) -> ParentType:
"""
Extract element at the given index from each component
Extract element at the given index from each list.
bdice marked this conversation as resolved.
Show resolved Hide resolved

Extract element from lists, tuples, or strings in
each element in the Series/Index.
If the index is out of bounds for any list,
return <NA> or, if provided, ``default``.
Thus, this method never raises an ``IndexError``.

Parameters
----------
index : int
default : scalar, optional

Returns
-------
Expand All @@ -360,14 +364,37 @@ def get(self, index: int) -> ParentType:
1 5
2 6
dtype: int64

>>> s = cudf.Series([[1, 2], [3, 4, 5], [4, 5, 6]])
>>> s.list.get(2)
0 <NA>
1 5
2 6
dtype: int64

>>> s = cudf.Series([[1, 2], [3, 4, 5], [4, 5, 6]])
>>> s.list.get(2, default=0)
0 0
1 5
2 6
dtype: int64
"""
min_col_list_len = self.len().min()
if -min_col_list_len <= index < min_col_list_len:
return self._return_or_inplace(
extract_element(self._column, index)
out = extract_element(self._column, index)

if not (default is None or default is cudf.NA):
bdice marked this conversation as resolved.
Show resolved Hide resolved
# determine rows for which `index` is out-of-bounds
lengths = count_elements(self._column)
out_of_bounds_mask = (np.negative(index) > lengths) | (
index >= lengths
)
else:
raise IndexError("list index out of range")

# replace the value in those rows (should be NA) with `default`
if out_of_bounds_mask.any():
out = out._scatter_by_column(
out_of_bounds_mask, cudf.Scalar(default)
)

return self._return_or_inplace(out)

def contains(self, search_key: ScalarLike) -> ParentType:
"""
Expand Down
30 changes: 26 additions & 4 deletions python/cudf/cudf/tests/test_list.py
Expand Up @@ -292,10 +292,32 @@ def test_get_nested_lists():
assert_eq(expect, got)


def test_get_nulls():
with pytest.raises(IndexError, match="list index out of range"):
sr = cudf.Series([[], [], []])
sr.list.get(100)
def test_get_default():
sr = cudf.Series([[1, 2], [3, 4, 5], [6, 7, 8, 9]])

assert_eq(cudf.Series([cudf.NA, 5, 8]), sr.list.get(2))
assert_eq(cudf.Series([cudf.NA, 5, 8]), sr.list.get(2, default=cudf.NA))
assert_eq(cudf.Series([0, 5, 8]), sr.list.get(2, default=0))
bdice marked this conversation as resolved.
Show resolved Hide resolved
assert_eq(cudf.Series([0, 3, 7]), sr.list.get(-3, default=0))
assert_eq(cudf.Series([2, 5, 9]), sr.list.get(-1))

string_sr = cudf.Series(
[["apple", "banana"], ["carrot", "daffodil", "elephant"]]
)
assert_eq(
cudf.Series(["default", "elephant"]),
string_sr.list.get(2, default="default"),
)

sr_with_null = cudf.Series([[0, cudf.NA], [1]])
assert_eq(cudf.Series([cudf.NA, 0]), sr_with_null.list.get(1, default=0))

sr_nested = cudf.Series([[[1, 2], [3, 4], [5, 6]], [[5, 6], [7, 8]]])
assert_eq(cudf.Series([[3, 4], [7, 8]]), sr_nested.list.get(1))
assert_eq(cudf.Series([[5, 6], cudf.NA]), sr_nested.list.get(2))
assert_eq(
cudf.Series([[5, 6], [0, 0]]), sr_nested.list.get(2, default=[0, 0])
)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion python/dask_cudf/dask_cudf/tests/test_accessor.py
Expand Up @@ -384,7 +384,7 @@ def test_contains(data, search_key):
"data, index, expectation",
[
(data_test_1(), 1, does_not_raise()),
(data_test_2(), 2, pytest.raises(IndexError)),
(data_test_2(), 2, does_not_raise()),
],
)
def test_get(data, index, expectation):
Expand Down