diff --git a/src/interfaces/swig/SGBase.i b/src/interfaces/swig/SGBase.i index 3a0d1655218..b5bfafb7e4d 100644 --- a/src/interfaces/swig/SGBase.i +++ b/src/interfaces/swig/SGBase.i @@ -525,3 +525,4 @@ copy_reg._reconstructor=_sg_reconstructor SUPPORT_TAG(String, string, std::string) SUPPORT_TAG(Float64, float, float64_t) SUPPORT_TAG(Int64, int, int64_t) +SUPPORT_TAG(Object, object, CSGObject*) diff --git a/src/shogun/base/SGObject.cpp b/src/shogun/base/SGObject.cpp index 8418c2f508b..c831811170b 100644 --- a/src/shogun/base/SGObject.cpp +++ b/src/shogun/base/SGObject.cpp @@ -909,3 +909,17 @@ void CSGObject::list_observable_parameters() x.second.second.c_str()); } } + +bool CSGObject::has(const std::string& name) const +{ + return has_parameter(BaseTag(name)); +} + +void CSGObject::ref_value(CSGObject* const* value) +{ + SG_REF(*value); +} + +void CSGObject::ref_value(...) +{ +} diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index 7432b631294..a46f04f6c94 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -300,10 +300,7 @@ class CSGObject * @param name name of the parameter * @return true if the parameter exists with the input name */ - bool has(const std::string& name) const - { - return has_parameter(BaseTag(name)); - } + bool has(const std::string& name) const; /** Checks if object has a class parameter identified by a Tag. * @@ -343,7 +340,10 @@ class CSGObject if (has_parameter(_tag)) { if(has(_tag.name())) + { + ref_value(&value); update_parameter(_tag, erase_type(value)); + } else { SG_ERROR("Type for parameter with name \"%s\" is not correct.\n", @@ -552,6 +552,9 @@ class CSGObject void unset_global_objects(); void init(); + static void ref_value(CSGObject* const* value); + static void ref_value(...); + /** Checks if object has a parameter identified by a BaseTag. * This only checks for name and not type information. * See its usage in has() and has(). diff --git a/src/shogun/base/class_list.h b/src/shogun/base/class_list.h index 3eeaedc536f..883b18fc82e 100644 --- a/src/shogun/base/class_list.h +++ b/src/shogun/base/class_list.h @@ -46,6 +46,7 @@ namespace shogun { delete object; SG_SERROR("Type mismatch"); } + cast->ref(); return cast; } } diff --git a/tests/unit/base/MockObject.h b/tests/unit/base/MockObject.h index 4af938a8884..bcdab1f12e1 100644 --- a/tests/unit/base/MockObject.h +++ b/tests/unit/base/MockObject.h @@ -16,6 +16,11 @@ namespace shogun init_params(); } + virtual ~CMockObject() + { + SG_UNREF(m_object); + } + const char* get_name() const { return "MockObject"; @@ -31,6 +36,11 @@ namespace shogun return m_watched; } + CSGObject* get_object() const + { + return m_object; + } + protected: void init_params() { @@ -43,10 +53,17 @@ namespace shogun "watched_int", &m_watched, AnyParameterProperties( MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE)); + + watch_param( + "watched_object", (CSGObject**)&m_object, + AnyParameterProperties( + MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE)); } private: int32_t m_integer = 0; int32_t m_watched = 0; + + CSGObject* m_object = 0; }; } diff --git a/tests/unit/base/SGObject_unittest.cc b/tests/unit/base/SGObject_unittest.cc index 86340eba7fb..34c267987a1 100644 --- a/tests/unit/base/SGObject_unittest.cc +++ b/tests/unit/base/SGObject_unittest.cc @@ -396,3 +396,16 @@ TEST(SGObject, watched_parameter) EXPECT_EQ(obj->get("watched_int"), 12); EXPECT_EQ(obj->get("watched_int"), obj->get_watched()); } + +TEST(SGObject, watched_parameter_object) +{ + auto obj = some(); + Some other_obj = some(); + + EXPECT_EQ(other_obj->ref_count(), 1); + obj->put("watched_object", dynamic_cast(other_obj.get())); + EXPECT_EQ(other_obj->ref_count(), 2); + EXPECT_TRUE(other_obj->equals(obj)); + obj = nullptr; + EXPECT_EQ(other_obj->ref_count(), 1); +}