Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default= kwarg to .list.get() accessor method #10547

Merged
merged 13 commits into from Apr 6, 2022
41 changes: 28 additions & 13 deletions python/cudf/cudf/core/column/lists.py
Expand Up @@ -2,7 +2,7 @@

import pickle
from functools import cached_property
from typing import List, Sequence
from typing import List, Optional, Sequence

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -337,16 +337,18 @@ def __init__(self, parent: ParentType):
)
super().__init__(parent=parent)

def get(self, index: int) -> ParentType:
def get(
self, index: int, default: Optional[ScalarLike] = None
) -> ParentType:
"""
Extract element at the given index from each component

Extract element from lists, tuples, or strings in
each element in the Series/Index.
Extract element at the given index from each list.
bdice marked this conversation as resolved.
Show resolved Hide resolved
If the index is out of bounds for any list,
return <NA> or, if provided, ``default``.

Parameters
----------
index : int
default : scalar, optional

Returns
-------
Expand All @@ -360,14 +362,27 @@ def get(self, index: int) -> ParentType:
1 5
2 6
dtype: int64

>>> s = cudf.Series([[1, 2], [3, 4, 5], [4, 5, 6]])
>>> s.list.get(2)
0 <NA>
1 5
2 6
dtype: int64

>>> s = cudf.Series([[1, 2], [3, 4, 5], [4, 5, 6]])
>>> s.list.get(2, default=0)
0 0
1 5
2 6
dtype: int64
"""
min_col_list_len = self.len().min()
if -min_col_list_len <= index < min_col_list_len:
return self._return_or_inplace(
extract_element(self._column, index)
)
else:
raise IndexError("list index out of range")
out = extract_element(self._column, index)
return self._return_or_inplace(
out
if (default is None or default is cudf.NA)
else out.fillna(default)
)

def contains(self, search_key: ScalarLike) -> ParentType:
"""
Expand Down
10 changes: 6 additions & 4 deletions python/cudf/cudf/tests/test_list.py
Expand Up @@ -292,10 +292,12 @@ def test_get_nested_lists():
assert_eq(expect, got)


def test_get_nulls():
with pytest.raises(IndexError, match="list index out of range"):
sr = cudf.Series([[], [], []])
sr.list.get(100)
def test_get_default():
sr = cudf.Series([[1, 2], [3, 4, 5], [6, 7, 8, 9]])

assert_eq(cudf.Series([cudf.NA, 5, 8]), sr.list.get(2))
assert_eq(cudf.Series([cudf.NA, 5, 8]), sr.list.get(2, default=cudf.NA))
assert_eq(cudf.Series([0, 5, 8]), sr.list.get(2, default=0))
bdice marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
Expand Down