Skip to content

Commit

Permalink
adding associative binary reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-eschmann committed Mar 28, 2024
1 parent 445daec commit 383c06e
Showing 1 changed file with 65 additions and 10 deletions.
75 changes: 65 additions & 10 deletions include/rl_tools/containers/tensor/operations_generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,6 @@ namespace rl_tools{
static constexpr auto OPERATION = T_OPERATION;
PARAMETER parameter;
};
template <typename PARAMETER, typename T_ACCUMULATOR_TYPE, typename T_CURRENT_TYPE, auto T_OPERATION>
struct ReduceOperation{
using ACCUMULATOR_TYPE = T_ACCUMULATOR_TYPE;
using CURRENT_TYPE = T_CURRENT_TYPE;
static constexpr auto OPERATION = T_OPERATION;
PARAMETER parameter;
ACCUMULATOR_TYPE initial_value;
};
namespace binary_operations{
template <typename T>
T add(T a, T b){
Expand Down Expand Up @@ -184,6 +176,14 @@ namespace rl_tools{
return parameter;
}
}
template <typename PARAMETER, typename T_ACCUMULATOR_TYPE, typename T_CURRENT_TYPE, auto T_OPERATION>
struct UnaryReduceOperation{
using ACCUMULATOR_TYPE = T_ACCUMULATOR_TYPE;
using CURRENT_TYPE = T_CURRENT_TYPE;
static constexpr auto OPERATION = T_OPERATION;
PARAMETER parameter;
ACCUMULATOR_TYPE initial_value;
};
namespace unary_reduce_operations{
namespace impl{
template <typename PARAMETER, typename ACCUMULATOR_TYPE, typename CURRENT_TYPE>
Expand All @@ -192,7 +192,31 @@ namespace rl_tools{
}
}
template <typename T>
using Sum = ReduceOperation<OperationEmptyParameter, T, T, impl::sum<OperationEmptyParameter, T, T>>;
using Sum = UnaryReduceOperation<OperationEmptyParameter, T, T, impl::sum<OperationEmptyParameter, T, T>>;
}
template <typename PARAMETER, typename T_ACCUMULATOR_TYPE, typename T_CURRENT_TYPE1, typename T_CURRENT_TYPE2, auto T_OPERATION, auto T_REDUCE_OPERATION>
struct BinaryReduceOperation{
using ACCUMULATOR_TYPE = T_ACCUMULATOR_TYPE;
using CURRENT_TYPE1 = T_CURRENT_TYPE1;
using CURRENT_TYPE2 = T_CURRENT_TYPE2;
static constexpr auto OPERATION = T_OPERATION;
static constexpr auto REDUCE_OPERATION = T_REDUCE_OPERATION; // the associative part
PARAMETER parameter;
ACCUMULATOR_TYPE initial_value;
};
namespace binary_reduce_operations{
namespace impl{
template <typename DEVICE, typename PARAMETER, typename ACCUMULATOR_TYPE, typename CURRENT_TYPE1, typename CURRENT_TYPE2>
ACCUMULATOR_TYPE absolute_difference(const DEVICE& device, const PARAMETER& parameter, const ACCUMULATOR_TYPE& accumulator, CURRENT_TYPE1 current1, CURRENT_TYPE2 current2){
return accumulator + math::abs(device.math, current1 - current2);
}
template <typename DEVICE, typename PARAMETER, typename ACCUMULATOR_TYPE, typename CURRENT_TYPE>
ACCUMULATOR_TYPE sum(const DEVICE& device, const PARAMETER& parameter, const ACCUMULATOR_TYPE& accumulator, CURRENT_TYPE current){
return accumulator + current;
}
}
template <typename DEVICE, typename T1, typename T2>
using AbsoluteDifference = BinaryReduceOperation<OperationEmptyParameter, T1, T1, T2, impl::absolute_difference<DEVICE, OperationEmptyParameter, T1, T1, T2>, impl::sum<DEVICE, OperationEmptyParameter, T1, T1>>;
}
}
template<typename DEVICE, typename SPEC_1, typename SPEC_2, typename SPEC_OUT, auto BINARY_OPERATION, typename OPERATION_PARAMETER>
Expand Down Expand Up @@ -244,7 +268,7 @@ namespace rl_tools{
}

