Skip to content

Commit 0762a3c

Browse files
authored
Update strings split APIs with stream parameters (#19909)
Contributes to #15163 Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: #19909
1 parent ebb3e0d commit 0762a3c

File tree

8 files changed

+177
-63
lines changed

8 files changed

+177
-63
lines changed
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
22
from libcpp.memory cimport unique_ptr
33
from libcpp.string cimport string
44
from pylibcudf.exception_handler cimport libcudf_exception_handler
55
from pylibcudf.libcudf.column.column cimport column
66
from pylibcudf.libcudf.column.column_view cimport column_view
77
from pylibcudf.libcudf.scalar.scalar cimport string_scalar
88
from pylibcudf.libcudf.table.table cimport table
9+
from rmm.librmm.cuda_stream_view cimport cuda_stream_view
910

1011

1112
cdef extern from "cudf/strings/split/partition.hpp" namespace \
1213
"cudf::strings" nogil:
1314

1415
cdef unique_ptr[table] partition(
1516
column_view input,
16-
string_scalar delimiter) except +libcudf_exception_handler
17+
string_scalar delimiter,
18+
cuda_stream_view stream) except +libcudf_exception_handler
1719

1820
cdef unique_ptr[table] rpartition(
1921
column_view input,
20-
string_scalar delimiter) except +libcudf_exception_handler
22+
string_scalar delimiter,
23+
cuda_stream_view stream) except +libcudf_exception_handler
Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
22
from libcpp.memory cimport unique_ptr
33
from libcpp.string cimport string
44
from pylibcudf.exception_handler cimport libcudf_exception_handler
@@ -8,6 +8,7 @@ from pylibcudf.libcudf.scalar.scalar cimport string_scalar
88
from pylibcudf.libcudf.strings.regex_program cimport regex_program
99
from pylibcudf.libcudf.table.table cimport table
1010
from pylibcudf.libcudf.types cimport size_type
11+
from rmm.librmm.cuda_stream_view cimport cuda_stream_view
1112

1213

1314
cdef extern from "cudf/strings/split/split.hpp" namespace \
@@ -16,22 +17,26 @@ cdef extern from "cudf/strings/split/split.hpp" namespace \
1617
cdef unique_ptr[table] split(
1718
column_view strings_column,
1819
string_scalar delimiter,
19-
size_type maxsplit) except +libcudf_exception_handler
20+
size_type maxsplit,
21+
cuda_stream_view stream) except +libcudf_exception_handler
2022

2123
cdef unique_ptr[table] rsplit(
2224
column_view strings_column,
2325
string_scalar delimiter,
24-
size_type maxsplit) except +libcudf_exception_handler
26+
size_type maxsplit,
27+
cuda_stream_view stream) except +libcudf_exception_handler
2528

2629
cdef unique_ptr[column] split_record(
2730
column_view strings,
2831
string_scalar delimiter,
29-
size_type maxsplit) except +libcudf_exception_handler
32+
size_type maxsplit,
33+
cuda_stream_view stream) except +libcudf_exception_handler
3034

3135
cdef unique_ptr[column] rsplit_record(
3236
column_view strings,
3337
string_scalar delimiter,
34-
size_type maxsplit) except +libcudf_exception_handler
38+
size_type maxsplit,
39+
cuda_stream_view stream) except +libcudf_exception_handler
3540

3641

3742
cdef extern from "cudf/strings/split/split_re.hpp" namespace \
@@ -40,19 +45,23 @@ cdef extern from "cudf/strings/split/split_re.hpp" namespace \
4045
cdef unique_ptr[table] split_re(
4146
const column_view& input,
4247
regex_program prog,
43-
size_type maxsplit) except +libcudf_exception_handler
48+
size_type maxsplit,
49+
cuda_stream_view stream) except +libcudf_exception_handler
4450

4551
cdef unique_ptr[table] rsplit_re(
4652
const column_view& input,
4753
regex_program prog,
48-
size_type maxsplit) except +libcudf_exception_handler
54+
size_type maxsplit,
55+
cuda_stream_view stream) except +libcudf_exception_handler
4956

5057
cdef unique_ptr[column] split_record_re(
5158
const column_view& input,
5259
regex_program prog,
53-
size_type maxsplit) except +libcudf_exception_handler
60+
size_type maxsplit,
61+
cuda_stream_view stream) except +libcudf_exception_handler
5462

