diff --git a/albatross/core/distribution.h b/albatross/core/distribution.h index a829507a..40edc7ca 100644 --- a/albatross/core/distribution.h +++ b/albatross/core/distribution.h @@ -13,6 +13,9 @@ #ifndef ALBATROSS_CORE_DISTRIBUTION_H #define ALBATROSS_CORE_DISTRIBUTION_H +#include "cereal/cereal.hpp" +#include "core/traits.h" +#include "eigen/serializable_diagonal_matrix.h" #include "indexing.h" #include #include @@ -52,11 +55,40 @@ template struct Distribution { Distribution(const Eigen::VectorXd &mean_) : mean(mean_), covariance(){}; Distribution(const Eigen::VectorXd &mean_, const CovarianceType &covariance_) : mean(mean_), covariance(covariance_){}; + + /* + * If the CovarianceType is serializable, add a serialize method. + */ + template + typename std::enable_if< + valid_in_out_serializer::value, void>::type + serialize(Archive &archive) { + archive(cereal::make_nvp("mean", mean)); + archive(cereal::make_nvp("covariance", covariance)); + } + + /* + * If you try to serialize a Distribution for which the covariance + * type is not serializable you'll get an error. + */ + template + typename std::enable_if< + !valid_in_out_serializer::value, void>::type + save(Archive &archive) { + static_assert(delay_static_assert::value, + "In order to serialize a Distribution the corresponding " + "CovarianceType must be serializable."); + } + + bool operator==(const Distribution &other) const { + return (mean == other.mean && covariance == other.covariance); + } }; +using DiagonalMatrixXd = + Eigen::SerializableDiagonalMatrix; using DenseDistribution = Distribution; -using DiagonalDistribution = - Distribution>; +using DiagonalDistribution = Distribution; template Distribution subset(const std::vector &indices, @@ -69,6 +101,7 @@ Distribution subset(const std::vector &indices, return Distribution(mean); } } + } // namespace albatross #endif diff --git a/albatross/core/model.h b/albatross/core/model.h index d66096e2..f40d4328 100644 --- a/albatross/core/model.h +++ b/albatross/core/model.h @@ -24,9 +24,8 @@ namespace albatross { -using DiagonalMatrixXd = Eigen::DiagonalMatrix; -using TargetDistribution = Distribution; -using PredictDistribution = Distribution; +using TargetDistribution = DiagonalDistribution; +using PredictDistribution = DenseDistribution; /* * A RegressionDataset holds two vectors of data, the features @@ -53,6 +52,23 @@ template struct RegressionDataset { RegressionDataset(const std::vector &features_, const Eigen::VectorXd &targets_) : RegressionDataset(features_, TargetDistribution(targets_)) {} + + template + typename std::enable_if::value, + void>::type + serialize(Archive &archive) { + archive(cereal::make_nvp("features", features)); + archive(cereal::make_nvp("targets", targets)); + } + + template + typename std::enable_if::value, + void>::type + serialize(Archive &archive) { + static_assert(delay_static_assert::value, + "In order to serialize a RegressionDataset the corresponding " + "FeatureType must be serializable."); + } }; typedef int32_t s32; diff --git a/albatross/core/serialize.h b/albatross/core/serialize.h index 07183625..3e5e3ef1 100644 --- a/albatross/core/serialize.h +++ b/albatross/core/serialize.h @@ -13,6 +13,7 @@ #ifndef ALBATROSS_CORE_SERIALIZE_H #define ALBATROSS_CORE_SERIALIZE_H +#include "core/traits.h" #include #include #include @@ -38,21 +39,52 @@ class SerializableRegressionModel : public RegressionModel { model_fit_ == other.get_fit()); } - // todo: enable if ModelFit is serializable. - template void save(Archive &archive) const { + /* + * Include save/load methods conditional on the ability to serialize + * ModelFit. + */ + template + typename std::enable_if::value, + void>::type + save(Archive &archive) const { archive(cereal::make_nvp( "model_definition", cereal::base_class>(this))); archive(cereal::make_nvp("model_fit", this->model_fit_)); } - template void load(Archive &archive) { + template + typename std::enable_if::value, + void>::type + load(Archive &archive) { archive(cereal::make_nvp( "model_definition", cereal::base_class>(this))); archive(cereal::make_nvp("model_fit", this->model_fit_)); } + /* + * If ModelFit does not have valid serialization methods and you attempt to + * (de)serialize a SerializableRegressionModel you'll get an error. + */ + template + typename std::enable_if::value, + void>::type + save(Archive &archive) const { + static_assert(delay_static_assert::value, + "SerializableRegressionModel requires a ModelFit type which " + "is serializable."); + } + + template + typename std::enable_if::value, + void>::type + load(Archive &archive) const { + static_assert(delay_static_assert::value, + "SerializableRegressionModel requires a ModelFit type which " + "is serializable."); + } + virtual ModelFit get_fit() const { return model_fit_; } protected: diff --git a/albatross/core/traits.h b/albatross/core/traits.h index abf33307..aea5235c 100644 --- a/albatross/core/traits.h +++ b/albatross/core/traits.h @@ -13,10 +13,20 @@ #ifndef ALBATROSS_CORE_MAGIC_H #define ALBATROSS_CORE_MAGIC_H +#include "cereal/details/traits.hpp" #include namespace albatross { +/* + * This little trick was borrowed from cereal, you an think of it as + * a function that will always return false ... but that doesn't + * get resolved until template instantiation, which when combined + * with a static assert let's you include a static assert that + * only triggers with a particular template parameter is used. + */ +template struct delay_static_assert : std::false_type {}; + /* * This determines whether or not a class has a method defined for, * `operator() (X x, Y y, Z z, ...)` @@ -82,6 +92,47 @@ using fit_type_if_serializable = typename enable_if_serializable::type>::type; +/* + * The following helper functions let you inspect a type and cereal Archive + * and determine if the type has a valid serialization method for that Archive + * type. + */ +template class valid_output_serializer { + template + static typename std::enable_if< + 1 == cereal::traits::detail::count_output_serializers::value, + std::true_type>::type + test(int); + template static std::false_type test(...); + +public: + static constexpr bool value = decltype(test(0))::value; +}; + +template class valid_input_serializer { + template + static typename std::enable_if< + 1 == cereal::traits::detail::count_input_serializers::value, + std::true_type>::type + test(int); + template static std::false_type test(...); + +public: + static constexpr bool value = decltype(test(0))::value; +}; + +template class valid_in_out_serializer { + template + static typename std::enable_if::value && + valid_output_serializer::value, + std::true_type>::type + test(int); + template static std::false_type test(...); + +public: + static constexpr bool value = decltype(test(0))::value; +}; + } // namespace albatross #endif diff --git a/albatross/eigen/serializable_diagonal_matrix.h b/albatross/eigen/serializable_diagonal_matrix.h new file mode 100644 index 00000000..f90a66a0 --- /dev/null +++ b/albatross/eigen/serializable_diagonal_matrix.h @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2018 Swift Navigation Inc. + * Contact: Swift Navigation + * + * This source is subject to the license found in the file 'LICENSE' which must + * be distributed together with this source. All other rights reserved. + * + * THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, + * EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef ALBATROSS_EIGEN_SERIALIZABLE_DIAGONAL_MATRIX_H +#define ALBATROSS_EIGEN_SERIALIZABLE_DIAGONAL_MATRIX_H + +/* + * The Eigen::DiagonalMatrix doesn't provide the public methods + * required to reliably serialize the `m_diagonal` private + * member. In order to make the DiagonalMatrix serializable + * we instead inherit from it, giving private access to the + * diagonal elements which in turn allows us to serialize it. + */ + +#include "Eigen/Cholesky" +#include "Eigen/Dense" +#include "cereal/cereal.hpp" +#include + +namespace Eigen { + +template +class SerializableDiagonalMatrix + : public Eigen::DiagonalMatrix<_Scalar, SizeAtCompileTime> { + using BaseClass = Eigen::DiagonalMatrix<_Scalar, SizeAtCompileTime>; + +public: + SerializableDiagonalMatrix() : BaseClass(){}; + + template + inline SerializableDiagonalMatrix(const DiagonalBase &other) + : BaseClass(other){}; + + template void serialize(Archive &archive) { + archive(cereal::make_nvp("diagonal", this->m_diagonal)); + } + + bool operator==(const BaseClass &other) const { + return (this->m_diagonal == other.diagonal()); + } +}; + +} // namesapce Eigen + +#endif diff --git a/tests/test_serialize.cc b/tests/test_serialize.cc index c453415b..417f1024 100644 --- a/tests/test_serialize.cc +++ b/tests/test_serialize.cc @@ -99,6 +99,39 @@ struct EigenMatrixXd : public SerializableType { } }; +struct FullDenseDistribution : public SerializableType { + DenseDistribution create() const override { + Eigen::MatrixXd cov(3, 3); + cov << 1., 2., 3., 4., 5., 6., 7, 8, 9; + Eigen::VectorXd mean = Eigen::VectorXd::Ones(3); + return DenseDistribution(mean, cov); + } +}; + +struct MeanOnlyDenseDistribution : public SerializableType { + DenseDistribution create() const override { + Eigen::MatrixXd mean = Eigen::VectorXd::Ones(3); + return DenseDistribution(mean); + } +}; + +struct FullDiagonalDistribution + : public SerializableType { + DiagonalDistribution create() const override { + Eigen::VectorXd diag = Eigen::VectorXd::Ones(3); + Eigen::VectorXd mean = Eigen::VectorXd::Ones(3); + return DiagonalDistribution(mean, diag.asDiagonal()); + } +}; + +struct MeanOnlyDiagonalDistribution + : public SerializableType { + DiagonalDistribution create() const override { + Eigen::MatrixXd mean = Eigen::VectorXd::Ones(3); + return DiagonalDistribution(mean); + } +}; + struct LDLT : public SerializableType { Eigen::Index n = 3; @@ -281,10 +314,12 @@ struct PolymorphicSerializeTest : public ::testing::Test { typedef ::testing::Types< LDLT, SerializableType, SerializableType, EmptyEigenVectorXd, EigenVectorXd, EmptyEigenMatrixXd, EigenMatrixXd, - ParameterStoreType, SerializableType, UnfitSerializableModel, - FitSerializableModel, FitDirectModel, UnfitDirectModel, - UnfitRegressionModel, FitLinearRegressionModel, - FitLinearSerializablePointer, UnfitGaussianProcess, FitGaussianProcess> + FullDenseDistribution, MeanOnlyDenseDistribution, FullDiagonalDistribution, + MeanOnlyDiagonalDistribution, ParameterStoreType, + SerializableType, UnfitSerializableModel, FitSerializableModel, + FitDirectModel, UnfitDirectModel, UnfitRegressionModel, + FitLinearRegressionModel, FitLinearSerializablePointer, + UnfitGaussianProcess, FitGaussianProcess> ToTest; TYPED_TEST_CASE(PolymorphicSerializeTest, ToTest);