Skip to content

Commit

Permalink
Merge branch 'fp-reduce' of https://github.com/codereport/cudf into f…
Browse files Browse the repository at this point in the history
…p-reduce2
  • Loading branch information
codereport committed Nov 26, 2020
2 parents 2e414c6 + 99e04d7 commit 442c8e1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- PR #6711 Implement `cudf::cast` for `decimal32/64` to/from integer and floating point
- PR #6777 Implement `cudf::unary_operation` for `decimal32` & `decimal64`
- PR #6729 Implement `cudf::cast` for `decimal32/64` to/from different `type_id`
- PR #6814 Implement `cudf::reduce` for `decimal32` and `decimal64`
- PR #6792 Implement `cudf::clamp` for `decimal32` and `decimal64`
- PR #6845 Implement `cudf::copy_if_else` for `decimal32` and `decimal64`
- PR #6805 Implement `cudf::detail::copy_if` for `decimal32` and `decimal64`
Expand Down
32 changes: 16 additions & 16 deletions cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,32 +47,32 @@ struct reduce_dispatch_functor {
std::unique_ptr<scalar> operator()(std::unique_ptr<aggregation> const &agg)
{
switch (k) {
case aggregation::SUM: return reduction::sum(col, output_dtype, stream, mr); break;
case aggregation::PRODUCT: return reduction::product(col, output_dtype, stream, mr); break;
case aggregation::MIN: return reduction::min(col, output_dtype, stream, mr); break;
case aggregation::MAX: return reduction::max(col, output_dtype, stream, mr); break;
case aggregation::ANY: return reduction::any(col, output_dtype, stream, mr); break;
case aggregation::ALL: return reduction::all(col, output_dtype, stream, mr); break;
case aggregation::SUM_OF_SQUARES:
return reduction::sum_of_squares(col, output_dtype, stream, mr);
break;
case aggregation::MEAN: return reduction::mean(col, output_dtype, stream, mr); break;
// clang-format off
case aggregation::SUM: return reduction::sum (col, output_dtype, stream, mr);
case aggregation::PRODUCT: return reduction::product (col, output_dtype, stream, mr);
case aggregation::MIN: return reduction::min (col, output_dtype, stream, mr);
case aggregation::MAX: return reduction::max (col, output_dtype, stream, mr);
case aggregation::ANY: return reduction::any (col, output_dtype, stream, mr);
case aggregation::ALL: return reduction::all (col, output_dtype, stream, mr);
case aggregation::SUM_OF_SQUARES: return reduction::sum_of_squares(col, output_dtype, stream, mr);
case aggregation::MEAN: return reduction::mean (col, output_dtype, stream, mr);
// clang-format on
case aggregation::VARIANCE: {
auto var_agg = static_cast<std_var_aggregation const *>(agg.get());
return reduction::variance(col, output_dtype, var_agg->_ddof, stream, mr);
} break;
}
case aggregation::STD: {
auto var_agg = static_cast<std_var_aggregation const *>(agg.get());
return reduction::standard_deviation(col, output_dtype, var_agg->_ddof, stream, mr);
} break;
}
case aggregation::MEDIAN: {
auto sorted_indices =
detail::sorted_order(table_view{{col}}, {}, {null_order::AFTER}, stream, mr);
auto valid_sorted_indices = split(*sorted_indices, {col.size() - col.null_count()})[0];
auto col_ptr = detail::quantile(
col, {0.5}, interpolation::LINEAR, valid_sorted_indices, true, stream, mr);
return get_element(*col_ptr, 0, mr);
} break;
}
case aggregation::QUANTILE: {
auto quantile_agg = static_cast<quantile_aggregation const *>(agg.get());
CUDF_EXPECTS(quantile_agg->_quantiles.size() == 1,
Expand All @@ -88,19 +88,19 @@ struct reduce_dispatch_functor {
stream,
mr);
return get_element(*col_ptr, 0, mr);
} break;
}
case aggregation::NUNIQUE: {
auto nunique_agg = static_cast<nunique_aggregation const *>(agg.get());
return make_fixed_width_scalar(
detail::distinct_count(
col, nunique_agg->_null_handling, nan_policy::NAN_IS_VALID, stream.value()),
stream.value(),
mr);
} break;
}
case aggregation::NTH_ELEMENT: {
auto nth_agg = static_cast<nth_element_aggregation const *>(agg.get());
return reduction::nth_element(col, nth_agg->_n, nth_agg->_null_handling, stream, mr);
} break;
}
default: CUDF_FAIL("Unsupported reduction operator");
}
}
Expand Down
12 changes: 7 additions & 5 deletions cpp/src/reductions/simple.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,13 @@ struct element_type_dispatcher {
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
if (output_type == col.type())
return cudf::reduction::simple::simple_reduction<ElementType, ElementType, Op>(
col, stream, mr);
auto result =
cudf::reduction::simple::simple_reduction<ElementType, int64_t, Op>(col, stream, mr);
using namespace cudf::reduction::simple;

if (output_type == col.type()) {
return simple_reduction<ElementType, ElementType, Op>(col, stream, mr);
}

auto result = simple_reduction<ElementType, int64_t, Op>(col, stream, mr);
if (output_type == result->type()) return result;
// this will cast the result to the output_type
return cudf::type_dispatcher(output_type,
Expand Down

0 comments on commit 442c8e1

Please sign in to comment.