Skip to content

Commit

Permalink
Strong index types for equality comparator (#10883)
Browse files Browse the repository at this point in the history
This adds strong index types for equality comparator, along with #10730 to unblock #10548, #10656, and several others nested type feature requests.

Authors:
  - Nghia Truong (https://github.com/ttnghia)
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Bradley Dice (https://github.com/bdice)

URL: #10883
  • Loading branch information
ttnghia authored May 19, 2022
1 parent c9bc82e commit 54789ee
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 45 deletions.
5 changes: 3 additions & 2 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,14 @@ outputs:
- test -f $PREFIX/include/cudf/lists/detail/scatter_helper.cuh
- test -f $PREFIX/include/cudf/lists/detail/stream_compaction.hpp
- test -f $PREFIX/include/cudf/lists/combine.hpp
- test -f $PREFIX/include/cudf/lists/contains.hpp
- test -f $PREFIX/include/cudf/lists/count_elements.hpp
- test -f $PREFIX/include/cudf/lists/explode.hpp
- test -f $PREFIX/include/cudf/lists/drop_list_duplicates.hpp
- test -f $PREFIX/include/cudf/lists/explode.hpp
- test -f $PREFIX/include/cudf/lists/extract.hpp
- test -f $PREFIX/include/cudf/lists/filling.hpp
- test -f $PREFIX/include/cudf/lists/contains.hpp
- test -f $PREFIX/include/cudf/lists/gather.hpp
- test -f $PREFIX/include/cudf/lists/list_view.hpp
- test -f $PREFIX/include/cudf/lists/lists_column_view.hpp
- test -f $PREFIX/include/cudf/lists/sorting.hpp
- test -f $PREFIX/include/cudf/lists/stream_compaction.hpp
Expand Down
94 changes: 94 additions & 0 deletions cpp/include/cudf/table/experimental/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ namespace equality {
template <typename Nullate>
class device_row_comparator {
friend class self_comparator;
friend class two_table_comparator;

public:
/**
Expand Down Expand Up @@ -855,6 +856,7 @@ struct preprocessed_table {

private:
friend class self_comparator;
friend class two_table_comparator;
friend class hash::row_hasher;

using table_device_view_owner =
Expand Down Expand Up @@ -923,6 +925,98 @@ class self_comparator {
std::shared_ptr<preprocessed_table> d_t;
};

template <typename Comparator>
struct strong_index_comparator_adapter {
__device__ constexpr bool operator()(lhs_index_type const lhs_index,
rhs_index_type const rhs_index) const noexcept
{
return comparator(static_cast<cudf::size_type>(lhs_index),
static_cast<cudf::size_type>(rhs_index));
}

__device__ constexpr bool operator()(rhs_index_type const rhs_index,
lhs_index_type const lhs_index) const noexcept
{
return this->operator()(lhs_index, rhs_index);
}

Comparator const comparator;
};

/**
* @brief An owning object that can be used to equality compare rows of two different tables.
*
* This class takes two table_views and preprocesses certain columns to allow for equality
* comparison. The preprocessed table and temporary data required for the comparison are created and
* owned by this class.
*
* Alternatively, `two_table_comparator` can be constructed from two existing
* `shared_ptr<preprocessed_table>`s when sharing the same tables among multiple comparators.
*
* This class can then provide a functor object that can used on the device.
* The object of this class must outlive the usage of the device functor.
*/
class two_table_comparator {
public:
/**
* @brief Construct an owning object for performing equality comparisons between two rows from two
* tables.
*
* The left and right table are expected to have the same number of columns and data types for
* each column.
*
* @param left The left table to compare.
* @param right The right table to compare.
* @param stream The stream to construct this object on. Not the stream that will be used for
* comparisons using this object.
*/
two_table_comparator(table_view const& left,
table_view const& right,
rmm::cuda_stream_view stream);

/**
* @brief Construct an owning object for performing equality comparisons between two rows from two
* tables.
*
* This constructor allows independently constructing a `preprocessed_table` and sharing it among
* multiple comparators.
*
* @param left The left table preprocessed for equality comparison.
* @param right The right table preprocessed for equality comparison.
*/
two_table_comparator(std::shared_ptr<preprocessed_table> left,
std::shared_ptr<preprocessed_table> right)
: d_left_table{std::move(left)}, d_right_table{std::move(right)}
{
}

/**
* @brief Return the binary operator for comparing rows in the table.
*
* Returns a binary callable, `F`, with signatures `bool F(lhs_index_type, rhs_index_type)` and
* `bool F(rhs_index_type, lhs_index_type)`.
*
* `F(lhs_index_type i, rhs_index_type j)` returns true if and only if row `i` of the left table
* compares equal to row `j` of the right table.
*
* Similarly, `F(rhs_index_type i, lhs_index_type j)` returns true if and only if row `i` of the
* right table compares equal to row `j` of the left table.
*
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
auto device_comparator(Nullate nullate = {},
null_equality nulls_are_equal = null_equality::EQUAL) const
{
return strong_index_comparator_adapter<device_row_comparator<Nullate>>{
device_row_comparator<Nullate>(nullate, *d_left_table, *d_right_table, nulls_are_equal)};
}

private:
std::shared_ptr<preprocessed_table> d_left_table;
std::shared_ptr<preprocessed_table> d_right_table;
};

} // namespace equality

namespace hash {
Expand Down
66 changes: 23 additions & 43 deletions cpp/src/structs/search/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
*/

#include <cudf/detail/structs/utilities.hpp>
#include <cudf/scalar/scalar_device_view.cuh>
#include <cudf/structs/detail/contains.hpp>
#include <cudf/table/experimental/row_operators.cuh>
#include <cudf/table/row_operators.cuh>
#include <cudf/table/table_device_view.cuh>
#include <cudf/table/table_view.hpp>
Expand All @@ -35,52 +35,32 @@ bool contains(structs_column_view const& haystack,
scalar const& needle,
rmm::cuda_stream_view stream)
{
CUDF_EXPECTS(haystack.type() == needle.type(), "scalar and column types must match");
auto const haystack_tv = table_view{{haystack}};
// Create a (structs) column_view of one row having children given from the input scalar.
auto const needle_tv = static_cast<struct_scalar const*>(&needle)->view();
auto const needle_as_col =
column_view(data_type{type_id::STRUCT},
1,
nullptr,
nullptr,
0,
0,
std::vector<column_view>{needle_tv.begin(), needle_tv.end()});

auto const scalar_table = static_cast<struct_scalar const*>(&needle)->view();
CUDF_EXPECTS(haystack.num_children() == scalar_table.num_columns(),
"struct scalar and structs column must have the same number of children");
for (size_type i = 0; i < haystack.num_children(); ++i) {
CUDF_EXPECTS(haystack.child(i).type() == scalar_table.column(i).type(),
"scalar and column children types must match");
}
// Haystack and needle compatibility is checked by the table comparator constructor.
auto const comparator = cudf::experimental::row::equality::two_table_comparator(
haystack_tv, table_view{{needle_as_col}}, stream);
auto const has_nulls = has_nested_nulls(haystack_tv) || has_nested_nulls(needle_tv);
auto const d_comp = comparator.device_comparator(nullate::DYNAMIC{has_nulls});

// Prepare to flatten the structs column and scalar.
auto const has_null_elements = has_nested_nulls(table_view{std::vector<column_view>{
haystack.child_begin(), haystack.child_end()}}) ||
has_nested_nulls(scalar_table);
auto const flatten_nullability = has_null_elements
? structs::detail::column_nullability::FORCE
: structs::detail::column_nullability::MATCH_INCOMING;

// Flatten the input structs column, only materialize the bitmask if there is null in the input.
auto const haystack_flattened =
structs::detail::flatten_nested_columns(table_view{{haystack}}, {}, {}, flatten_nullability);
auto const needle_flattened =
structs::detail::flatten_nested_columns(scalar_table, {}, {}, flatten_nullability);

// The struct scalar only contains the struct member columns.
// Thus, if there is any null in the input, we must exclude the first column in the flattened
// table of the input column from searching because that column is the materialized bitmask of
// the input structs column.
auto const haystack_flattened_content = haystack_flattened.flattened_columns();
auto const haystack_flattened_children = table_view{std::vector<column_view>{
haystack_flattened_content.begin() + static_cast<size_type>(has_null_elements),
haystack_flattened_content.end()}};

auto const d_haystack_children_ptr =
table_device_view::create(haystack_flattened_children, stream);
auto const d_needle_ptr = table_device_view::create(needle_flattened, stream);

auto const start_iter = thrust::make_counting_iterator<size_type>(0);
auto const start_iter = cudf::experimental::row::lhs_iterator(0);
auto const end_iter = start_iter + haystack.size();
auto const comp = row_equality_comparator(nullate::DYNAMIC{has_null_elements},
*d_haystack_children_ptr,
*d_needle_ptr,
null_equality::EQUAL);
using cudf::experimental::row::rhs_index_type;

auto const found_iter = thrust::find_if(
rmm::exec_policy(stream), start_iter, end_iter, [comp] __device__(auto const idx) {
return comp(idx, 0); // compare haystack[idx] == val[0].
rmm::exec_policy(stream), start_iter, end_iter, [d_comp] __device__(auto const idx) {
// Compare haystack[idx] == needle_as_col[0].
return d_comp(idx, static_cast<rhs_index_type>(0));
});

return found_iter != end_iter;
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/table/row_operators.cu
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,15 @@ std::shared_ptr<preprocessed_table> preprocessed_table::create(table_view const&
new preprocessed_table(std::move(d_t), std::move(std::get<1>(null_pushed_table))));
}

two_table_comparator::two_table_comparator(table_view const& left,
table_view const& right,
rmm::cuda_stream_view stream)
: d_left_table{preprocessed_table::create(left, stream)},
d_right_table{preprocessed_table::create(right, stream)}
{
check_shape_compatibility(left, right);
}

} // namespace equality

} // namespace row
Expand Down

0 comments on commit 54789ee

Please sign in to comment.