Skip to content

Commit

Permalink
Support get/put for SGObject (#4066)
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Jan 4, 2018
1 parent ee3aa74 commit 21e50fa
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/interfaces/swig/SGBase.i
Expand Up @@ -525,3 +525,4 @@ copy_reg._reconstructor=_sg_reconstructor
SUPPORT_TAG(String, string, std::string)
SUPPORT_TAG(Float64, float, float64_t)
SUPPORT_TAG(Int64, int, int64_t)
SUPPORT_TAG(Object, object, CSGObject*)
14 changes: 14 additions & 0 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -909,3 +909,17 @@ void CSGObject::list_observable_parameters()
x.second.second.c_str());
}
}

bool CSGObject::has(const std::string& name) const
{
return has_parameter(BaseTag(name));
}

void CSGObject::ref_value(CSGObject* const* value)
{
SG_REF(*value);
}

void CSGObject::ref_value(...)
{
}
11 changes: 7 additions & 4 deletions src/shogun/base/SGObject.h
Expand Up @@ -300,10 +300,7 @@ class CSGObject
* @param name name of the parameter
* @return true if the parameter exists with the input name
*/
bool has(const std::string& name) const
{
return has_parameter(BaseTag(name));
}
bool has(const std::string& name) const;

/** Checks if object has a class parameter identified by a Tag.
*
Expand Down Expand Up @@ -343,7 +340,10 @@ class CSGObject
if (has_parameter(_tag))
{
if(has<T>(_tag.name()))
{
ref_value(&value);
update_parameter(_tag, erase_type(value));
}
else
{
SG_ERROR("Type for parameter with name \"%s\" is not correct.\n",
Expand Down Expand Up @@ -552,6 +552,9 @@ class CSGObject
void unset_global_objects();
void init();

static void ref_value(CSGObject* const* value);
static void ref_value(...);

/** Checks if object has a parameter identified by a BaseTag.
* This only checks for name and not type information.
* See its usage in has() and has<T>().
Expand Down
1 change: 1 addition & 0 deletions src/shogun/base/class_list.h
Expand Up @@ -46,6 +46,7 @@ namespace shogun {
delete object;
SG_SERROR("Type mismatch");
}
cast->ref();
return cast;
}
}
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/base/MockObject.h
Expand Up @@ -16,6 +16,11 @@ namespace shogun
init_params();
}

virtual ~CMockObject()
{
SG_UNREF(m_object);
}

const char* get_name() const
{
return "MockObject";
Expand All @@ -31,6 +36,11 @@ namespace shogun
return m_watched;
}

CSGObject* get_object() const
{
return m_object;
}

protected:
void init_params()
{
Expand All @@ -43,10 +53,17 @@ namespace shogun
"watched_int", &m_watched,
AnyParameterProperties(
MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE));

watch_param(
"watched_object", (CSGObject**)&m_object,
AnyParameterProperties(
MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE));
}

private:
int32_t m_integer = 0;
int32_t m_watched = 0;

CSGObject* m_object = 0;
};
}
13 changes: 13 additions & 0 deletions tests/unit/base/SGObject_unittest.cc
Expand Up @@ -396,3 +396,16 @@ TEST(SGObject, watched_parameter)
EXPECT_EQ(obj->get<int32_t>("watched_int"), 12);
EXPECT_EQ(obj->get<int32_t>("watched_int"), obj->get_watched());
}

TEST(SGObject, watched_parameter_object)
{
auto obj = some<CMockObject>();
Some<CMockObject> other_obj = some<CMockObject>();

EXPECT_EQ(other_obj->ref_count(), 1);
obj->put("watched_object", dynamic_cast<CSGObject*>(other_obj.get()));
EXPECT_EQ(other_obj->ref_count(), 2);
EXPECT_TRUE(other_obj->equals(obj));
obj = nullptr;
EXPECT_EQ(other_obj->ref_count(), 1);
}

0 comments on commit 21e50fa

Please sign in to comment.