5563
cdef unique_ptr[column] rsplit_record_re(
5664
const column_view& input,
5765
regex_program prog,
58-
size_type maxsplit) except +libcudf_exception_handler
66+
size_type maxsplit,
67+
cuda_stream_view stream) except +libcudf_exception_handler
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
22

33
from pylibcudf.column cimport Column
44
from pylibcudf.scalar cimport Scalar
55
from pylibcudf.table cimport Table
6+
from rmm.pylibrmm.stream cimport Stream
67

78

8-
cpdef Table partition(Column input, Scalar delimiter=*)
9+
cpdef Table partition(Column input, Scalar delimiter=*, Stream stream=*)
910

10-
cpdef Table rpartition(Column input, Scalar delimiter=*)
11+
cpdef Table rpartition(Column input, Scalar delimiter=*, Stream stream=*)
Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION.
22

3+
from rmm.pylibrmm.stream import Stream
4+
35
from pylibcudf.column import Column
46
from pylibcudf.scalar import Scalar
57
from pylibcudf.table import Table
68

7-
def partition(input: Column, delimiter: Scalar | None = None) -> Table: ...
8-
def rpartition(input: Column, delimiter: Scalar | None = None) -> Table: ...
9+
def partition(
10+
input: Column,
11+
delimiter: Scalar | None = None,
12+
stream: Stream | None = None,
13+
) -> Table: ...
14+
def rpartition(
15+
input: Column,
16+
delimiter: Scalar | None = None,
17+
stream: Stream | None = None,
18+
) -> Table: ...

python/pylibcudf/pylibcudf/strings/split/partition.pyx

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from cython.operator import dereference
1717

1818
__all__ = ["partition", "rpartition"]
1919

20-
cpdef Table partition(Column input, Scalar delimiter=None):
20+
cpdef Table partition(Column input, Scalar delimiter=None, Stream stream=None):
2121
"""
2222
Returns a set of 3 columns by splitting each string using the
2323
specified delimiter.
@@ -41,23 +41,25 @@ cpdef Table partition(Column input, Scalar delimiter=None):
4141
cdef const string_scalar* c_delimiter = <const string_scalar*>(
4242
delimiter.c_obj.get()
4343
)
44-
cdef Stream stream
44+
cdef Stream stream_local
45+
46+
stream_local = _get_stream(stream)
4547

4648
if delimiter is None:
47-
stream = _get_stream(None)
4849
delimiter = Scalar.from_libcudf(
49-
cpp_make_string_scalar("".encode(), stream.view())
50+
cpp_make_string_scalar("".encode(), stream_local.view())
5051
)
5152

5253
with nogil:
5354
c_result = cpp_partition.partition(
5455
input.view(),
55-
dereference(c_delimiter)
56+
dereference(c_delimiter),
57+
stream_local.view()
5658
)
5759

58-
return Table.from_libcudf(move(c_result))
60+
return Table.from_libcudf(move(c_result), stream_local)
5961

60-
cpdef Table rpartition(Column input, Scalar delimiter=None):
62+
cpdef Table rpartition(Column input, Scalar delimiter=None, Stream stream=None):
6163
"""
6264
Returns a set of 3 columns by splitting each string using the
6365
specified delimiter starting from the end of each string.
@@ -81,18 +83,20 @@ cpdef Table rpartition(Column input, Scalar delimiter=None):
8183
cdef const string_scalar* c_delimiter = <const string_scalar*>(
8284
delimiter.c_obj.get()
8385
)
84-
cdef Stream stream
86+
cdef Stream stream_local
87+
88+
stream_local = _get_stream(stream)
8589

8690
if delimiter is None:
87-
stream = _get_stream(None)
8891
delimiter = Scalar.from_libcudf(
89-
cpp_make_string_scalar("".encode(), stream.view())
92+
cpp_make_string_scalar("".encode(), stream_local.view())
9093
)
9194

9295
with nogil:
9396
c_result = cpp_partition.rpartition(
9497
input.view(),
95-
dereference(c_delimiter)
98+
dereference(c_delimiter),
99+
stream_local.view()
96100
)
97101

98-
return Table.from_libcudf(move(c_result))
102+
return Table.from_libcudf(move(c_result), stream_local)
Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,41 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
22

33
from pylibcudf.column cimport Column
44
from pylibcudf.libcudf.types cimport size_type
55
from pylibcudf.scalar cimport Scalar
66
from pylibcudf.strings.regex_program cimport RegexProgram
77
from pylibcudf.table cimport Table
8+
from rmm.pylibrmm.stream cimport Stream
89

