Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support + in strings_udf #12117

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
40eb1e2
Add strings udf C++ classes and function for phase II
davidwendt Oct 12, 2022
5317db8
fix style error
davidwendt Oct 12, 2022
8e531a5
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 13, 2022
b8d7868
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 13, 2022
2e36b6a
support returning strings within strings_udf library
brandon-b-miller Oct 13, 2022
238c862
returning strings working
brandon-b-miller Oct 14, 2022
0544c23
clean up code a bit
brandon-b-miller Oct 17, 2022
ae1bbdc
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 17, 2022
edcaaf2
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 18, 2022
a5661bc
change void* to udf_string*
davidwendt Oct 18, 2022
9661c4e
update doxygens
davidwendt Oct 18, 2022
a6f03a3
remove unnecessary explicit casting
brandon-b-miller Oct 18, 2022
ece495f
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 18, 2022
5554ed9
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 19, 2022
ebaf088
add pad utility functions
davidwendt Oct 19, 2022
c3e17ac
fix doxygen for udf_apis.hpp
davidwendt Oct 19, 2022
2dae45d
fix to_string to use count_digits
davidwendt Oct 20, 2022
3467f34
add ALL_FLAGS
davidwendt Oct 20, 2022
4f63c54
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 20, 2022
7639039
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 21, 2022
84721d4
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 24, 2022
4c72149
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 25, 2022
cf72fc8
add noexcept decl to appropriate member functions
davidwendt Oct 25, 2022
28e917b
fix return types for split
davidwendt Oct 25, 2022
f82c454
fix doxygen for various functions
davidwendt Oct 25, 2022
3b513a3
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 26, 2022
7b9718c
create free_udf_strings_array function
davidwendt Oct 31, 2022
68e54e8
fix compare returns, null assignment, reuse ctors
davidwendt Oct 31, 2022
6eef0a4
fix some doxygen wording
davidwendt Oct 31, 2022
02aa5b4
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 31, 2022
69e0d7c
remove string_view const parameter decl
davidwendt Oct 31, 2022
a95c030
fix default-stream
davidwendt Oct 31, 2022
e0526e6
remove lstrip and rstrip
davidwendt Oct 31, 2022
bc903d6
reword split doxygen text for result=nullptr
davidwendt Oct 31, 2022
229c1f2
Merge branch 'branch-22.12' into udf-string-class
davidwendt Nov 1, 2022
eb6532e
add cuda_runtime.h to resolve device refs
davidwendt Nov 1, 2022
a8fca12
fix doxygen wording for pad()
davidwendt Nov 1, 2022
a249d13
refactor split; add count_tokens function
davidwendt Nov 1, 2022
96b06f6
refactor append, replace for better reuse
davidwendt Nov 1, 2022
7849307
expand spos/epos var names
davidwendt Nov 1, 2022
cadcf79
add more doc to replace() for count parm
davidwendt Nov 1, 2022
b3a43b8
Merge branch 'branch-22.12' into udf-string-class
davidwendt Nov 1, 2022
e0d1374
Merge remote-tracking branch 'david/udf-string-class' into fea-string…
brandon-b-miller Nov 1, 2022
1e02c26
adjust for changes
brandon-b-miller Nov 1, 2022
c9ef3ec
Merge branch 'branch-22.12' into fea-strings-udf-return-strings
brandon-b-miller Nov 2, 2022
1218c08
fix up cython
brandon-b-miller Nov 2, 2022
b9aabdd
merge the latest, resolve conflicts, pass tests
brandon-b-miller Nov 3, 2022
e864dea
from_udf_string_array -> column_from_udf_string_array, to_string_view…
brandon-b-miller Nov 3, 2022
9fccc9b
refactor
brandon-b-miller Nov 3, 2022
d5c37a8
prune imports
brandon-b-miller Nov 3, 2022
b7c1b1d
cleanup
brandon-b-miller Nov 3, 2022
267b904
begin to address reviews
brandon-b-miller Nov 4, 2022
8b7a412
Update python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx
brandon-b-miller Nov 4, 2022
4f821ca
finish addressing reviews, walrus everywhere!
brandon-b-miller Nov 4, 2022
b0a8681
support strip
brandon-b-miller Nov 7, 2022
18aee5a
updates
brandon-b-miller Nov 8, 2022
2cefbe4
merge 22.12
brandon-b-miller Nov 8, 2022
d7556b0
fix bad merge
brandon-b-miller Nov 8, 2022
c4f8847
add tests to cudf
brandon-b-miller Nov 8, 2022
7030108
plumb to maskedtype
brandon-b-miller Nov 8, 2022
11e966c
cleanup
brandon-b-miller Nov 8, 2022
9991c76
more cleanup
brandon-b-miller Nov 8, 2022
837a49c
Update python/strings_udf/strings_udf/lowering.py
brandon-b-miller Nov 9, 2022
302fe60
address reviews
brandon-b-miller Nov 9, 2022
e0f98cc
plumb concat to empty shim function for now
brandon-b-miller Nov 9, 2022
2bbba3b
troublesome segfaulting shim function
brandon-b-miller Nov 10, 2022
46caab0
merge latest and resolve conflicts
brandon-b-miller Nov 10, 2022
906f6d5
zero out preallocated udf_string
brandon-b-miller Nov 10, 2022
bdb64e9
add tests for maskedtype and a little extra typing
brandon-b-miller Nov 10, 2022
bfff98a
Update python/strings_udf/cpp/src/strings/udf/shim.cu
brandon-b-miller Nov 10, 2022
ab061f3
Merge branch 'branch-22.12' into fea-stringudf-concat
brandon-b-miller Nov 14, 2022
c22123a
move memset into lowering
brandon-b-miller Nov 14, 2022
7255d94
use placement new
brandon-b-miller Nov 15, 2022
d6d030c
add reflected concat test
brandon-b-miller Nov 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/cudf/cudf/core/udf/strings_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def len_typing(self, args, kws):
return nb_signature(size_type, args[0])


