Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Move template param to member var to improve compile of hash/groupby.cu #6835

Merged
merged 4 commits into from
Nov 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
- PR #6829 Enable workaround to write categorical columns in csv
- PR #6819 Use CMake 3.19 for RMM when building cuDF jar
- PR #6833 Use settings.xml if existing for internal build
- PR #6835 Move template param to member var to improve compile of hash/groupby.cu
- PR #6837 Avoid gather when copying strings view from start of strings column

## Bug Fixes
Expand Down
34 changes: 13 additions & 21 deletions cpp/src/groupby/hash/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,27 +279,19 @@ void compute_single_pass_aggs(table_view const& keys,

bool skip_key_rows_with_nulls = keys_have_nulls and include_null_keys == null_policy::EXCLUDE;

if (skip_key_rows_with_nulls) {
auto row_bitmask{cudf::detail::bitmask_and(keys, stream)};
thrust::for_each_n(
rmm::exec_policy(stream)->on(stream.value()),
thrust::make_counting_iterator(0),
keys.num_rows(),
hash::compute_single_pass_aggs<true, Map>{map,
keys.num_rows(),
*d_values,
*d_sparse_table,
d_aggs.data().get(),
static_cast<bitmask_type*>(row_bitmask.data())});
} else {
thrust::for_each_n(
rmm::exec_policy(stream)->on(stream.value()),
thrust::make_counting_iterator(0),
keys.num_rows(),
hash::compute_single_pass_aggs<false, Map>{
map, keys.num_rows(), *d_values, *d_sparse_table, d_aggs.data().get(), nullptr});
}

auto row_bitmask =
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
skip_key_rows_with_nulls ? cudf::detail::bitmask_and(keys, stream) : rmm::device_buffer{};
thrust::for_each_n(
rmm::exec_policy(stream)->on(stream.value()),
thrust::make_counting_iterator(0),
keys.num_rows(),
hash::compute_single_pass_aggs_fn<Map>{map,
keys.num_rows(),
*d_values,
*d_sparse_table,
d_aggs.data().get(),
static_cast<bitmask_type*>(row_bitmask.data()),
skip_key_rows_with_nulls});
// Add results back to sparse_results cache
auto sparse_result_cols = sparse_table.release();
for (size_t i = 0; i < aggs.size(); i++) {
Expand Down
31 changes: 16 additions & 15 deletions cpp/src/groupby/hash/groupby_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,20 @@ namespace hash {
* rows. In this way, after all rows are aggregated, `output_values` will likely
* be "sparse", meaning that not all rows contain the result of an aggregation.
*
* @tparam skip_rows_with_nulls Indicates if rows in `input_keys` containing
* null values should be skipped. It `true`, it is assumed `row_bitmask` is a
* bitmask where bit `i` indicates the presence of a null value in row `i`.
* @tparam Map The type of the hash map
*/
template <bool skip_rows_with_nulls, typename Map>
struct compute_single_pass_aggs {
template <typename Map>
struct compute_single_pass_aggs_fn {
Map map;
size_type num_keys;
table_device_view input_values;
mutable_table_device_view output_values;
aggregation::Kind const* __restrict__ aggs;
bitmask_type const* __restrict__ row_bitmask;
bool skip_rows_with_nulls;

/**
* @brief Construct a new compute_single_pass_aggs functor object
* @brief Construct a new compute_single_pass_aggs_fn functor object
*
* @param map Hash map object to insert key,value pairs into.
* @param num_keys The number of rows in input keys table
Expand All @@ -84,19 +82,24 @@ struct compute_single_pass_aggs {
* columns of the `input_values` rows
* @param row_bitmask Bitmask where bit `i` indicates the presence of a null
* value in row `i` of input keys. Only used if `skip_rows_with_nulls` is `true`
* @param skip_rows_with_nulls Indicates if rows in `input_keys` containing
* null values should be skipped. It `true`, it is assumed `row_bitmask` is a
* bitmask where bit `i` indicates the presence of a null value in row `i`.
*/
compute_single_pass_aggs(Map map,
size_type num_keys,
table_device_view input_values,
mutable_table_device_view output_values,
aggregation::Kind const* aggs,
bitmask_type const* row_bitmask)
compute_single_pass_aggs_fn(Map map,
size_type num_keys,
table_device_view input_values,
mutable_table_device_view output_values,
aggregation::Kind const* aggs,
bitmask_type const* row_bitmask,
bool skip_rows_with_nulls)
: map(map),
num_keys(num_keys),
input_values(input_values),
output_values(output_values),
aggs(aggs),
row_bitmask(row_bitmask)
row_bitmask(row_bitmask),
skip_rows_with_nulls(skip_rows_with_nulls)
{
}

Expand All @@ -111,8 +114,6 @@ struct compute_single_pass_aggs {
}
};

// TODO (dm): variance kernel

} // namespace hash
} // namespace detail
} // namespace groupby
Expand Down