Skip to content

Commit

Permalink
simplifying tensor; adding gru; moar tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-eschmann committed Mar 28, 2024
1 parent 049236f commit 49a8f62
Show file tree
Hide file tree
Showing 12 changed files with 668 additions and 171 deletions.
125 changes: 98 additions & 27 deletions include/rl_tools/containers/tensor/operations_generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ namespace rl_tools{
}
template <typename DEVICE, typename SPEC>
void free(DEVICE& device, Tensor<SPEC>& tensor){
delete[] data_reference(tensor);
delete data(tensor);
}

template <typename DEVICE, typename SPEC, typename TI, auto DIM=0, auto SIZE=0>
auto view_range(DEVICE& device, Tensor<SPEC>& tensor, TI index, tensor::ViewSpec<DIM, SIZE> = {}){
static_assert(SIZE > 0);
using NEW_SHAPE = tensor::Replace<typename SPEC::SHAPE, SIZE, DIM>;
using NEW_STRIDE = typename SPEC::STRIDE;
// using NEW_SHAPE = tensor::Replace<typename SPEC::SHAPE, SIZE, DIM>;
// using NEW_STRIDE = typename SPEC::STRIDE;
auto offset = index * get<DIM>(typename SPEC::STRIDE{});
using NEW_SPEC = tensor::Specification<typename SPEC::T, typename SPEC::TI, NEW_SHAPE, NEW_STRIDE>;
Tensor<NEW_SPEC> view;
// using NEW_SPEC = tensor::Specification<typename SPEC::T, typename SPEC::TI, NEW_SHAPE, NEW_STRIDE>;
Tensor<tensor::spec::view::range::Specification<SPEC, tensor::ViewSpec<DIM, SIZE>>> view;
data_reference(view) = data(tensor) + offset;
return view;
}
Expand Down Expand Up @@ -126,33 +127,21 @@ namespace rl_tools{
}
}