@register_string_function(operator.add)
def concat_typing(self, args, kws):
if _is_valid_string_arg(args[0]) and _is_valid_string_arg(args[1]):
return nb_signature(
MaskedType(udf_string),
MaskedType(string_view),
MaskedType(string_view),
)


@register_string_function(operator.contains)
def contains_typing(self, args, kws):
if _is_valid_string_arg(args[0]) and _is_valid_string_arg(args[1]):
Expand Down
9 changes: 9 additions & 0 deletions python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,15 @@ def func(row):
run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
@pytest.mark.parametrize("concat_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_concat(str_udf_data, concat_char):
def func(row):
return row["str_col"] + concat_char

run_masked_udf_test(func, str_udf_data, check_dtype=False)


@pytest.mark.parametrize(
"data", [[1.0, 0.0, 1.5], [1, 0, 2], [True, False, True]]
)
Expand Down
13 changes: 13 additions & 0 deletions python/strings_udf/cpp/src/strings/udf/shim.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,16 @@ extern "C" __device__ int rstrip(int* nb_retval,

return 0;
}

extern "C" __device__ int concat(int* nb_retval, void* udf_str, void* const* lhs, void* const* rhs)
{
auto lhs_ptr = reinterpret_cast<cudf::string_view const*>(lhs);
auto rhs_ptr = reinterpret_cast<cudf::string_view const*>(rhs);

auto udf_str_ptr = new (udf_str) udf_string;

udf_string result;
result.append(*lhs_ptr).append(*rhs_ptr);
*udf_str_ptr = result;
return 0;
}
5 changes: 5 additions & 0 deletions python/strings_udf/strings_udf/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,13 @@ def generic(self, args, kws):
register_stringview_binaryop(operator.gt, types.boolean)
register_stringview_binaryop(operator.le, types.boolean)
register_stringview_binaryop(operator.ge, types.boolean)

# st in other
register_stringview_binaryop(operator.contains, types.boolean)

# st + other
register_stringview_binaryop(operator.add, udf_string)
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved


def create_binary_attr(attrname, retty):
"""
Expand Down
28 changes: 28 additions & 0 deletions python/strings_udf/strings_udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
# CUDA function declarations
# read-only (input is a string_view, output is a fixed with type)
_string_view_len = cuda.declare_device("len", size_type(_STR_VIEW_PTR))
_concat_string_view = cuda.declare_device(
"concat", types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR)
)


def _declare_binary_func(lhs, rhs, out, name):
Expand Down Expand Up @@ -160,6 +163,31 @@ def len_impl(context, builder, sig, args):
return result


def call_concat_string_view(result, lhs, rhs):
return _concat_string_view(result, lhs, rhs)


@cuda_lower(operator.add, string_view, string_view)
def concat_impl(context, builder, sig, args):
lhs_ptr = builder.alloca(args[0].type)
rhs_ptr = builder.alloca(args[1].type)
builder.store(args[0], lhs_ptr)
builder.store(args[1], rhs_ptr)

udf_str_ptr = builder.alloca(default_manager[udf_string].get_value_type())
_ = context.compile_internal(
builder,
call_concat_string_view,
types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR),
(udf_str_ptr, lhs_ptr, rhs_ptr),
)

result = cgutils.create_struct_proxy(udf_string)(
context, builder, value=builder.load(udf_str_ptr)
)
return result._getvalue()


def create_binary_string_func(binary_func, retty):
"""
Provide a wrapper around numba's low-level extension API which
Expand Down
16 changes: 16 additions & 0 deletions python/strings_udf/strings_udf/tests/test_string_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,19 @@ def func(st):
return st.rstrip(strip_char)

run_udf_test(data, func, "str")


@pytest.mark.parametrize("concat_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_concat(data, concat_char):
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
def func(st):
return st + concat_char

run_udf_test(data, func, "str")


@pytest.mark.parametrize("concat_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_concat_reflected(data, concat_char):
def func(st):
return concat_char + st

run_udf_test(data, func, "str")