diff --git a/python/cudf/cudf/core/udf/strings_typing.py b/python/cudf/cudf/core/udf/strings_typing.py index f8f50600b12..e8a35c12f71 100644 --- a/python/cudf/cudf/core/udf/strings_typing.py +++ b/python/cudf/cudf/core/udf/strings_typing.py @@ -59,6 +59,16 @@ def len_typing(self, args, kws): return nb_signature(size_type, args[0]) +@register_string_function(operator.add) +def concat_typing(self, args, kws): + if _is_valid_string_arg(args[0]) and _is_valid_string_arg(args[1]): + return nb_signature( + MaskedType(udf_string), + MaskedType(string_view), + MaskedType(string_view), + ) + + @register_string_function(operator.contains) def contains_typing(self, args, kws): if _is_valid_string_arg(args[0]) and _is_valid_string_arg(args[1]): diff --git a/python/cudf/cudf/tests/test_udf_masked_ops.py b/python/cudf/cudf/tests/test_udf_masked_ops.py index 7af47f981d6..fbe6b3f8888 100644 --- a/python/cudf/cudf/tests/test_udf_masked_ops.py +++ b/python/cudf/cudf/tests/test_udf_masked_ops.py @@ -903,6 +903,15 @@ def func(row): run_masked_udf_test(func, str_udf_data, check_dtype=False) +@string_udf_test +@pytest.mark.parametrize("concat_char", ["1", "a", "12", " ", "", ".", "@"]) +def test_string_udf_concat(str_udf_data, concat_char): + def func(row): + return row["str_col"] + concat_char + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + @pytest.mark.parametrize( "data", [[1.0, 0.0, 1.5], [1, 0, 2], [True, False, True]] ) diff --git a/python/strings_udf/cpp/src/strings/udf/shim.cu b/python/strings_udf/cpp/src/strings/udf/shim.cu index 63e740c5226..737afff4a1b 100644 --- a/python/strings_udf/cpp/src/strings/udf/shim.cu +++ b/python/strings_udf/cpp/src/strings/udf/shim.cu @@ -270,3 +270,16 @@ extern "C" __device__ int rstrip(int* nb_retval, return 0; } + +extern "C" __device__ int concat(int* nb_retval, void* udf_str, void* const* lhs, void* const* rhs) +{ + auto lhs_ptr = reinterpret_cast(lhs); + auto rhs_ptr = reinterpret_cast(rhs); + + auto udf_str_ptr = new (udf_str) udf_string; + + udf_string result; + result.append(*lhs_ptr).append(*rhs_ptr); + *udf_str_ptr = result; + return 0; +} diff --git a/python/strings_udf/strings_udf/_typing.py b/python/strings_udf/strings_udf/_typing.py index a309a9cb93c..b678db88b95 100644 --- a/python/strings_udf/strings_udf/_typing.py +++ b/python/strings_udf/strings_udf/_typing.py @@ -159,8 +159,13 @@ def generic(self, args, kws): register_stringview_binaryop(operator.gt, types.boolean) register_stringview_binaryop(operator.le, types.boolean) register_stringview_binaryop(operator.ge, types.boolean) + +# st in other register_stringview_binaryop(operator.contains, types.boolean) +# st + other +register_stringview_binaryop(operator.add, udf_string) + def create_binary_attr(attrname, retty): """ diff --git a/python/strings_udf/strings_udf/lowering.py b/python/strings_udf/strings_udf/lowering.py index 17a1869e881..d98384c0d02 100644 --- a/python/strings_udf/strings_udf/lowering.py +++ b/python/strings_udf/strings_udf/lowering.py @@ -25,6 +25,9 @@ # CUDA function declarations # read-only (input is a string_view, output is a fixed with type) _string_view_len = cuda.declare_device("len", size_type(_STR_VIEW_PTR)) +_concat_string_view = cuda.declare_device( + "concat", types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR) +) def _declare_binary_func(lhs, rhs, out, name): @@ -160,6 +163,31 @@ def len_impl(context, builder, sig, args): return result +def call_concat_string_view(result, lhs, rhs): + return _concat_string_view(result, lhs, rhs) + + +@cuda_lower(operator.add, string_view, string_view) +def concat_impl(context, builder, sig, args): + lhs_ptr = builder.alloca(args[0].type) + rhs_ptr = builder.alloca(args[1].type) + builder.store(args[0], lhs_ptr) + builder.store(args[1], rhs_ptr) + + udf_str_ptr = builder.alloca(default_manager[udf_string].get_value_type()) + _ = context.compile_internal( + builder, + call_concat_string_view, + types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR), + (udf_str_ptr, lhs_ptr, rhs_ptr), + ) + + result = cgutils.create_struct_proxy(udf_string)( + context, builder, value=builder.load(udf_str_ptr) + ) + return result._getvalue() + + def create_binary_string_func(binary_func, retty): """ Provide a wrapper around numba's low-level extension API which diff --git a/python/strings_udf/strings_udf/tests/test_string_udfs.py b/python/strings_udf/strings_udf/tests/test_string_udfs.py index 522433d404f..49663ee02ec 100644 --- a/python/strings_udf/strings_udf/tests/test_string_udfs.py +++ b/python/strings_udf/strings_udf/tests/test_string_udfs.py @@ -302,3 +302,19 @@ def func(st): return st.rstrip(strip_char) run_udf_test(data, func, "str") + + +@pytest.mark.parametrize("concat_char", ["1", "a", "12", " ", "", ".", "@"]) +def test_string_udf_concat(data, concat_char): + def func(st): + return st + concat_char + + run_udf_test(data, func, "str") + + +@pytest.mark.parametrize("concat_char", ["1", "a", "12", " ", "", ".", "@"]) +def test_string_udf_concat_reflected(data, concat_char): + def func(st): + return concat_char + st + + run_udf_test(data, func, "str")