910

10-
cpdef Table split(Column strings_column, Scalar delimiter, size_type maxsplit)
11+
cpdef Table split(
12+
Column strings_column, Scalar delimiter, size_type maxsplit, Stream stream=*
13+
)
1114

12-
cpdef Table rsplit(Column strings_column, Scalar delimiter, size_type maxsplit)
15+
cpdef Table rsplit(
16+
Column strings_column, Scalar delimiter, size_type maxsplit, Stream stream=*
17+
)
1318

14-
cpdef Column split_record(Column strings, Scalar delimiter, size_type maxsplit)
19+
cpdef Column split_record(
20+
Column strings, Scalar delimiter, size_type maxsplit, Stream stream=*
21+
)
1522

16-
cpdef Column rsplit_record(Column strings, Scalar delimiter, size_type maxsplit)
23+
cpdef Column rsplit_record(
24+
Column strings, Scalar delimiter, size_type maxsplit, Stream stream=*
25+
)
1726

18-
cpdef Table split_re(Column input, RegexProgram prog, size_type maxsplit)
27+
cpdef Table split_re(
28+
Column input, RegexProgram prog, size_type maxsplit, Stream stream=*
29+
)
1930

20-
cpdef Table rsplit_re(Column input, RegexProgram prog, size_type maxsplit)
31+
cpdef Table rsplit_re(
32+
Column input, RegexProgram prog, size_type maxsplit, Stream stream=*
33+
)
2134

22-
cpdef Column split_record_re(Column input, RegexProgram prog, size_type maxsplit)
35+
cpdef Column split_record_re(
36+
Column input, RegexProgram prog, size_type maxsplit, Stream stream=*
37+
)
2338

24-
cpdef Column rsplit_record_re(Column input, RegexProgram prog, size_type maxsplit)
39+
cpdef Column rsplit_record_re(
40+
Column input, RegexProgram prog, size_type maxsplit, Stream stream=*
41+
)
Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,57 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION.
22

3+
from rmm.pylibrmm.stream import Stream
4+
35
from pylibcudf.column import Column
46
from pylibcudf.scalar import Scalar
57
from pylibcudf.strings.regex_program import RegexProgram
68
from pylibcudf.table import Table
79

810
def split(
9-
strings_column: Column, delimiter: Scalar, maxsplit: int
11+
strings_column: Column,
12+
delimiter: Scalar,
13+
maxsplit: int,
14+
stream: Stream | None = None,
1015
) -> Table: ...
1116
def rsplit(
12-
strings_column: Column, delimiter: Scalar, maxsplit: int
17+
strings_column: Column,
18+
delimiter: Scalar,
19+
maxsplit: int,
20+
stream: Stream | None = None,
1321
) -> Table: ...
1422
def split_record(
15-
strings: Column, delimiter: Scalar, maxsplit: int
23+
strings: Column,
24+
delimiter: Scalar,
25+
maxsplit: int,
26+
stream: Stream | None = None,
1627
) -> Column: ...
1728
def rsplit_record(
18-
strings: Column, delimiter: Scalar, maxsplit: int
29+
strings: Column,
30+
delimiter: Scalar,
31+
maxsplit: int,
32+
stream: Stream | None = None,
1933
) -> Column: ...
20-
def split_re(input: Column, prog: RegexProgram, maxsplit: int) -> Table: ...
21-
def rsplit_re(input: Column, prog: RegexProgram, maxsplit: int) -> Table: ...
34+
def split_re(
35+
input: Column,
36+
prog: RegexProgram,
37+
maxsplit: int,
38+
stream: Stream | None = None,
39+
) -> Table: ...
40+
def rsplit_re(
41+
input: Column,
42+
prog: RegexProgram,
43+
maxsplit: int,
44+
stream: Stream | None = None,
45+
) -> Table: ...
2246
def split_record_re(
23-
input: Column, prog: RegexProgram, maxsplit: int
47+
input: Column,
48+
prog: RegexProgram,
49+
maxsplit: int,
50+
stream: Stream | None = None,
2451
) -> Column: ...
2552
def rsplit_record_re(
26-
input: Column, prog: RegexProgram, maxsplit: int
53+
input: Column,
54+
prog: RegexProgram,
55+
maxsplit: int,
56+
stream: Stream | None = None,
2757
) -> Column: ...

0 commit comments

Comments
 (0)