Skip to content

Commit

Permalink
Add default= kwarg to .list.get() accessor method (#10547)
Browse files Browse the repository at this point in the history
Closes #10540.

As mentioned in the issue, this is a breaking change, although we could introduce this change in a non-breaking way by using a sentinel value for the kwarg if desired.

Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - Ram (Ramakrishna Prabhu) (https://github.com/rgsl888prabhu)
  - Bradley Dice (https://github.com/bdice)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #10547
  • Loading branch information
shwina committed Apr 6, 2022
1 parent 956c7b5 commit 261879f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 16 deletions.
49 changes: 38 additions & 11 deletions python/cudf/cudf/core/column/lists.py
Expand Up @@ -2,7 +2,7 @@

import pickle
from functools import cached_property
from typing import List, Sequence
from typing import List, Optional, Sequence

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -337,16 +337,20 @@ def __init__(self, parent: ParentType):
)
super().__init__(parent=parent)

def get(self, index: int) -> ParentType:
def get(
self, index: int, default: Optional[ScalarLike] = None
) -> ParentType:
"""
Extract element at the given index from each component
Extract element at the given index from each list.
Extract element from lists, tuples, or strings in
each element in the Series/Index.
If the index is out of bounds for any list,
return <NA> or, if provided, ``default``.
Thus, this method never raises an ``IndexError``.
Parameters
----------
index : int
default : scalar, optional
Returns
-------
Expand All @@ -360,14 +364,37 @@ def get(self, index: int) -> ParentType:
1 5
2 6
dtype: int64
>>> s = cudf.Series([[1, 2], [3, 4, 5], [4, 5, 6]])
>>> s.list.get(2)
0 <NA>
1 5
2 6
dtype: int64
>>> s = cudf.Series([[1, 2], [3, 4, 5], [4, 5, 6]])
>>> s.list.get(2, default=0)
0 0
1 5
2 6
dtype: int64
"""
min_col_list_len = self.len().min()
if -min_col_list_len <= index < min_col_list_len:
return self._return_or_inplace(
extract_element(self._column, index)
out = extract_element(self._column, index)

if not (default is None or default is cudf.NA):
# determine rows for which `index` is out-of-bounds
lengths = count_elements(self._column)
out_of_bounds_mask = (np.negative(index) > lengths) | (
index >= lengths
)
else:
raise IndexError("list index out of range")

# replace the value in those rows (should be NA) with `default`
if out_of_bounds_mask.any():
out = out._scatter_by_column(
out_of_bounds_mask, cudf.Scalar(default)
)

return self._return_or_inplace(out)

def contains(self, search_key: ScalarLike) -> ParentType:
"""
Expand Down
30 changes: 26 additions & 4 deletions python/cudf/cudf/tests/test_list.py
Expand Up @@ -292,10 +292,32 @@ def test_get_nested_lists():
assert_eq(expect, got)


def test_get_nulls():
with pytest.raises(IndexError, match="list index out of range"):
sr = cudf.Series([[], [], []])
sr.list.get(100)
def test_get_default():
sr = cudf.Series([[1, 2], [3, 4, 5], [6, 7, 8, 9]])

assert_eq(cudf.Series([cudf.NA, 5, 8]), sr.list.get(2))
assert_eq(cudf.Series([cudf.NA, 5, 8]), sr.list.get(2, default=cudf.NA))
assert_eq(cudf.Series([0, 5, 8]), sr.list.get(2, default=0))
assert_eq(cudf.Series([0, 3, 7]), sr.list.get(-3, default=0))
assert_eq(cudf.Series([2, 5, 9]), sr.list.get(-1))

string_sr = cudf.Series(
[["apple", "banana"], ["carrot", "daffodil", "elephant"]]
)
assert_eq(
cudf.Series(["default", "elephant"]),
string_sr.list.get(2, default="default"),
)

sr_with_null = cudf.Series([[0, cudf.NA], [1]])
assert_eq(cudf.Series([cudf.NA, 0]), sr_with_null.list.get(1, default=0))

sr_nested = cudf.Series([[[1, 2], [3, 4], [5, 6]], [[5, 6], [7, 8]]])
assert_eq(cudf.Series([[3, 4], [7, 8]]), sr_nested.list.get(1))
assert_eq(cudf.Series([[5, 6], cudf.NA]), sr_nested.list.get(2))
assert_eq(
cudf.Series([[5, 6], [0, 0]]), sr_nested.list.get(2, default=[0, 0])
)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion python/dask_cudf/dask_cudf/tests/test_accessor.py
Expand Up @@ -384,7 +384,7 @@ def test_contains(data, search_key):
"data, index, expectation",
[
(data_test_1(), 1, does_not_raise()),
(data_test_2(), 2, pytest.raises(IndexError)),
(data_test_2(), 2, does_not_raise()),
],
)
def test_get(data, index, expectation):
Expand Down

0 comments on commit 261879f

Please sign in to comment.