diff --git a/include/albatross/Common b/include/albatross/Common index dfa90408..6fa1368b 100644 --- a/include/albatross/Common +++ b/include/albatross/Common @@ -43,4 +43,8 @@ #include "src/utils/map_utils.hpp" #include "src/cereal/eigen.hpp" +#include "src/details/traits.hpp" +#include "src/details/has_any_macros.hpp" +#include "src/details/error_handling.hpp" + #endif diff --git a/include/albatross/src/cereal/traits.hpp b/include/albatross/src/cereal/traits.hpp index 131e147b..5510c97e 100644 --- a/include/albatross/src/cereal/traits.hpp +++ b/include/albatross/src/cereal/traits.hpp @@ -15,15 +15,6 @@ namespace albatross { -/* - * This little trick was borrowed from cereal, you can think of it as - * a function that will always return false ... but that doesn't - * get resolved until template instantiation, which when combined - * with a static assert let's you include a static assert that - * only triggers with a particular template parameter is used. - */ -template struct delay_static_assert : std::false_type {}; - /* * The following helper functions let you inspect a type and cereal Archive * and determine if the type has a valid serialization method for that Archive diff --git a/include/albatross/src/core/model.hpp b/include/albatross/src/core/model.hpp index 76d24339..ba936811 100644 --- a/include/albatross/src/core/model.hpp +++ b/include/albatross/src/core/model.hpp @@ -19,7 +19,9 @@ using Insights = std::map; template class ModelBase : public ParameterHandlingMixin { - template friend class Prediction; + friend class JointPredictor; + friend class MarginalPredictor; + friend class MeanPredictor; template friend class fit_model_type; @@ -46,7 +48,10 @@ template class ModelBase : public ParameterHandlingMixin { !has_valid_fit::value, int>::type = 0> void _fit(const std::vector &features, - const MarginalDistribution &targets) const = delete; // Invalid fit + const MarginalDistribution &targets) const + ALBATROSS_FAIL(FeatureType, + "The ModelType *almost* has a _fit_impl method for " + "FeatureType, but it appears to be invalid"); template class ModelBase : public ParameterHandlingMixin { !has_valid_fit::value, int>::type = 0> void _fit(const std::vector &features, - const MarginalDistribution &targets) const = delete; + const MarginalDistribution &targets) const + ALBATROSS_FAIL( + FeatureType, + "The ModelType is missing a _fit_impl method for FeatureType."); template < typename PredictFeatureType, typename FitType, typename PredictType, @@ -73,9 +81,12 @@ template class ModelBase : public ParameterHandlingMixin { typename std::enable_if::value, int>::type = 0> - PredictType predict_( - const std::vector &features, const FitType &fit, - PredictTypeIdentity &&) const = delete; // No valid predict. + PredictType predict_(const std::vector &features, + const FitType &fit, + PredictTypeIdentity &&) const + ALBATROSS_FAIL(PredictFeatureType, + "The ModelType is missing a _predict_impl method for " + "PredictFeatureType, FitType, PredictType."); public: /* diff --git a/include/albatross/src/core/prediction.hpp b/include/albatross/src/core/prediction.hpp index 51364f64..a7244ef0 100644 --- a/include/albatross/src/core/prediction.hpp +++ b/include/albatross/src/core/prediction.hpp @@ -19,83 +19,80 @@ namespace albatross { // which behave different conditional on the type of predictions desired. template struct PredictTypeIdentity { typedef T type; }; -template -class Prediction { - +/* + * MeanPredictor is responsible for determining if a valid form of + * predicting exists for a given set of model, feature, fit. The + * primary goal of the class is to consolidate all the logic required + * to decide if different predict types are available. For example, + * by inspecting this class for a _mean method you can determine if + * any valid mean prediction method exists. + */ +class MeanPredictor { public: - Prediction(const ModelType &model, const FitType &fit, - const std::vector &features) - : model_(model), fit_(fit), features_(features) {} - - Prediction(const ModelType &model, const FitType &fit, - std::vector &&features) - : model_(model), fit_(fit), features_(std::move(features)) {} - - // Mean - template ::value, + has_valid_predict_mean::value, int>::type = 0> - Eigen::VectorXd mean() const { - static_assert(std::is_same::value, - "never do prediction.mean()"); - return model_.predict_(features_, fit_, - PredictTypeIdentity()); + Eigen::VectorXd _mean(const ModelType &model, const FitType &fit, + const std::vector &features) const { + return model.predict_(features, fit, + PredictTypeIdentity()); } template < - typename DummyType = FeatureType, + typename ModelType, typename FeatureType, typename FitType, typename std::enable_if< - !has_valid_predict_mean::value && - has_valid_predict_marginal::value, + !has_valid_predict_mean::value && + has_valid_predict_marginal::value, int>::type = 0> - Eigen::VectorXd mean() const { - static_assert(std::is_same::value, - "never do prediction.mean()"); - return model_ - .predict_(features_, fit_, PredictTypeIdentity()) + Eigen::VectorXd _mean(const ModelType &model, const FitType &fit, + const std::vector &features) const { + return model + .predict_(features, fit, PredictTypeIdentity()) .mean; } template < - typename DummyType = FeatureType, + typename ModelType, typename FeatureType, typename FitType, typename std::enable_if< - !has_valid_predict_mean::value && - !has_valid_predict_marginal::value && + !has_valid_predict_marginal::value && - has_valid_predict_joint::value, + has_valid_predict_joint::value, int>::type = 0> - Eigen::VectorXd mean() const { - static_assert(std::is_same::value, - "never do prediction.mean()"); - return model_ - .predict_(features_, fit_, PredictTypeIdentity()) + Eigen::VectorXd _mean(const ModelType &model, const FitType &fit, + const std::vector &features) const { + return model + .predict_(features, fit, PredictTypeIdentity()) .mean; } +}; - // Marginal - template ::value, + ModelType, FeatureType, FitType>::value, int>::type = 0> - MarginalDistribution marginal() const { - static_assert(std::is_same::value, - "never do prediction.marginal()"); - return model_.predict_(features_, fit_, - PredictTypeIdentity()); + MarginalDistribution + _marginal(const ModelType &model, const FitType &fit, + const std::vector &features) const { + return model.predict_(features, fit, + PredictTypeIdentity()); } template < - typename DummyType = FeatureType, + typename ModelType, typename FeatureType, typename FitType, typename std::enable_if< - !has_valid_predict_marginal::value && - has_valid_predict_joint::value, + !has_valid_predict_marginal::value && + has_valid_predict_joint::value, int>::type = 0> - MarginalDistribution marginal() const { - static_assert(std::is_same::value, - "never do prediction.marginal()"); - const auto joint_pred = model_.predict_( - features_, fit_, PredictTypeIdentity()); + MarginalDistribution + _marginal(const ModelType &model, const FitType &fit, + const std::vector &features) const { + const auto joint_pred = + model.predict_(features, fit, PredictTypeIdentity()); if (joint_pred.has_covariance()) { Eigen::VectorXd diag = joint_pred.covariance.diagonal(); return MarginalDistribution(joint_pred.mean, diag.asDiagonal()); @@ -103,45 +100,94 @@ class Prediction { return MarginalDistribution(joint_pred.mean); } } +}; - // Joint - template ::value, + has_valid_predict_joint::value, int>::type = 0> - JointDistribution joint() const { + JointDistribution _joint(const ModelType &model, const FitType &fit, + const std::vector &features) const { + return model.predict_(features, fit, + PredictTypeIdentity()); + } +}; + +template +class Prediction { + +public: + Prediction(const ModelType &model, const FitType &fit, + const std::vector &features) + : model_(model), fit_(fit), features_(features) {} + + Prediction(const ModelType &model, const FitType &fit, + std::vector &&features) + : model_(model), fit_(fit), features_(std::move(features)) {} + + // Mean + template ::value, + int>::type = 0> + Eigen::VectorXd mean() const { static_assert(std::is_same::value, - "never do prediction.joint()"); - return model_.predict_(features_, fit_, - PredictTypeIdentity()); + "never do prediction.mean()"); + return MeanPredictor()._mean(model_, fit_, features_); } - // CATCH FAILURE MODES template < typename DummyType = FeatureType, - typename std::enable_if< - !has_valid_predict_mean::value && - !has_valid_predict_marginal::value && - !has_valid_predict_joint::value, - int>::type = 0> - Eigen::VectorXd mean() const = delete; // No valid predict method found. + typename std::enable_if::value, + int>::type = 0> + void mean() const + ALBATROSS_FAIL(DummyType, "No valid predict method in ModelType for the " + "mean with FitType and FeatureType."); + // Marginal template < typename DummyType = FeatureType, - typename std::enable_if< - !has_valid_predict_marginal::value && - !has_valid_predict_joint::value, - int>::type = 0> - Eigen::VectorXd - marginal() const = delete; // No valid predict marginal method found. + typename std::enable_if::value, + int>::type = 0> + MarginalDistribution marginal() const { + static_assert(std::is_same::value, + "never do prediction.mean()"); + return MarginalPredictor()._marginal(model_, fit_, features_); + } template ::value, + !can_predict_marginal::value, int>::type = 0> - Eigen::VectorXd - joint() const = delete; // No valid predict joint method found. + void marginal() const + ALBATROSS_FAIL(DummyType, "No valid predict method in ModelType for the " + "marginal with FitType and FeatureType."); + + // Joint + template < + typename DummyType = FeatureType, + typename std::enable_if::value, + int>::type = 0> + JointDistribution joint() const { + static_assert(std::is_same::value, + "never do prediction.mean()"); + return JointPredictor()._joint(model_, fit_, features_); + } + + template < + typename DummyType = FeatureType, + typename std::enable_if::value, + int>::type = 0> + void joint() const + ALBATROSS_FAIL(DummyType, "No valid predict method in ModelType for the " + "joint with FitType and FeatureType."); template PredictType get(PredictTypeIdentity = diff --git a/include/albatross/src/core/traits.hpp b/include/albatross/src/core/traits.hpp index 1ca1c621..8d40dbcb 100644 --- a/include/albatross/src/core/traits.hpp +++ b/include/albatross/src/core/traits.hpp @@ -15,25 +15,6 @@ namespace albatross { -/* - * Checks if a class type is complete by using sizeof. - * - * https://stackoverflow.com/questions/25796126/static-assert-that-template-typename-t-is-not-complete - */ -template class is_complete { - template - static std::true_type test(int); - template static std::false_type test(...); - -public: - static constexpr bool value = decltype(test(0))::value; -}; - -template struct is_vector : public std::false_type {}; - -template -struct is_vector> : public std::true_type {}; - /* * This determines whether or not a class, T, has a method, * `std::string T.name() const` @@ -91,21 +72,11 @@ template class has_valid_fit { static constexpr bool value = decltype(test(0))::value; }; -/* - * This determines whether or not a class (T) has a method defined for, - * `Anything _fit_impl(std::vector&, - * MarginalDistribution &)` - */ -template class has_possible_fit { - template ()._fit_impl( - std::declval &>(), - std::declval()))> - static std::true_type test(C *); - template static std::false_type test(...); +HAS_METHOD(_fit_impl); -public: - static constexpr bool value = decltype(test(0))::value; -}; +template +class has_possible_fit : public has__fit_impl &, + MarginalDistribution &> {}; /* * Determines which object would be returned from a call to: @@ -143,29 +114,16 @@ template struct fit_type : public fit_type::type, F> {}; -/* - * A valid predict method has a signature that looks like: - * - * PredictType _predict_impl(const std::vector &, - * const FitType &, - * const PredictTypeIdentity) const; - */ +MAKE_HAS_ANY_TRAIT(_predict_impl); + +HAS_METHOD_WITH_RETURN_TYPE(_predict_impl); + template -class has_valid_predict { - template < - typename C, - typename ReturnType = decltype(std::declval()._predict_impl( - std::declval &>(), - std::declval(), - std::declval>()))> - static typename std::enable_if::value, - std::true_type>::type - test(C *); - template static std::false_type test(...); - -public: - static constexpr bool value = decltype(test(0))::value; +class has_valid_predict + : public has__predict_impl_with_return_type< + T, PredictType, typename const_ref>::type, + typename const_ref::type, PredictTypeIdentity> { }; template @@ -180,6 +138,34 @@ template using has_valid_predict_joint = has_valid_predict; +HAS_METHOD(_mean); + +template +struct can_predict_mean + : public has__mean::type, + typename const_ref::type, + typename const_ref>::type> {}; + +HAS_METHOD(_marginal); + +template +struct can_predict_marginal + : public has__marginal::type, + typename const_ref::type, + typename const_ref>::type> { +}; + +HAS_METHOD(_joint); + +template +struct can_predict_joint + : public has__joint::type, + typename const_ref::type, + typename const_ref>::type> {}; + /* * Methods for inspecting `Prediction` types. */ diff --git a/include/albatross/src/covariance_functions/traits.hpp b/include/albatross/src/covariance_functions/traits.hpp index 725df302..13a04ccc 100644 --- a/include/albatross/src/covariance_functions/traits.hpp +++ b/include/albatross/src/covariance_functions/traits.hpp @@ -15,77 +15,39 @@ namespace albatross { -/* - * In CovarianceFunction we frequently inspect for definitions of - * _call_impl( which MUST be defined for const references to objects - * (so that repeated covariance matrix evaluations return the same thing - * and so the computations are not repeatedly copying.) - * This type conversion utility will turn a type `T` into `const T&` - */ -template struct call_impl_arg_type { - typedef - typename std::add_lvalue_reference::type>::type - type; -}; +MAKE_HAS_ANY_TRAIT(_call_impl); -/* - * This determines whether or not a class has a method defined for, - * `operator() (const X &x, const Y &y, const Z &z, ...)` - * The result of the inspection gets stored in the member `value`. - */ -template class has_call_operator { +// A helper rename to avoid duplicate underscores. +template class has_any_call_impl : public has_any__call_impl {}; - template ()( - std::declval::type>()...))> - static std::true_type test(C *); - template static std::false_type test(...); +HAS_METHOD_WITH_RETURN_TYPE(_call_impl); -public: - static constexpr bool value = decltype(test(0))::value; +template +class has_valid_call_impl : public has__call_impl_with_return_type< + U, double, typename const_ref::type...> { }; +HAS_METHOD(_call_impl); + +template +class has_possible_call_impl : public has__call_impl {}; + /* * This determines whether or not a class has a method defined for, - * `double _call_impl(const X &x, const Y &y, const Z &z, ...)` + * `operator() (const X &x, const Y &y, const Z &z, ...)` * The result of the inspection gets stored in the member `value`. */ -template class has_valid_call_impl { +template class has_call_operator { - template - static typename std::is_same< - decltype(std::declval()._call_impl( - std::declval::type>()...)), - double>::type - test(C *); + template ()( + std::declval::type>()...))> + static std::true_type test(C *); template static std::false_type test(...); public: static constexpr bool value = decltype(test(0))::value; }; -/* - * This determines whether or not a class has a method defined for - * something close to, but not quite, a valid _call_impl(. For example - * if a class has: - * double _call_impl(const X x) - * or - * double _call_impl(X &x) - * or - * int _call_impl(const X &x) - * those are nearly correct but the required `const X &x` in which - * case this trait can be used to warn the user. - */ -template class has_possible_call_impl { - template ()._call_impl( - std::declval()...))> - static std::true_type test(int); - template static std::false_type test(...); - -public: - static constexpr bool value = decltype(test(0))::value; -}; - template class has_invalid_call_impl { public: @@ -93,42 +55,6 @@ template class has_invalid_call_impl { !has_valid_call_impl::value); }; -/* - * This set of trait logic checks if a type has any _call_impl( method - * implemented (including private methods) by hijacking name hiding. - * Namely if a derived class overloads a method the base methods will - * be hidden. So by starting with a base class with a known method - * then extending that class you can determine if the derived class - * included any other methods with that name. - * - * https://stackoverflow.com/questions/1628768/why-does-an-overridden-function-in-the-derived-class-hide-other-overloads-of-the - */ -namespace detail { - -struct DummyType {}; - -struct BaseWithPublicCallImpl { - // This method will be accessible in `MultiInherit` only if - // the class U doesn't contain any methods with the same name. - double _call_impl(const DummyType &) const { return -1.; } -}; - -template -struct MultiInheritCallImpl : public U, public BaseWithPublicCallImpl {}; -} // namespace detail - -template class has_any_call_impl { - template - static typename std::enable_if< - has_valid_call_impl, - detail::DummyType>::value, - std::false_type>::type - test(int); - template static std::true_type test(...); - -public: - static constexpr bool value = decltype(test(0))::value; -}; } // namespace albatross #endif diff --git a/include/albatross/src/details/error_handling.hpp b/include/albatross/src/details/error_handling.hpp new file mode 100644 index 00000000..0e3059cb --- /dev/null +++ b/include/albatross/src/details/error_handling.hpp @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2019 Swift Navigation Inc. + * Contact: Swift Navigation + * + * This source is subject to the license found in the file 'LICENSE' which must + * be distributed together with this source. All other rights reserved. + * + * THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, + * EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef INCLUDE_ALBATROSS_SRC_DETAILS_ERROR_HANDLING_HPP_ +#define INCLUDE_ALBATROSS_SRC_DETAILS_ERROR_HANDLING_HPP_ + +namespace albatross { + +#define ALBATROSS_FAIL(dummy, msg) \ + { static_assert(delay_static_assert::value, msg); } + +/* + * Setting ALBATROSS_FAIL to "= delete" as below will slightly + * change the behavior of failures. In some situations + * inspection of return types can trigger the delay_static_assert + * approach above, while the deleted function approach may + * work fine. In general however the deleted function approach + * leads to slightly more confusing compile errors since it + * isn't possible to include an error message. + */ + +//#define ALBATROSS_FAIL(dummy, msg) = delete + +} // namespace albatross + +#endif /* INCLUDE_ALBATROSS_SRC_DETAILS_ERROR_HANDLING_HPP_ */ diff --git a/include/albatross/src/details/has_any_macros.hpp b/include/albatross/src/details/has_any_macros.hpp new file mode 100644 index 00000000..72cc7890 --- /dev/null +++ b/include/albatross/src/details/has_any_macros.hpp @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2019 Swift Navigation Inc. + * Contact: Swift Navigation + * + * This source is subject to the license found in the file 'LICENSE' which must + * be distributed together with this source. All other rights reserved. + * + * THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, + * EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef INCLUDE_ALBATROSS_SRC_DETAILS_HAS_ANY_MACROS_HPP_ +#define INCLUDE_ALBATROSS_SRC_DETAILS_HAS_ANY_MACROS_HPP_ + +namespace albatross { + +#define HAS_METHOD(fname) \ + template class has_##fname { \ + template ().fname( \ + std::declval()...))> \ + static std::true_type test(C *); \ + template static std::false_type test(...); \ + \ + public: \ + static constexpr bool value = decltype(test(0))::value; \ + }; + +#define HAS_METHOD_WITH_RETURN_TYPE(fname) \ + template \ + class has_##fname##_with_return_type { \ + template ().fname(std::declval()...))> \ + static typename std::is_same::type \ + test(C *); \ + template static std::false_type test(...); \ + \ + public: \ + static constexpr bool value = decltype(test(0))::value; \ + }; + +/* + * This set of macros creates a trait which can check for the existence + * of any method with some name `fname` (including private methods). + * This is done by hijacking name hiding, Namely if a derived class overloads a + * method the base methods will be hidden. So by starting with a base class + * with a known method then extending that class you can determine if the + * derived class included any other methods with that name. + * https://stackoverflow.com/questions/1628768/why-does-an-overridden-function-in-the-derived-class-hide-other-overloads-of-the + */ + +namespace detail { +struct DummyType {}; +} // namespace detail + +/* + * Creates a base class with a public method with name `fname` this is + * included via inheritance to check for name hiding. + */ +#define BASE_WITH_PUBLIC_METHOD(fname) \ + namespace detail { \ + struct BaseWithPublic##fname { \ + DummyType fname() const { return DummyType(); } \ + }; \ + } + +/* + * Creates a templated class which inherits from a given class as well + * as the Base class above. If U contains a method with name `fname` then + * the Base class definition of that function will be hidden. + */ +#define MULTI_INHERIT(fname) \ + namespace detail { \ + template \ + struct MultiInherit##fname : public U, public BaseWithPublic##fname {}; \ + } + +/* + * Creates a trait which checks to see if the dummy implementation in + * the Base class exists or not, used to determine if name hiding is + * active. + */ +#define HAS_DUMMY_DEFINITION(fname) \ + namespace detail { \ + template class has_dummy_definition_##fname { \ + template \ + static typename std::is_same().fname()), \ + DummyType>::type \ + test(C *); \ + template static std::false_type test(...); \ + \ + public: \ + static constexpr bool value = decltype(test(0))::value; \ + }; \ + } + +/* + * This creates the final trait which will check if any method named + * `fname` exists in type U. + */ +#define HAS_ANY_DEFINITION(fname) \ + template class has_any_##fname { \ + template \ + static typename std::enable_if>::value, \ + std::false_type>::type \ + test(int); \ + template static std::true_type test(...); \ + \ + public: \ + static constexpr bool value = decltype(test(0))::value; \ + } + +#define MAKE_HAS_ANY_TRAIT(fname) \ + BASE_WITH_PUBLIC_METHOD(fname); \ + MULTI_INHERIT(fname); \ + HAS_DUMMY_DEFINITION(fname); \ + HAS_ANY_DEFINITION(fname); + +} // namespace albatross + +#endif /* INCLUDE_ALBATROSS_SRC_DETAILS_HAS_ANY_MACROS_HPP_ */ diff --git a/include/albatross/src/details/traits.hpp b/include/albatross/src/details/traits.hpp new file mode 100644 index 00000000..c3e5c123 --- /dev/null +++ b/include/albatross/src/details/traits.hpp @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2019 Swift Navigation Inc. + * Contact: Swift Navigation + * + * This source is subject to the license found in the file 'LICENSE' which must + * be distributed together with this source. All other rights reserved. + * + * THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, + * EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef INCLUDE_ALBATROSS_SRC_DETAILS_TRAITS_HPP_ +#define INCLUDE_ALBATROSS_SRC_DETAILS_TRAITS_HPP_ + +namespace albatross { + +/* + * We frequently inspect for definitions of functions which + * must be defined for const references to objects + * (so that repeated evaluations return the same thing + * and so the computations are not repeatedly copying.) + * This type conversion utility will turn a type `T` into `const T&` + */ +template struct const_ref { + typedef + typename std::add_lvalue_reference::type>::type + type; +}; + +/* + * This little trick was borrowed from cereal, you can think of it as + * a function that will always return false ... but that doesn't + * get resolved until template instantiation, which when combined + * with a static assert let's you include a static assert that + * only triggers with a particular template parameter is used. + */ +template struct delay_static_assert : std::false_type {}; + +/* + * Checks if a class type is complete by using sizeof. + * + * https://stackoverflow.com/questions/25796126/static-assert-that-template-typename-t-is-not-complete + */ +template class is_complete { + template + static std::true_type test(int); + template static std::false_type test(...); + +public: + static constexpr bool value = + decltype(test::type>(0))::value; +}; + +template struct is_vector : public std::false_type {}; + +template +struct is_vector> : public std::true_type {}; + +} // namespace albatross + +#endif /* INCLUDE_ALBATROSS_SRC_DETAILS_TRAITS_HPP_ */ diff --git a/include/albatross/src/models/gp.hpp b/include/albatross/src/models/gp.hpp index c31f286e..a74f35a9 100644 --- a/include/albatross/src/models/gp.hpp +++ b/include/albatross/src/models/gp.hpp @@ -215,8 +215,7 @@ class GaussianProcessBase : public ModelBase { return ss.str(); } - // If the implementing class doesn't have a fit method for this - // FeatureType but the CovarianceFunction does. + // If the CovarianceFunction is defined. template ::value, @@ -227,6 +226,15 @@ class GaussianProcessBase : public ModelBase { return GPFitType(features, cov, targets); } + // If the CovarianceFunction is NOT defined. + template ::value, + int>::type = 0> + auto _fit_impl(const std::vector &features, + const MarginalDistribution &targets) const + ALBATROSS_FAIL(FeatureType, "CovFunc is not defined for FeatureType"); + template < typename FeatureType, typename FitFeaturetype, typename std::enable_if< @@ -284,10 +292,12 @@ class GaussianProcessBase : public ModelBase { !has_call_operator::value || !has_call_operator::value, int>::type = 0> - PredictType _predict_impl(const std::vector &features, - const GPFitType &gp_fit, - PredictTypeIdentity &&) const = - delete; // Covariance Function isn't defined for FeatureType. + auto _predict_impl(const std::vector &features, + const GPFitType &gp_fit, + PredictTypeIdentity &&) const + ALBATROSS_FAIL( + FeatureType, + "CovFunc is not defined for FeatureType and FitFeatureType"); CovFunc get_covariance() const { return covariance_function_; }