Skip to content

Commit

Permalink
Make SG_ADD register parameters into the new parameter map
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Dec 24, 2017
1 parent c7af75e commit 3dde105
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 40 deletions.
35 changes: 23 additions & 12 deletions src/shogun/base/SGObject.h
Expand Up @@ -76,19 +76,23 @@ template <class T> class SGStringList;
#define VARARG_IMPL(base, count, ...) VARARG_IMPL2(base, count, __VA_ARGS__)
#define VARARG(base, ...) VARARG_IMPL(base, VA_NARGS(__VA_ARGS__), __VA_ARGS__)

#define SG_ADD4(param, name, description, ms_available) {\
m_parameters->add(param, name, description);\
if (ms_available)\
m_model_selection_parameters->add(param, name, description);\
}
#define SG_ADD4(param, name, description, ms_available) \
{ \
m_parameters->add(param, name, description); \
watch_param(name, param); \
if (ms_available) \
m_model_selection_parameters->add(param, name, description); \
}

#define SG_ADD5(param, name, description, ms_available, gradient_available) {\
m_parameters->add(param, name, description);\
if (ms_available)\
m_model_selection_parameters->add(param, name, description);\
if (gradient_available)\
m_gradient_parameters->add(param, name, description);\
}
#define SG_ADD5(param, name, description, ms_available, gradient_available) \
{ \
m_parameters->add(param, name, description); \
watch_param(name, param); \
if (ms_available) \
m_model_selection_parameters->add(param, name, description); \
if (gradient_available) \
m_gradient_parameters->add(param, name, description); \
}

#define SG_ADD(...) VARARG(SG_ADD, __VA_ARGS__)

Expand Down Expand Up @@ -490,6 +494,13 @@ class CSGObject
type_erased_put(tag, erase_type(value));
}

template <typename T>
void watch_param(const std::string& name, T* value)
{
BaseTag tag(name);
type_erased_put(tag, erase_type_non_owning(value));
}

public:
/** Updates the hash of current parameter combination */
virtual void update_parameter_hash();
Expand Down
7 changes: 7 additions & 0 deletions src/shogun/lib/SGSparseVector.h
Expand Up @@ -216,6 +216,13 @@ template <class T> class SGSparseVector : public SGReferencedData

};

template <class T>
inline bool
operator==(const SGSparseVector<T>& lhs, const SGSparseVector<T>& rhs)
{
SG_SERROR("Comparison is not implemented for sparse vectors");
return false;
}
}

#endif // __SGSPARSEVECTOR_H__
6 changes: 4 additions & 2 deletions src/shogun/modelselection/GradientModelSelection.cpp
Expand Up @@ -89,10 +89,12 @@ class GradientModelSelectionCostFunction: public FirstOrderCostFunction
"obj in GradientModelSelectionCostFunction", MS_NOT_AVAILABLE);
m_func_data = NULL;
m_val = SGVector<float64_t>();
SG_ADD(m_val, "GradientModelSelectionCostFunction__m_val",
SG_ADD(
&m_val, "GradientModelSelectionCostFunction__m_val",
"val in GradientModelSelectionCostFunction", MS_NOT_AVAILABLE);
m_grad = SGVector<float64_t>();
SG_ADD(m_grad, "GradientModelSelectionCostFunction__m_grad",
SG_ADD(
&m_grad, "GradientModelSelectionCostFunction__m_grad",
"grad in GradientModelSelectionCostFunction", MS_NOT_AVAILABLE);
}

Expand Down
68 changes: 42 additions & 26 deletions tests/unit/base/MockObject.h
Expand Up @@ -4,30 +4,46 @@
namespace shogun
{

/** @brief Used to test the tags-parameter framework
* Allows testing of registering new member and avoiding
* non-registered member variables using tags framework.
*/
class CMockObject : public CSGObject
{
public:
CMockObject() : CSGObject()
{
init_params();
}

const char* get_name() const { return "MockObject"; }

protected:
void init_params()
{
float64_t decimal = 0.0;
register_param("vector", SGVector<float64_t>());
register_param("int", m_integer);
register_param("float", decimal);
}

private:
int32_t m_integer = 0;
};
/** @brief Used to test the tags-parameter framework
* Allows testing of registering new member and avoiding
* non-registered member variables using tags framework.
*/
class CMockObject : public CSGObject
{
public:
CMockObject() : CSGObject()
{
init_params();
}

const char* get_name() const
{
return "MockObject";
}

void set_watched(int32_t value)
{
m_watched = value;
}

int32_t get_watched() const
{
return m_watched;
}

protected:
void init_params()
{
float64_t decimal = 0.0;
register_param("vector", SGVector<float64_t>());
register_param("int", m_integer);
register_param("float", decimal);

watch_param("watched_int", &m_watched);
}

private:
int32_t m_integer = 0;
int32_t m_watched = 0;
};
}
12 changes: 12 additions & 0 deletions tests/unit/base/SGObject_unittest.cc
Expand Up @@ -384,3 +384,15 @@ TEST(SGObject, tags_has)
EXPECT_EQ(obj->has<int32_t>("foo"), false);
EXPECT_EQ(obj->has(Tag<int32_t>("foo")), false);
}

TEST(SGObject, watched_parameter)
{
auto obj = some<CMockObject>();

obj->put("watched_int", 89);
EXPECT_EQ(obj->get<int32_t>("watched_int"), 89);
EXPECT_EQ(obj->get<int32_t>("watched_int"), obj->get_watched());
obj->set_watched(12);
EXPECT_EQ(obj->get<int32_t>("watched_int"), 12);
EXPECT_EQ(obj->get<int32_t>("watched_int"), obj->get_watched());
}

0 comments on commit 3dde105

Please sign in to comment.