Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scan_aggregation and reduce_aggregation derived types. #10357

Merged
merged 10 commits into from Mar 11, 2022
Merged
22 changes: 22 additions & 0 deletions cpp/include/cudf/aggregation.hpp
Expand Up @@ -148,6 +148,28 @@ class groupby_scan_aggregation : public virtual aggregation {
groupby_scan_aggregation() {}
};

/**
* @brief Derived class intended for reduction usage.
*/
class reduce_aggregation : public virtual aggregation {
public:
~reduce_aggregation() override = default;

protected:
reduce_aggregation() {}
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved
};

/**
* @brief Derived class intended for scan usage.
*/
class scan_aggregation : public virtual aggregation {
public:
~scan_aggregation() override = default;

protected:
scan_aggregation() {}
};

enum class udf_type : bool { CUDA, PTX };
enum class correlation_type : int32_t { PEARSON, KENDALL, SPEARMAN };

Expand Down
50 changes: 34 additions & 16 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Expand Up @@ -147,7 +147,9 @@ class aggregation_finalizer { // Declares the interface for the finalizer
*/
class sum_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation {
public groupby_scan_aggregation,
public reduce_aggregation,
public scan_aggregation {
public:
sum_aggregation() : aggregation(SUM) {}

Expand All @@ -166,7 +168,9 @@ class sum_aggregation final : public rolling_aggregation,
/**
* @brief Derived class for specifying a product aggregation
*/
class product_aggregation final : public groupby_aggregation {
class product_aggregation final : public groupby_aggregation,
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved
public reduce_aggregation,
public scan_aggregation {
public:
product_aggregation() : aggregation(PRODUCT) {}

Expand All @@ -187,7 +191,9 @@ class product_aggregation final : public groupby_aggregation {
*/
class min_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation {
public groupby_scan_aggregation,
public reduce_aggregation,
public scan_aggregation {
public:
min_aggregation() : aggregation(MIN) {}

Expand All @@ -208,7 +214,9 @@ class min_aggregation final : public rolling_aggregation,
*/
class max_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation {
public groupby_scan_aggregation,
public reduce_aggregation,
public scan_aggregation {
public:
max_aggregation() : aggregation(MAX) {}

Expand Down Expand Up @@ -248,7 +256,7 @@ class count_aggregation final : public rolling_aggregation,
/**
* @brief Derived class for specifying an any aggregation
*/
class any_aggregation final : public aggregation {
class any_aggregation final : public reduce_aggregation {
public:
any_aggregation() : aggregation(ANY) {}

Expand All @@ -267,7 +275,7 @@ class any_aggregation final : public aggregation {
/**
* @brief Derived class for specifying an all aggregation
*/
class all_aggregation final : public aggregation {
class all_aggregation final : public reduce_aggregation {
public:
all_aggregation() : aggregation(ALL) {}

Expand All @@ -286,7 +294,7 @@ class all_aggregation final : public aggregation {
/**
* @brief Derived class for specifying a sum_of_squares aggregation
*/
class sum_of_squares_aggregation final : public groupby_aggregation {
class sum_of_squares_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
sum_of_squares_aggregation() : aggregation(SUM_OF_SQUARES) {}

Expand All @@ -305,7 +313,9 @@ class sum_of_squares_aggregation final : public groupby_aggregation {
/**
* @brief Derived class for specifying a mean aggregation
*/
class mean_aggregation final : public rolling_aggregation, public groupby_aggregation {
class mean_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
mean_aggregation() : aggregation(MEAN) {}

Expand Down Expand Up @@ -343,7 +353,9 @@ class m2_aggregation : public groupby_aggregation {
/**
* @brief Derived class for specifying a standard deviation/variance aggregation
*/
class std_var_aggregation : public rolling_aggregation, public groupby_aggregation {
class std_var_aggregation : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
size_type _ddof; ///< Delta degrees of freedom

Expand Down Expand Up @@ -415,7 +427,7 @@ class std_aggregation final : public std_var_aggregation {
/**
* @brief Derived class for specifying a median aggregation
*/
class median_aggregation final : public groupby_aggregation {
class median_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
median_aggregation() : aggregation(MEDIAN) {}

Expand All @@ -434,7 +446,7 @@ class median_aggregation final : public groupby_aggregation {
/**
* @brief Derived class for specifying a quantile aggregation
*/
class quantile_aggregation final : public groupby_aggregation {
class quantile_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
quantile_aggregation(std::vector<double> const& q, interpolation i)
: aggregation{QUANTILE}, _quantiles{q}, _interpolation{i}
Expand Down Expand Up @@ -521,7 +533,7 @@ class argmin_aggregation final : public rolling_aggregation, public groupby_aggr
/**
* @brief Derived class for specifying a nunique aggregation
*/
class nunique_aggregation final : public groupby_aggregation {
class nunique_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
nunique_aggregation(null_policy null_handling)
: aggregation{NUNIQUE}, _null_handling{null_handling}
Expand Down Expand Up @@ -560,7 +572,7 @@ class nunique_aggregation final : public groupby_aggregation {
/**
* @brief Derived class for specifying a nth element aggregation
*/
class nth_element_aggregation final : public groupby_aggregation {
class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
nth_element_aggregation(size_type n, null_policy null_handling)
: aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling}
Expand Down Expand Up @@ -622,7 +634,9 @@ class row_number_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a rank aggregation
*/
class rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
class rank_aggregation final : public rolling_aggregation,
public groupby_scan_aggregation,
public scan_aggregation {
public:
rank_aggregation() : aggregation{RANK} {}

Expand All @@ -641,7 +655,9 @@ class rank_aggregation final : public rolling_aggregation, public groupby_scan_a
/**
* @brief Derived class for specifying a dense rank aggregation
*/
class dense_rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
class dense_rank_aggregation final : public rolling_aggregation,
public groupby_scan_aggregation,
public scan_aggregation {
public:
dense_rank_aggregation() : aggregation{DENSE_RANK} {}

Expand All @@ -657,7 +673,9 @@ class dense_rank_aggregation final : public rolling_aggregation, public groupby_
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

class percent_rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
class percent_rank_aggregation final : public rolling_aggregation,
public groupby_scan_aggregation,
public scan_aggregation {
public:
percent_rank_aggregation() : aggregation{PERCENT_RANK} {}

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/detail/scan.hpp
Expand Up @@ -47,7 +47,7 @@ namespace detail {
* @returns Column with scan results.
*/
std::unique_ptr<column> scan_exclusive(column_view const& input,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<scan_aggregation> const& agg,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
Expand All @@ -73,7 +73,7 @@ std::unique_ptr<column> scan_exclusive(column_view const& input,
* @returns Column with scan results.
*/
std::unique_ptr<column> scan_inclusive(column_view const& input,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<scan_aggregation> const& agg,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/cudf/reduction.hpp
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,7 +65,7 @@ enum class scan_type : bool { INCLUSIVE, EXCLUSIVE };
*/
std::unique_ptr<scalar> reduce(
column_view const& col,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<reduce_aggregation> const& agg,
data_type output_dtype,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand All @@ -89,7 +89,7 @@ std::unique_ptr<scalar> reduce(
*/
std::unique_ptr<column> scan(
const column_view& input,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<scan_aggregation> const& agg,
scan_type inclusive,
null_policy null_handling = null_policy::EXCLUDE,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
Expand Down
26 changes: 26 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Expand Up @@ -417,6 +417,8 @@ template std::unique_ptr<aggregation> make_sum_aggregation<aggregation>();
template std::unique_ptr<rolling_aggregation> make_sum_aggregation<rolling_aggregation>();
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved
template std::unique_ptr<groupby_aggregation> make_sum_aggregation<groupby_aggregation>();
template std::unique_ptr<groupby_scan_aggregation> make_sum_aggregation<groupby_scan_aggregation>();
template std::unique_ptr<reduce_aggregation> make_sum_aggregation<reduce_aggregation>();
template std::unique_ptr<scan_aggregation> make_sum_aggregation<scan_aggregation>();

/// Factory to create a PRODUCT aggregation
template <typename Base>
Expand All @@ -426,6 +428,8 @@ std::unique_ptr<Base> make_product_aggregation()
}
template std::unique_ptr<aggregation> make_product_aggregation<aggregation>();
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved
template std::unique_ptr<groupby_aggregation> make_product_aggregation<groupby_aggregation>();
template std::unique_ptr<reduce_aggregation> make_product_aggregation<reduce_aggregation>();
template std::unique_ptr<scan_aggregation> make_product_aggregation<scan_aggregation>();

/// Factory to create a MIN aggregation
template <typename Base>
Expand All @@ -437,6 +441,8 @@ template std::unique_ptr<aggregation> make_min_aggregation<aggregation>();
template std::unique_ptr<rolling_aggregation> make_min_aggregation<rolling_aggregation>();
template std::unique_ptr<groupby_aggregation> make_min_aggregation<groupby_aggregation>();
template std::unique_ptr<groupby_scan_aggregation> make_min_aggregation<groupby_scan_aggregation>();
template std::unique_ptr<reduce_aggregation> make_min_aggregation<reduce_aggregation>();
template std::unique_ptr<scan_aggregation> make_min_aggregation<scan_aggregation>();

/// Factory to create a MAX aggregation
template <typename Base>
Expand All @@ -448,6 +454,8 @@ template std::unique_ptr<aggregation> make_max_aggregation<aggregation>();
template std::unique_ptr<rolling_aggregation> make_max_aggregation<rolling_aggregation>();
template std::unique_ptr<groupby_aggregation> make_max_aggregation<groupby_aggregation>();
template std::unique_ptr<groupby_scan_aggregation> make_max_aggregation<groupby_scan_aggregation>();
template std::unique_ptr<reduce_aggregation> make_max_aggregation<reduce_aggregation>();
template std::unique_ptr<scan_aggregation> make_max_aggregation<scan_aggregation>();

/// Factory to create a COUNT aggregation
template <typename Base>
Expand All @@ -473,6 +481,7 @@ std::unique_ptr<Base> make_any_aggregation()
return std::make_unique<detail::any_aggregation>();
}
template std::unique_ptr<aggregation> make_any_aggregation<aggregation>();
template std::unique_ptr<reduce_aggregation> make_any_aggregation<reduce_aggregation>();

/// Factory to create a ALL aggregation
template <typename Base>
Expand All @@ -481,6 +490,7 @@ std::unique_ptr<Base> make_all_aggregation()
return std::make_unique<detail::all_aggregation>();
}
template std::unique_ptr<aggregation> make_all_aggregation<aggregation>();
template std::unique_ptr<reduce_aggregation> make_all_aggregation<reduce_aggregation>();

/// Factory to create a SUM_OF_SQUARES aggregation
template <typename Base>
Expand All @@ -491,6 +501,7 @@ std::unique_ptr<Base> make_sum_of_squares_aggregation()
template std::unique_ptr<aggregation> make_sum_of_squares_aggregation<aggregation>();
template std::unique_ptr<groupby_aggregation>
make_sum_of_squares_aggregation<groupby_aggregation>();
template std::unique_ptr<reduce_aggregation> make_sum_of_squares_aggregation<reduce_aggregation>();

/// Factory to create a MEAN aggregation
template <typename Base>
Expand All @@ -501,6 +512,7 @@ std::unique_ptr<Base> make_mean_aggregation()
template std::unique_ptr<aggregation> make_mean_aggregation<aggregation>();
template std::unique_ptr<rolling_aggregation> make_mean_aggregation<rolling_aggregation>();
template std::unique_ptr<groupby_aggregation> make_mean_aggregation<groupby_aggregation>();
template std::unique_ptr<reduce_aggregation> make_mean_aggregation<reduce_aggregation>();

/// Factory to create a M2 aggregation
template <typename Base>
Expand All @@ -522,6 +534,8 @@ template std::unique_ptr<rolling_aggregation> make_variance_aggregation<rolling_
size_type ddof);
template std::unique_ptr<groupby_aggregation> make_variance_aggregation<groupby_aggregation>(
size_type ddof);
template std::unique_ptr<reduce_aggregation> make_variance_aggregation<reduce_aggregation>(
size_type ddof);

/// Factory to create a STD aggregation
template <typename Base>
Expand All @@ -534,6 +548,8 @@ template std::unique_ptr<rolling_aggregation> make_std_aggregation<rolling_aggre
size_type ddof);
template std::unique_ptr<groupby_aggregation> make_std_aggregation<groupby_aggregation>(
size_type ddof);
template std::unique_ptr<reduce_aggregation> make_std_aggregation<reduce_aggregation>(
size_type ddof);

/// Factory to create a MEDIAN aggregation
template <typename Base>
Expand All @@ -543,6 +559,7 @@ std::unique_ptr<Base> make_median_aggregation()
}
template std::unique_ptr<aggregation> make_median_aggregation<aggregation>();
template std::unique_ptr<groupby_aggregation> make_median_aggregation<groupby_aggregation>();
template std::unique_ptr<reduce_aggregation> make_median_aggregation<reduce_aggregation>();

/// Factory to create a QUANTILE aggregation
template <typename Base>
Expand All @@ -555,6 +572,8 @@ template std::unique_ptr<aggregation> make_quantile_aggregation<aggregation>(
std::vector<double> const& quantiles, interpolation interp);
template std::unique_ptr<groupby_aggregation> make_quantile_aggregation<groupby_aggregation>(
std::vector<double> const& quantiles, interpolation interp);
template std::unique_ptr<reduce_aggregation> make_quantile_aggregation<reduce_aggregation>(
std::vector<double> const& quantiles, interpolation interp);

/// Factory to create an ARGMAX aggregation
template <typename Base>
Expand Down Expand Up @@ -586,6 +605,8 @@ template std::unique_ptr<aggregation> make_nunique_aggregation<aggregation>(
null_policy null_handling);
template std::unique_ptr<groupby_aggregation> make_nunique_aggregation<groupby_aggregation>(
null_policy null_handling);
template std::unique_ptr<reduce_aggregation> make_nunique_aggregation<reduce_aggregation>(
null_policy null_handling);

/// Factory to create an NTH_ELEMENT aggregation
template <typename Base>
Expand All @@ -597,6 +618,8 @@ template std::unique_ptr<aggregation> make_nth_element_aggregation<aggregation>(
size_type n, null_policy null_handling);
template std::unique_ptr<groupby_aggregation> make_nth_element_aggregation<groupby_aggregation>(
size_type n, null_policy null_handling);
template std::unique_ptr<reduce_aggregation> make_nth_element_aggregation<reduce_aggregation>(
size_type n, null_policy null_handling);

/// Factory to create a ROW_NUMBER aggregation
template <typename Base>
Expand All @@ -616,6 +639,7 @@ std::unique_ptr<Base> make_rank_aggregation()
template std::unique_ptr<aggregation> make_rank_aggregation<aggregation>();
template std::unique_ptr<groupby_scan_aggregation>
make_rank_aggregation<groupby_scan_aggregation>();
template std::unique_ptr<scan_aggregation> make_rank_aggregation<scan_aggregation>();

/// Factory to create a DENSE_RANK aggregation
template <typename Base>
Expand All @@ -626,6 +650,7 @@ std::unique_ptr<Base> make_dense_rank_aggregation()
template std::unique_ptr<aggregation> make_dense_rank_aggregation<aggregation>();
template std::unique_ptr<groupby_scan_aggregation>
make_dense_rank_aggregation<groupby_scan_aggregation>();
template std::unique_ptr<scan_aggregation> make_dense_rank_aggregation<scan_aggregation>();

/// Factory to create a PERCENT_RANK aggregation
template <typename Base>
Expand All @@ -636,6 +661,7 @@ std::unique_ptr<Base> make_percent_rank_aggregation()
template std::unique_ptr<aggregation> make_percent_rank_aggregation<aggregation>();
template std::unique_ptr<groupby_scan_aggregation>
make_percent_rank_aggregation<groupby_scan_aggregation>();
template std::unique_ptr<scan_aggregation> make_percent_rank_aggregation<scan_aggregation>();

/// Factory to create a COLLECT_LIST aggregation
template <typename Base>
Expand Down