From 3dde105e7e4afd72a37eed8b48a5315ec553b6be Mon Sep 17 00:00:00 2001 From: Sergey Lisitsyn Date: Sun, 24 Dec 2017 19:10:17 +0100 Subject: [PATCH] Make SG_ADD register parameters into the new parameter map --- src/shogun/base/SGObject.h | 35 ++++++---- src/shogun/lib/SGSparseVector.h | 7 ++ .../modelselection/GradientModelSelection.cpp | 6 +- tests/unit/base/MockObject.h | 68 ++++++++++++------- tests/unit/base/SGObject_unittest.cc | 12 ++++ 5 files changed, 88 insertions(+), 40 deletions(-) diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index 42764d1ab14..2a3ccb4396b 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -76,19 +76,23 @@ template class SGStringList; #define VARARG_IMPL(base, count, ...) VARARG_IMPL2(base, count, __VA_ARGS__) #define VARARG(base, ...) VARARG_IMPL(base, VA_NARGS(__VA_ARGS__), __VA_ARGS__) -#define SG_ADD4(param, name, description, ms_available) {\ - m_parameters->add(param, name, description);\ - if (ms_available)\ - m_model_selection_parameters->add(param, name, description);\ -} +#define SG_ADD4(param, name, description, ms_available) \ + { \ + m_parameters->add(param, name, description); \ + watch_param(name, param); \ + if (ms_available) \ + m_model_selection_parameters->add(param, name, description); \ + } -#define SG_ADD5(param, name, description, ms_available, gradient_available) {\ - m_parameters->add(param, name, description);\ - if (ms_available)\ - m_model_selection_parameters->add(param, name, description);\ - if (gradient_available)\ - m_gradient_parameters->add(param, name, description);\ -} +#define SG_ADD5(param, name, description, ms_available, gradient_available) \ + { \ + m_parameters->add(param, name, description); \ + watch_param(name, param); \ + if (ms_available) \ + m_model_selection_parameters->add(param, name, description); \ + if (gradient_available) \ + m_gradient_parameters->add(param, name, description); \ + } #define SG_ADD(...) VARARG(SG_ADD, __VA_ARGS__) @@ -490,6 +494,13 @@ class CSGObject type_erased_put(tag, erase_type(value)); } + template + void watch_param(const std::string& name, T* value) + { + BaseTag tag(name); + type_erased_put(tag, erase_type_non_owning(value)); + } + public: /** Updates the hash of current parameter combination */ virtual void update_parameter_hash(); diff --git a/src/shogun/lib/SGSparseVector.h b/src/shogun/lib/SGSparseVector.h index 31a54812f5c..7da4799ef88 100644 --- a/src/shogun/lib/SGSparseVector.h +++ b/src/shogun/lib/SGSparseVector.h @@ -216,6 +216,13 @@ template class SGSparseVector : public SGReferencedData }; +template +inline bool +operator==(const SGSparseVector& lhs, const SGSparseVector& rhs) +{ + SG_SERROR("Comparison is not implemented for sparse vectors"); + return false; +} } #endif // __SGSPARSEVECTOR_H__ diff --git a/src/shogun/modelselection/GradientModelSelection.cpp b/src/shogun/modelselection/GradientModelSelection.cpp index b1cc7aa1174..f670ed3866c 100644 --- a/src/shogun/modelselection/GradientModelSelection.cpp +++ b/src/shogun/modelselection/GradientModelSelection.cpp @@ -89,10 +89,12 @@ class GradientModelSelectionCostFunction: public FirstOrderCostFunction "obj in GradientModelSelectionCostFunction", MS_NOT_AVAILABLE); m_func_data = NULL; m_val = SGVector(); - SG_ADD(m_val, "GradientModelSelectionCostFunction__m_val", + SG_ADD( + &m_val, "GradientModelSelectionCostFunction__m_val", "val in GradientModelSelectionCostFunction", MS_NOT_AVAILABLE); m_grad = SGVector(); - SG_ADD(m_grad, "GradientModelSelectionCostFunction__m_grad", + SG_ADD( + &m_grad, "GradientModelSelectionCostFunction__m_grad", "grad in GradientModelSelectionCostFunction", MS_NOT_AVAILABLE); } diff --git a/tests/unit/base/MockObject.h b/tests/unit/base/MockObject.h index ff79131cf91..84c6ea1cae2 100644 --- a/tests/unit/base/MockObject.h +++ b/tests/unit/base/MockObject.h @@ -4,30 +4,46 @@ namespace shogun { -/** @brief Used to test the tags-parameter framework - * Allows testing of registering new member and avoiding - * non-registered member variables using tags framework. - */ -class CMockObject : public CSGObject -{ -public: - CMockObject() : CSGObject() - { - init_params(); - } - - const char* get_name() const { return "MockObject"; } - -protected: - void init_params() - { - float64_t decimal = 0.0; - register_param("vector", SGVector()); - register_param("int", m_integer); - register_param("float", decimal); - } - -private: - int32_t m_integer = 0; -}; + /** @brief Used to test the tags-parameter framework + * Allows testing of registering new member and avoiding + * non-registered member variables using tags framework. + */ + class CMockObject : public CSGObject + { + public: + CMockObject() : CSGObject() + { + init_params(); + } + + const char* get_name() const + { + return "MockObject"; + } + + void set_watched(int32_t value) + { + m_watched = value; + } + + int32_t get_watched() const + { + return m_watched; + } + + protected: + void init_params() + { + float64_t decimal = 0.0; + register_param("vector", SGVector()); + register_param("int", m_integer); + register_param("float", decimal); + + watch_param("watched_int", &m_watched); + } + + private: + int32_t m_integer = 0; + int32_t m_watched = 0; + }; } diff --git a/tests/unit/base/SGObject_unittest.cc b/tests/unit/base/SGObject_unittest.cc index e74684aac9b..86340eba7fb 100644 --- a/tests/unit/base/SGObject_unittest.cc +++ b/tests/unit/base/SGObject_unittest.cc @@ -384,3 +384,15 @@ TEST(SGObject, tags_has) EXPECT_EQ(obj->has("foo"), false); EXPECT_EQ(obj->has(Tag("foo")), false); } + +TEST(SGObject, watched_parameter) +{ + auto obj = some(); + + obj->put("watched_int", 89); + EXPECT_EQ(obj->get("watched_int"), 89); + EXPECT_EQ(obj->get("watched_int"), obj->get_watched()); + obj->set_watched(12); + EXPECT_EQ(obj->get("watched_int"), 12); + EXPECT_EQ(obj->get("watched_int"), obj->get_watched()); +}