template<typename DEVICE, typename SPEC>
typename SPEC::T sum(DEVICE& device, Tensor<SPEC>& t){
using T = typename SPEC::T;
using TI = typename DEVICE::index_t;
if constexpr(length(typename SPEC::SHAPE{}) > 1){
T acc = 0;
for(TI i=0; i < get<0>(typename SPEC::SHAPE{}); ++i){
auto next = view(device, t, i);
acc += sum(device, next);
}
return acc;
}
else{
T acc = 0;
for(TI i=0; i < get<0>(typename SPEC::SHAPE{}); i++){
acc += get(device, t, i);
}
return acc;
}
}
namespace tensor{
struct OperationEmptyParameter{};
template <auto T_OPERATION, typename PARAMETER>
struct Operation{
static constexpr auto OPERATION = T_OPERATION;
PARAMETER parameter;
};
template <typename PARAMETER, typename T_ACCUMULATOR_TYPE, typename T_CURRENT_TYPE, auto T_OPERATION>
struct ReduceOperation{
using ACCUMULATOR_TYPE = T_ACCUMULATOR_TYPE;
using CURRENT_TYPE = T_CURRENT_TYPE;
static constexpr auto OPERATION = T_OPERATION;
PARAMETER parameter;
ACCUMULATOR_TYPE initial_value;
};
namespace binary_operations{
template <typename T>
T add(T a, T b){
Expand Down Expand Up @@ -185,9 +174,19 @@ namespace rl_tools{
return parameter;
}
}
namespace unary_reduce_operations{
namespace impl{
template <typename PARAMETER, typename ACCUMULATOR_TYPE, typename CURRENT_TYPE>
ACCUMULATOR_TYPE sum(const PARAMETER& parameter, const ACCUMULATOR_TYPE& accumulator, CURRENT_TYPE current){
return accumulator + current;
}
}
template <typename T>
using Sum = ReduceOperation<OperationEmptyParameter, T, T, impl::sum<OperationEmptyParameter, T, T>>;
}
}
template<typename DEVICE, typename SPEC_1, typename SPEC_2, typename SPEC_OUT, auto BINARY_OPERATION, typename OPERATION_PARAMETER>
void binary_operation(DEVICE& device, const tensor::Operation<BINARY_OPERATION, OPERATION_PARAMETER>, Tensor<SPEC_1>& t1, Tensor<SPEC_2>& t2, Tensor<SPEC_OUT>& result){
inline void binary_operation(DEVICE& device, const tensor::Operation<BINARY_OPERATION, OPERATION_PARAMETER>, Tensor<SPEC_1>& t1, Tensor<SPEC_2>& t2, Tensor<SPEC_OUT>& result){
using T = typename SPEC_1::T;
using TI = typename DEVICE::index_t;
using BOP = tensor::Operation<BINARY_OPERATION, OPERATION_PARAMETER>;
Expand Down Expand Up @@ -234,6 +233,36 @@ namespace rl_tools{
}
}

template<typename DEVICE, typename SPEC, auto UNARY_REDUCE_OPERATION, typename ACCUMULATOR_TYPE, typename CURRENT_TYPE, typename OPERATION_PARAMETER>
ACCUMULATOR_TYPE unary_associative_reduce(DEVICE& device, const tensor::ReduceOperation<OPERATION_PARAMETER, ACCUMULATOR_TYPE, CURRENT_TYPE, UNARY_REDUCE_OPERATION>& op, Tensor<SPEC>& t){
using T = typename SPEC::T;
using TI = typename DEVICE::index_t;
if constexpr(length(typename SPEC::SHAPE{}) > 1){
ACCUMULATOR_TYPE accumulator = op.initial_value;
for(TI i=0; i < get<0>(typename SPEC::SHAPE{}); ++i){
auto next_t = view(device, t, i);
accumulator = UNARY_REDUCE_OPERATION(op.parameter, accumulator, unary_associative_reduce(device, op, next_t));
}
return accumulator;
}
else{
ACCUMULATOR_TYPE accumulator = op.initial_value;
for(TI i=0; i < get<0>(typename SPEC::SHAPE{}); i++){
T t_value = get(device, t, i);
accumulator = UNARY_REDUCE_OPERATION(op.parameter, accumulator, t_value);
}
return accumulator;
}
}

template<typename DEVICE, typename SPEC>
typename SPEC::T sum(DEVICE& device, Tensor<SPEC>& t){
tensor::unary_reduce_operations::Sum<typename SPEC::T> op;
op.initial_value = 0;
return unary_associative_reduce(device, op, t);
}


template<typename DEVICE, typename SPEC>
void abs(DEVICE& device, Tensor<SPEC>& t){
using T = typename SPEC::T;
Expand All @@ -251,8 +280,50 @@ namespace rl_tools{
unary_operation(device, op, t);
}

template<typename DEVICE, typename SPEC_1, typename SPEC_2, typename SPEC_OUT>
void multiply(DEVICE& device, Tensor<SPEC_1>& t1, Tensor<SPEC_2>& t2, Tensor<SPEC_OUT>& result){
// Y^T = WX^T
static_assert(length(typename SPEC_1::SHAPE{}) == 2);
static_assert(length(typename SPEC_2::SHAPE{}) == 2);
static_assert(length(typename SPEC_OUT::SHAPE{}) == 2);
static_assert(get<1>(typename SPEC_1::SHAPE{}) == get<0>(typename SPEC_2::SHAPE{}));
static_assert(get<0>(typename SPEC_1::SHAPE{}) == get<0>(typename SPEC_OUT::SHAPE{}));
static_assert(get<1>(typename SPEC_2::SHAPE{}) == get<1>(typename SPEC_OUT::SHAPE{}));
using T = typename SPEC_1::T;
using TI = typename DEVICE::index_t;
for(TI row_i=0; row_i < get<0>(typename SPEC_1::SHAPE{}); ++row_i){
for(TI col_j=0; col_j < get<1>(typename SPEC_2::SHAPE{}); ++col_j){
T acc = 0;
for(TI k=0; k < get<1>(typename SPEC_1::SHAPE{}); ++k){
acc += get(device, t1, row_i, k) * get(device, t2, k, col_j);
}
set(device, result, acc, row_i, col_j);
}
}
}


template<typename DEVICE, typename SPEC_1, typename SPEC_2, typename SPEC_OUT>
void multiply_transpose(DEVICE& device, Tensor<SPEC_1>& t1, Tensor<SPEC_2>& t2, Tensor<SPEC_OUT>& result){
// Y^T = WX^T
static_assert(length(typename SPEC_1::SHAPE{}) == 2);
static_assert(length(typename SPEC_2::SHAPE{}) == 2);
static_assert(length(typename SPEC_OUT::SHAPE{}) == 2);
static_assert(get<1>(typename SPEC_1::SHAPE{}) == get<1>(typename SPEC_2::SHAPE{}));
static_assert(get<0>(typename SPEC_2::SHAPE{}) == get<0>(typename SPEC_OUT::SHAPE{}));
static_assert(get<0>(typename SPEC_1::SHAPE{}) == get<1>(typename SPEC_OUT::SHAPE{}));
using T = typename SPEC_1::T;
using TI = typename DEVICE::index_t;
for(TI i=0; i < get<0>(typename SPEC_1::SHAPE{}); ++i){
for(TI j=0; j < get<0>(typename SPEC_2::SHAPE{}); ++j){
T acc = 0;
for(TI k=0; k < get<1>(typename SPEC_1::SHAPE{}); ++k){
acc += get(device, t1, i, k) * get(device, t2, j, k);
}
set(device, result, acc, j, i);
}
}
}
}
RL_TOOLS_NAMESPACE_WRAPPER_END

#endif
38 changes: 29 additions & 9 deletions include/rl_tools/containers/tensor/persist.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,43 @@ namespace rl_tools {

template<typename DEVICE, typename SPEC>
void save(DEVICE& device, Tensor<SPEC>& tensor, HighFive::Group group, std::string dataset_name) {
// todo
auto data = vector(device, tensor);
auto data = to_vector(device, tensor);
group.createDataSet(dataset_name, data);
}

template<typename DEVICE, typename SPEC>
void load(DEVICE& device, Tensor<SPEC>& tensor, const HighFive::DataSet& dataset, bool fallback_to_zero = false) {
void load(DEVICE& device, const HighFive::DataSet& dataset, Tensor<SPEC>& tensor, bool fallback_to_zero = false) {
using T = typename SPEC::T;
auto dims = dataset.getDimensions();
static_assert(tensor::dense_layout<SPEC>(), "Load only supports dense tensors for now");
utils::assert_exit(device, dims.size() == length(typename SPEC::SHAPE{}), "Rank mismatch");
utils::assert_exit(device, tensor::check_dimensions(device, tensor, dims), "Dimension mismatch");
dataset.read(data(tensor));
}

template<typename DEVICE, typename SPEC>
void load(DEVICE& device, Tensor<SPEC>& tensor, HighFive::Group group, std::string dataset_name, bool fallback_to_zero = false) {
// todo
typename SPEC::T* data_ptr = data(tensor);
utils::assert_exit(device, data_ptr != nullptr, "Data pointer is null");
constexpr bool VIA_VECTOR = false;
if constexpr(VIA_VECTOR){
static_assert(!VIA_VECTOR || (length(typename SPEC::SHAPE{}) <= 3));
if constexpr(length(typename SPEC::SHAPE{}) == 1){
dataset.read(data_ptr);
}
else{
if constexpr(length(typename SPEC::SHAPE{}) == 2){
std::vector<std::vector<T>> buffer;
dataset.read(buffer);
from_vector(device, buffer, tensor);
}
else{
if constexpr(length(typename SPEC::SHAPE{}) == 3){
std::vector<std::vector<std::vector<T>>> buffer;
dataset.read(buffer);
from_vector(device, buffer, tensor);
}
}
}
}
else{
dataset.read(data_ptr);
}
}
}
RL_TOOLS_NAMESPACE_WRAPPER_END
Expand Down
48 changes: 40 additions & 8 deletions include/rl_tools/containers/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,25 @@ namespace rl_tools{
static_assert(length(ELEMENT{}) > ELEMENT_OFFSET);
};





template <typename SHAPE>
using RowMajorStride = Append<PopFront<Product<SHAPE>>, 1>;

template <typename T_T, typename T_TI, typename T_SHAPE, typename T_STRIDE = RowMajorStride<T_SHAPE>>
template <typename T_T, typename T_TI, typename T_SHAPE, typename T_STRIDE = RowMajorStride<T_SHAPE>, bool T_STATIC=false>
struct Specification{
using T = T_T;
using TI = T_TI;
using SHAPE = T_SHAPE;
using STRIDE = T_STRIDE;
static constexpr bool STATIC = T_STATIC;
static constexpr TI SIZE = Product<SHAPE>::VALUE;
static constexpr TI SIZE_BYTES = SIZE * sizeof(T);

};
template<auto T_DIM, auto T_SIZE=0>
struct ViewSpec{
static constexpr auto DIM = T_DIM;
static constexpr auto SIZE = T_SIZE;
};
template<auto DIM, auto SIZE=0>
struct ViewSpec{};
template <typename A, typename B>
bool constexpr _same_dimensions_shape(){
static_assert(length(A{}) == length(B{}));
Expand Down Expand Up @@ -185,14 +186,37 @@ namespace rl_tools{
bool constexpr dense_layout(){
return _dense_layout_shape<typename SPEC::SHAPE, typename SPEC::STRIDE>();
}
namespace spec::view{
namespace range{
template <typename SHAPE, typename VIEW_SPEC>
using Shape = tensor::Replace<SHAPE, VIEW_SPEC::SIZE, VIEW_SPEC::DIM>;
template <typename STRIDE, typename VIEW_SPEC>
using Stride = STRIDE;
template <typename SPEC, typename VIEW_SPEC>
using Specification = tensor::Specification<typename SPEC::T, typename SPEC::TI, Shape<typename SPEC::SHAPE, VIEW_SPEC>, Stride<typename SPEC::STRIDE, VIEW_SPEC>>;
}
namespace point{
template <typename SHAPE, typename VIEW_SPEC>
using Shape = tensor::Remove<SHAPE, VIEW_SPEC::DIM>;
template <typename STRIDE, typename VIEW_SPEC>
using Stride = tensor::Remove<STRIDE, VIEW_SPEC::DIM>;
template <typename SPEC, typename VIEW_SPEC>
using Specification = tensor::Specification<typename SPEC::T, typename SPEC::TI, Shape<typename SPEC::SHAPE, VIEW_SPEC>, Stride<typename SPEC::STRIDE, VIEW_SPEC>>;
}
}
}

template <typename T_SPEC>
struct Tensor{
using SPEC = T_SPEC;
using T = typename SPEC::T;
T* _data;
template <typename VIEW_SPEC>
using VIEW_POINT = Tensor<tensor::spec::view::point::Specification<SPEC, VIEW_SPEC>>;
template <typename VIEW_SPEC>
using VIEW_RANGE = Tensor<tensor::spec::view::range::Specification<SPEC, VIEW_SPEC>>;
utils::typing::conditional_t<SPEC::STATIC, T[SPEC::SIZE], T*> _data;
};

template <typename SPEC>
constexpr typename SPEC::T* data(const Tensor<SPEC>& tensor){
return tensor._data;
Expand All @@ -201,6 +225,14 @@ namespace rl_tools{
constexpr typename SPEC::T*& data_reference(Tensor<SPEC>& tensor){
return tensor._data;
}
struct TensorDynamicTag{
template<typename SPEC>
using type = Tensor<tensor::Specification<typename SPEC::T, typename SPEC::TI, typename SPEC::SHAPE, typename SPEC::STRIDE, false>>;
};
struct TensorStaticTag{
template<typename SPEC>
using type = Tensor<tensor::Specification<typename SPEC::T, typename SPEC::TI, typename SPEC::SHAPE, typename SPEC::STRIDE, true>>;
};
}

#endif
Loading

0 comments on commit 49a8f62

Please sign in to comment.