diff --git a/albatross/core/parameter_handling_mixin.h b/albatross/core/parameter_handling_mixin.h index 6903bf2a..d009e828 100644 --- a/albatross/core/parameter_handling_mixin.h +++ b/albatross/core/parameter_handling_mixin.h @@ -39,7 +39,7 @@ struct Parameter { ParameterPrior prior; Parameter() : value(), prior(nullptr){}; - Parameter(ParameterValue value_) : value(value_) {} + Parameter(ParameterValue value_) : value(value_), prior(nullptr) {} Parameter(ParameterValue value_, const ParameterPrior &prior_) : value(value_), prior(prior_){}; /* @@ -162,7 +162,14 @@ class ParameterHandlingMixin { } } - void set_param(const ParameterKey &key, const ParameterValue &value) { + void set_param_values(const std::map &values) { + for (const auto &pair : values) { + check_param_key(pair.first); + unchecked_set_param(pair.first, pair.second); + } + } + + void set_param_value(const ParameterKey &key, const ParameterValue &value) { check_param_key(key); unchecked_set_param(key, value); } @@ -172,6 +179,14 @@ class ParameterHandlingMixin { unchecked_set_param(key, param); } + // This just avoids the situation where a user would call `set_param` + // with a double, which may then be viewed by the compiler as the + // initialization argument for a `Parameter` which would then + // inadvertently overwrite the prior. + void set_param(const ParameterKey &key, const ParameterValue &value) { + set_param_value(key, value); + } + void set_prior(const ParameterKey &key, const ParameterPrior &prior) { check_param_key(key); unchecked_set_prior(key, prior); @@ -242,17 +257,17 @@ class ParameterHandlingMixin { } } - ParameterValue get_param_value(const std::string &name) const { + ParameterValue get_param_value(const ParameterKey &name) const { return get_params().at(name).value; } - void unchecked_set_param(const std::string &name, + void unchecked_set_param(const ParameterKey &name, const ParameterValue value) { Parameter param = {value, get_params()[name].prior}; unchecked_set_param(name, param); } - void unchecked_set_prior(const std::string &name, + void unchecked_set_prior(const ParameterKey &name, const ParameterPrior &prior) { Parameter param = {get_params()[name].value, prior}; unchecked_set_param(name, param); @@ -281,7 +296,7 @@ class ParameterHandlingMixin { virtual ParameterStore get_params() const { return params_; } - virtual void unchecked_set_param(const std::string &name, + virtual void unchecked_set_param(const ParameterKey &name, const Parameter ¶m) { params_[name] = param; } diff --git a/albatross/covariance_functions/covariance_functions.h b/albatross/covariance_functions/covariance_functions.h index a794b6e1..2cf50d44 100644 --- a/albatross/covariance_functions/covariance_functions.h +++ b/albatross/covariance_functions/covariance_functions.h @@ -74,8 +74,13 @@ template struct CovarianceFunction { inline auto set_params(const ParameterStore ¶ms) { return term.set_params(params); }; - inline auto set_param(const ParameterKey &key, const ParameterValue &value) { - return term.set_param(key, value); + inline auto + set_param_values(const std::map &values) { + return term.set_param_values(values); + }; + inline auto set_param_value(const ParameterKey &key, + const ParameterValue &value) { + return term.set_param_value(key, value); }; inline auto set_param(const ParameterKey &key, const Parameter ¶m) { return term.set_param(key, param); diff --git a/albatross/covariance_functions/covariance_term.h b/albatross/covariance_functions/covariance_term.h index 910387e8..b52ccb2e 100644 --- a/albatross/covariance_functions/covariance_term.h +++ b/albatross/covariance_functions/covariance_term.h @@ -71,7 +71,7 @@ class CombinationOfCovarianceTerms : public CovarianceTerm { return map_join(lhs_.get_params(), rhs_.get_params()); } - void unchecked_set_param(const std::string &name, + void unchecked_set_param(const ParameterKey &name, const Parameter ¶m) override { if (map_contains(lhs_.get_params(), name)) { lhs_.set_param(name, param); diff --git a/albatross/covariance_functions/scaling_function.h b/albatross/covariance_functions/scaling_function.h index cb9cf214..adb45ce6 100644 --- a/albatross/covariance_functions/scaling_function.h +++ b/albatross/covariance_functions/scaling_function.h @@ -78,6 +78,10 @@ template class ScalingTerm : public CovarianceTerm { scaling_function_.set_params(params); } + void set_param_values(const std::map &values) { + scaling_function_.set_param_values(values); + } + virtual ParameterStore get_params() const override { return scaling_function_.get_params(); } @@ -94,7 +98,7 @@ template class ScalingTerm : public CovarianceTerm { archive(cereal::make_nvp("scaling_function", scaling_function_)); } - void unchecked_set_param(const std::string &name, + void unchecked_set_param(const ParameterKey &name, const Parameter ¶m) override { scaling_function_.set_param(name, param); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dcd1103a..401eb054 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,19 +5,21 @@ set(CMAKE_CXX_FLAGS "-Wshadow -Wswitch-default -Wswitch-enum -Wundef -Wuninitial add_executable(albatross_unit_tests EXCLUDE_FROM_ALL +test_core_distribution.cc test_core_model.cc test_covariance_functions.cc test_distance_metrics.cc test_evaluate.cc test_map_utils.cc -test_serialize.cc -test_parameter_handling_mixin.cc +test_model_adapter.cc test_models.cc -test_core_distribution.cc +test_parameter_handling_mixin.cc +test_radial.cc +test_scaling_function.cc +test_serializable_ldlt.cc +test_serialize.cc test_tune.cc test_tuning_metrics.cc -test_serializable_ldlt.cc -test_radial.cc ) add_dependencies(albatross_unit_tests diff --git a/tests/test_model_adapter.cc b/tests/test_model_adapter.cc index 8276fd6b..d15aea77 100644 --- a/tests/test_model_adapter.cc +++ b/tests/test_model_adapter.cc @@ -31,9 +31,9 @@ class TestAdaptedModel : public TestAdaptedModelBase { std::string get_name() const override { return "test_adapted"; }; - const Eigen::VectorXd convert_feature(const double &x) const { + const Eigen::VectorXd convert_feature(const double &x) const override { Eigen::VectorXd converted(2); - converted << 1., (x - this->params_.at("center")); + converted << 1., (x - this->get_param_value("center")); return converted; } @@ -59,7 +59,7 @@ void test_get_set(albatross::RegressionModel &model, const std::string &key) { // Make sure a key exists, then modify it and make sure it // takes on the new value. - const auto orig = model.get_params().at(key); + const auto orig = model.get_param_value(key); model.set_param(key, orig + 1.); EXPECT_EQ(model.get_params().at(key), orig + 1.); } diff --git a/tests/test_parameter_handling_mixin.cc b/tests/test_parameter_handling_mixin.cc index 8b46b97f..f99eaee8 100644 --- a/tests/test_parameter_handling_mixin.cc +++ b/tests/test_parameter_handling_mixin.cc @@ -127,4 +127,26 @@ TEST(test_parameter_handler, test_set_prior) { } }; +TEST(test_parameter_handler, test_set_param_values_doesnt_overwrite_prior) { + auto p = TestParameterHandler(); + + const auto orig_params = p.get_params(); + + std::map new_params; + for (const auto &pair : orig_params) { + new_params[pair.first] = pair.second.value + 1.; + } + + p.set_param_values(new_params); + + for (const auto &pair : orig_params) { + // Make sure we changed the parameter value + const auto new_param = p.get_params().at(pair.first); + EXPECT_NE(new_param.value, pair.second.value); + // but not the prior. + EXPECT_TRUE(!pair.second.has_prior() || + (pair.second.prior == new_param.prior)); + } +}; + } // namespace albatross diff --git a/tests/test_scaling_function.cc b/tests/test_scaling_function.cc index 6592793d..02fa65c2 100644 --- a/tests/test_scaling_function.cc +++ b/tests/test_scaling_function.cc @@ -111,8 +111,8 @@ TEST(test_scaling_functions, test_predicts) { auto model = gp_from_covariance(covariance_function); auto folds = leave_one_out(dataset); - auto cv_scores = - cross_validated_scores(root_mean_square_error, folds, &model); + auto cv_scores = cross_validated_scores( + evaluation_metrics::root_mean_square_error, folds, &model); EXPECT_LE(cv_scores.mean(), 0.01); }