template<typename DEVICE, typename SPEC, auto UNARY_REDUCE_OPERATION, typename ACCUMULATOR_TYPE, typename CURRENT_TYPE, typename OPERATION_PARAMETER>
ACCUMULATOR_TYPE unary_associative_reduce(DEVICE& device, const tensor::ReduceOperation<OPERATION_PARAMETER, ACCUMULATOR_TYPE, CURRENT_TYPE, UNARY_REDUCE_OPERATION>& op, Tensor<SPEC>& t){
ACCUMULATOR_TYPE unary_associative_reduce(DEVICE& device, const tensor::UnaryReduceOperation<OPERATION_PARAMETER, ACCUMULATOR_TYPE, CURRENT_TYPE, UNARY_REDUCE_OPERATION>& op, Tensor<SPEC>& t){
using T = typename SPEC::T;
using TI = typename DEVICE::index_t;
if constexpr(length(typename SPEC::SHAPE{}) > 1){
Expand Down Expand Up @@ -272,6 +296,37 @@ namespace rl_tools{
return unary_associative_reduce(device, op, t);
}

template<typename DEVICE, typename SPEC_1, typename SPEC_2, auto BINARY_REDUCE_OPERATION, auto BINARY_ASSOCIATIVE_REDUCE_OPERATION, typename ACCUMULATOR_TYPE, typename CURRENT_TYPE1, typename CURRENT_TYPE2, typename OPERATION_PARAMETER>
ACCUMULATOR_TYPE binary_associative_reduce(DEVICE& device, const tensor::BinaryReduceOperation<OPERATION_PARAMETER, ACCUMULATOR_TYPE, CURRENT_TYPE1, CURRENT_TYPE2, BINARY_REDUCE_OPERATION, BINARY_ASSOCIATIVE_REDUCE_OPERATION>& op, Tensor<SPEC_1>& t1, Tensor<SPEC_2>& t2){
using T = typename SPEC_1::T;
using TI = typename DEVICE::index_t;
static_assert(tensor::same_dimensions<SPEC_1, SPEC_2>());
if constexpr(length(typename SPEC_1::SHAPE{}) > 1){
ACCUMULATOR_TYPE accumulator = op.initial_value;
for(TI i=0; i < get<0>(typename SPEC_1::SHAPE{}); ++i){
auto next_t1 = view(device, t1, i);
auto next_t2 = view(device, t2, i);
accumulator = BINARY_ASSOCIATIVE_REDUCE_OPERATION(device, op.parameter, accumulator, binary_associative_reduce(device, op, next_t1, next_t2));
}
return accumulator;
}
else{
ACCUMULATOR_TYPE accumulator = op.initial_value;
for(TI i=0; i < get<0>(typename SPEC_1::SHAPE{}); i++){
T t1_value = get(device, t1, i);
T t2_value = get(device, t2, i);
accumulator = BINARY_REDUCE_OPERATION(device, op.parameter, accumulator, t1_value, t2_value);
}
return accumulator;
}
}
template<typename DEVICE, typename SPEC_1, typename SPEC_2>
typename SPEC_1::T absolute_difference(DEVICE& device, Tensor<SPEC_1>& t1, Tensor<SPEC_2>& t2){
tensor::binary_reduce_operations::AbsoluteDifference<DEVICE, typename SPEC_1::T, typename SPEC_2::T> op;
op.initial_value = 0;
return binary_associative_reduce(device, op, t1, t2);
}


template<typename DEVICE, typename SPEC>
void abs(DEVICE& device, Tensor<SPEC>& t){
Expand Down

0 comments on commit 383c06e

Please sign in to comment.