Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex flags to strings findall functions #10208

Merged
15 changes: 10 additions & 5 deletions cpp/include/cudf/strings/findall.hpp
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
*/
#pragma once

#include <cudf/strings/regex/flags.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/table/table.hpp>

Expand Down Expand Up @@ -47,14 +48,16 @@ namespace strings {
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param strings Strings instance for this operation.
* @param input Strings instance for this operation.
* @param pattern Regex pattern to match within each string.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned table's device memory.
* @return New table of strings columns.
*/
std::unique_ptr<table> findall(
strings_column_view const& strings,
strings_column_view const& input,
vyasr marked this conversation as resolved.
Show resolved Hide resolved
std::string const& pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -77,14 +80,16 @@ std::unique_ptr<table> findall(
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param strings Strings instance for this operation.
* @param input Strings instance for this operation.
* @param pattern Regex pattern to match within each string.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New lists column of strings.
*/
std::unique_ptr<column> findall_record(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of doxygen group
Expand Down
18 changes: 10 additions & 8 deletions cpp/src/strings/search/findall.cu
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -110,17 +110,18 @@ struct findall_count_fn : public findall_fn<stack_size> {

//
std::unique_ptr<table> findall(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto const strings_count = strings.size();
auto const d_strings = column_device_view::create(strings.parent(), stream);
auto const strings_count = input.size();
auto const d_strings = column_device_view::create(input.parent(), stream);

auto const d_flags = detail::get_character_flags_table();
// compile regex into device object
auto const d_prog = reprog_device::create(pattern, d_flags, strings_count, stream);
auto const d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
vyasr marked this conversation as resolved.
Show resolved Hide resolved
auto const regex_insts = d_prog->insts_counts();

rmm::device_uvector<size_type> find_counts(strings_count, stream);
Expand Down Expand Up @@ -205,12 +206,13 @@ std::unique_ptr<table> findall(

// external API

std::unique_ptr<table> findall(strings_column_view const& strings,
std::unique_ptr<table> findall(strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::findall(strings, pattern, rmm::cuda_stream_default, mr);
return detail::findall(input, pattern, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/strings/search/findall_record.cu
Expand Up @@ -79,17 +79,18 @@ struct findall_fn {

//
std::unique_ptr<column> findall_record(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto const strings_count = strings.size();
auto const d_strings = column_device_view::create(strings.parent(), stream);
auto const strings_count = input.size();
auto const d_strings = column_device_view::create(input.parent(), stream);

// compile regex into device object
auto const d_prog =
reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);

// Create lists offsets column
auto offsets = count_matches(*d_strings, *d_prog, stream, mr);
Expand Down Expand Up @@ -159,12 +160,13 @@ std::unique_ptr<column> findall_record(

// external API

std::unique_ptr<column> findall_record(strings_column_view const& strings,
std::unique_ptr<column> findall_record(strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::findall_record(strings, pattern, rmm::cuda_stream_default, mr);
return detail::findall_record(input, pattern, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
45 changes: 45 additions & 0 deletions cpp/tests/strings/findall_tests.cpp
Expand Up @@ -97,6 +97,51 @@ TEST_F(StringsFindallTests, FindallRecord)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);
}

TEST_F(StringsFindallTests, Multiline)
{
cudf::test::strings_column_wrapper input({"abc\nfff\nabc", "fff\nabc\nlll", "abc", "", "abc\n"});
auto view = cudf::strings_column_view(input);

{
auto results = cudf::strings::findall(view, "(^abc$)", cudf::strings::regex_flags::MULTILINE);
auto col0 =
cudf::test::strings_column_wrapper({"abc", "abc", "abc", "", "abc"}, {1, 1, 1, 0, 1});
auto col1 = cudf::test::strings_column_wrapper({"abc", "", "", "", ""}, {1, 0, 0, 0, 0});
auto expected = cudf::table_view({col0, col1});
CUDF_TEST_EXPECT_TABLES_EQUAL(results->view(), expected);
}
{
auto results =
cudf::strings::findall_record(view, "(^abc$)", cudf::strings::regex_flags::MULTILINE);
bool valids[] = {1, 1, 1, 0, 1};
using LCW = cudf::test::lists_column_wrapper<cudf::string_view>;
LCW expected({LCW{"abc", "abc"}, LCW{"abc"}, LCW{"abc"}, LCW{}, LCW{"abc"}}, valids);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);
}
}

TEST_F(StringsFindallTests, DotAll)
{
cudf::test::strings_column_wrapper input({"abc\nfa\nef", "fff\nabbc\nfff", "abcdef", ""});
auto view = cudf::strings_column_view(input);

{
auto results = cudf::strings::findall(view, "(b.*f)", cudf::strings::regex_flags::DOTALL);
auto col0 =
cudf::test::strings_column_wrapper({"bc\nfa\nef", "bbc\nfff", "bcdef", ""}, {1, 1, 1, 0});
auto expected = cudf::table_view({col0});
CUDF_TEST_EXPECT_TABLES_EQUAL(results->view(), expected);
}
{
auto results =
cudf::strings::findall_record(view, "(b.*f)", cudf::strings::regex_flags::DOTALL);
bool valids[] = {1, 1, 1, 0};
using LCW = cudf::test::lists_column_wrapper<cudf::string_view>;
LCW expected({LCW{"bc\nfa\nef"}, LCW{"bbc\nfff"}, LCW{"bcdef"}, LCW{}}, valids);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);
}
}

TEST_F(StringsFindallTests, MediumRegex)
{
// This results in 15 regex instructions and falls in the 'medium' range.
Expand Down
7 changes: 5 additions & 2 deletions python/cudf/cudf/_lib/cpp/strings/findall.pxd
Expand Up @@ -5,15 +5,18 @@ from libcpp.string cimport string

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.strings.contains cimport regex_flags
from cudf._lib.cpp.table.table cimport table


cdef extern from "cudf/strings/findall.hpp" namespace "cudf::strings" nogil:

cdef unique_ptr[table] findall(
const column_view& source_strings,
const string& pattern) except +
const string& pattern,
regex_flags flags) except +

cdef unique_ptr[column] findall_record(
const column_view& source_strings,
const string& pattern) except +
const string& pattern,
regex_flags flags) except +
16 changes: 11 additions & 5 deletions python/cudf/cudf/_lib/strings/findall.pyx
@@ -1,5 +1,6 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.

from libc.stdint cimport uint32_t
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libcpp.utility cimport move
Expand All @@ -8,6 +9,7 @@ from cudf._lib.column cimport Column
from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.scalar.scalar cimport string_scalar
from cudf._lib.cpp.strings.contains cimport regex_flags
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
from cudf._lib.cpp.strings.findall cimport (
findall as cpp_findall,
findall_record as cpp_findall_record,
Expand All @@ -17,7 +19,7 @@ from cudf._lib.scalar cimport DeviceScalar
from cudf._lib.utils cimport data_from_unique_ptr


def findall(Column source_strings, pattern):
def findall(Column source_strings, object pattern, uint32_t flags):
"""
Returns data with all non-overlapping matches of `pattern`
in each string of `source_strings`.
Expand All @@ -26,11 +28,13 @@ def findall(Column source_strings, pattern):
cdef column_view source_view = source_strings.view()

cdef string pattern_string = <string>str(pattern).encode()
cdef regex_flags c_flags = <regex_flags>flags

with nogil:
c_result = move(cpp_findall(
source_view,
pattern_string
pattern_string,
c_flags
))

return data_from_unique_ptr(
Expand All @@ -39,7 +43,7 @@ def findall(Column source_strings, pattern):
)


def findall_record(Column source_strings, pattern):
def findall_record(Column source_strings, object pattern, uint32_t flags):
"""
Returns data with all non-overlapping matches of `pattern`
in each string of `source_strings` as a lists column.
Expand All @@ -48,11 +52,13 @@ def findall_record(Column source_strings, pattern):
cdef column_view source_view = source_strings.view()

cdef string pattern_string = <string>str(pattern).encode()
cdef regex_flags c_flags = <regex_flags>flags

with nogil:
c_result = move(cpp_findall_record(
source_view,
pattern_string
pattern_string,
c_flags
))

return Column.from_unique_ptr(move(c_result))
16 changes: 12 additions & 4 deletions python/cudf/cudf/core/column/string.py
Expand Up @@ -3410,6 +3410,8 @@ def findall(
----------
pat : str
Pattern or regular expression.
flags : int, default 0 (no flags)
Flags to pass through to the regex engine (e.g. re.MULTILINE)

Returns
-------
Expand All @@ -3419,7 +3421,8 @@ def findall(

Notes
-----
`flags` parameter is currently not supported.
The `flags` parameter currently only supports re.DOTALL and
re.MULTILINE.

Examples
--------
Expand Down Expand Up @@ -3462,10 +3465,15 @@ def findall(
1 <NA> <NA>
2 b b
"""
if flags != 0:
raise NotImplementedError("`flags` parameter is not yet supported")
if isinstance(pat, re.Pattern):
flags = pat.flags & ~re.U
pat = pat.pattern
vyasr marked this conversation as resolved.
Show resolved Hide resolved
if not _is_supported_regex_flags(flags):
raise NotImplementedError(
"unsupported value for `flags` parameter"
)

data, index = libstrings.findall(self._column, pat)
data, index = libstrings.findall(self._column, pat, flags)
return self._return_or_inplace(
cudf.core.frame.Frame(data, index), expand=expand
)
Expand Down
15 changes: 12 additions & 3 deletions python/cudf/cudf/tests/test_string.py
Expand Up @@ -1775,14 +1775,23 @@ def test_string_count(data, pat, flags):


def test_string_findall():
ps = pd.Series(["Lion", "Monkey", "Rabbit"])
gs = cudf.Series(["Lion", "Monkey", "Rabbit"])
test_data = ["Lion", "Monkey", "Rabbit", "Don\nkey"]
ps = pd.Series(test_data)
gs = cudf.Series(test_data)

assert_eq(ps.str.findall("Monkey")[1][0], gs.str.findall("Monkey")[0][1])
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
assert_eq(ps.str.findall("on")[0][0], gs.str.findall("on")[0][0])
assert_eq(ps.str.findall("on")[1][0], gs.str.findall("on")[0][1])
assert_eq(ps.str.findall("on$")[0][0], gs.str.findall("on$")[0][0])
assert_eq(ps.str.findall("b")[2][1], gs.str.findall("b")[1][2])
assert_eq(ps.str.findall("on$")[0][0], gs.str.findall("on$")[0][0])
assert_eq(
ps.str.findall("on$", re.MULTILINE)[3][0],
gs.str.findall("on$", re.MULTILINE)[0][3],
)
assert_eq(
ps.str.findall("o.*k", re.DOTALL)[3][0],
gs.str.findall("o.*k", re.DOTALL)[0][3],
)


def test_string_replace_multi():
Expand Down