Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use single kernel to extract all groups in cudf::strings::extract #9358

Merged
merged 11 commits into from
Oct 13, 2021
9 changes: 5 additions & 4 deletions cpp/benchmarks/string/extract_benchmark.cpp
Expand Up @@ -48,7 +48,7 @@ static void BM_extract(benchmark::State& state, int groups)
});

std::string pattern;
while (static_cast<int>(pattern.size()) < groups) {
while (groups--) {
pattern += "(\\d+) ";
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
}

Expand Down Expand Up @@ -86,6 +86,7 @@ static void generate_bench_args(benchmark::internal::Benchmark* b)
->UseManualTime() \
->Unit(benchmark::kMillisecond);

STRINGS_BENCHMARK_DEFINE(small, 2)
STRINGS_BENCHMARK_DEFINE(medium, 10)
STRINGS_BENCHMARK_DEFINE(large, 30)
STRINGS_BENCHMARK_DEFINE(one, 1)
STRINGS_BENCHMARK_DEFINE(two, 2)
STRINGS_BENCHMARK_DEFINE(four, 4)
STRINGS_BENCHMARK_DEFINE(eight, 8)
120 changes: 65 additions & 55 deletions cpp/src/strings/extract.cu
Expand Up @@ -19,13 +19,12 @@

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/extract.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

Expand All @@ -47,29 +46,36 @@ using string_index_pair = thrust::pair<const char*, size_type>;
template <int stack_size>
struct extract_fn {
reprog_device prog;
column_device_view d_strings;
size_type column_index;
column_device_view const d_strings;
cudf::detail::device_2dspan<string_index_pair> d_indices;

__device__ string_index_pair operator()(size_type idx)
__device__ void operator()(size_type idx)
{
if (d_strings.is_null(idx)) return string_index_pair{nullptr, 0};
string_view d_str = d_strings.element<string_view>(idx);
string_index_pair result{nullptr, 0};
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if (prog.find<stack_size>(idx, d_str, begin, end) > 0) {
auto extracted = prog.extract<stack_size>(idx, d_str, begin, end, column_index);
if (extracted) {
auto const offset = d_str.byte_offset(extracted.value().first);
// build index-pair
result = string_index_pair{d_str.data() + offset,
d_str.byte_offset(extracted.value().second) - offset};
auto const groups = prog.group_counts();
auto d_output = d_indices[idx];

if (d_strings.is_valid(idx)) {
auto const d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if (prog.find<stack_size>(idx, d_str, begin, end) > 0) {
for (auto col_idx = 0; col_idx < groups; ++col_idx) {
auto const extracted = prog.extract<stack_size>(idx, d_str, begin, end, col_idx);
d_output[col_idx] = [&] {
if (!extracted) return string_index_pair{nullptr, 0};
auto const offset = d_str.byte_offset((*extracted).first);
return string_index_pair{d_str.data() + offset,
d_str.byte_offset((*extracted).second) - offset};
}();
}
return;
}
}
return result;

// if null row or no match found, fill the output with null entries
thrust::fill(thrust::seq, d_output.begin(), d_output.end(), string_index_pair{nullptr, 0});
}
};

} // namespace

//
Expand All @@ -79,9 +85,9 @@ std::unique_ptr<table> extract(
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_strings = *strings_column;
auto const strings_count = strings.size();
auto const strings_column = column_device_view::create(strings.parent(), stream);
auto const d_strings = *strings_column;

// compile regex into device object
auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
Expand All @@ -90,41 +96,45 @@ std::unique_ptr<table> extract(
auto const groups = d_prog.group_counts();
CUDF_EXPECTS(groups > 0, "Group indicators not found in regex pattern");

rmm::device_uvector<string_index_pair> indices(strings_count * groups, stream);
cudf::detail::device_2dspan<string_index_pair> d_indices(indices.data(), strings_count, groups);

auto const regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS) {
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_MEDIUM_INSTS) {
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, d_indices});
} else if (regex_insts <= RX_LARGE_INSTS) {
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, d_indices});
} else {
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
extract_fn<RX_STACK_ANY>{d_prog, d_strings, d_indices});
}

// build a result column for each group
std::vector<std::unique_ptr<column>> results;
auto regex_insts = d_prog.insts_counts();

for (int32_t column_index = 0; column_index < groups; ++column_index) {
rmm::device_uvector<string_index_pair> indices(strings_count, stream);

if (regex_insts <= RX_SMALL_INSTS) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_SMALL>{d_prog, d_strings, column_index});
} else if (regex_insts <= RX_MEDIUM_INSTS) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, column_index});
} else if (regex_insts <= RX_LARGE_INSTS) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, column_index});
} else {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, column_index});
}

results.emplace_back(make_strings_column(indices, stream, mr));
for (auto column_index = 0; column_index < groups; ++column_index) {
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
// this iterator transposes the extract results into column order
auto indices_itr = thrust::make_permutation_iterator(
indices.begin(),
thrust::make_transform_iterator(thrust::make_counting_iterator<size_type>(0),
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
[column_index, groups] __device__(size_type idx) {
return (idx * groups) + column_index;
}));
results.emplace_back(make_strings_column(indices_itr, indices_itr + strings_count, stream, mr));
}

return std::make_unique<table>(std::move(results));
}

Expand Down