Skip to content

Commit

Permalink
use SG_ADD to register objects without upcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Feb 23, 2018
1 parent 5158986 commit d5c26ed
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 35 deletions.
8 changes: 8 additions & 0 deletions src/shogun/base/Parameter.h
Expand Up @@ -333,6 +333,14 @@ class Parameter
*/
void add(CSGObject** param,
const char* name, const char* description="");

template <typename T, std::enable_if_t<std::is_base_of<CSGObject, T>::value, T>* = nullptr>
void add(T** param,const char* name, const char* description="")
{
TSGDataType type(CT_SCALAR, ST_NONE, PT_SGOBJECT);
add_type(&type, (CSGObject**)param, name, description);
}

/** add param
* @param param parameter itself
* @param name name of parameter
Expand Down
9 changes: 0 additions & 9 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -845,15 +845,6 @@ bool CSGObject::has(const std::string& name) const
return has_parameter(BaseTag(name));
}

void CSGObject::ref_value(CSGObject* const* value)
{
SG_REF(*value);
}

void CSGObject::ref_value(...)
{
}

class ToStringVisitor : public AnyVisitor
{
public:
Expand Down
17 changes: 8 additions & 9 deletions src/shogun/base/SGObject.h
Expand Up @@ -358,7 +358,7 @@ class CSGObject
get_name(), _tag.name().c_str(),
exc.actual().c_str(), exc.expected().c_str());
}
ref_value(&value);
ref_value(value);
update_parameter(_tag, make_any(value));
}
else
Expand All @@ -374,12 +374,6 @@ class CSGObject
{
if (dynamic_cast<T*>(value))
{
if (has<T*>(name))
{
SG_REF(value);
T* old = get<T*>(name);
SG_UNREF(old);
}
put(Tag<T*>(name), (T*) value);
return true;
}
Expand Down Expand Up @@ -681,8 +675,13 @@ class CSGObject
void unset_global_objects();
void init();

static void ref_value(CSGObject* const* value);
static void ref_value(...);
/** Overloaded helper to increase reference counter */
static void ref_value(CSGObject* value) { SG_REF(value); }

/** Overloaded helper to increase reference counter
* Here a no-op for non CSGobject pointer parameters */
template <typename T, std::enable_if_t<!std::is_base_of<CSGObject, typename std::remove_pointer<T>::type>::value, T>* = nullptr>
static void ref_value(T value) {}

/** Checks if object has a parameter identified by a BaseTag.
* This only checks for name and not type information.
Expand Down
3 changes: 1 addition & 2 deletions src/shogun/machine/DistanceMachine.cpp
Expand Up @@ -35,8 +35,7 @@ void CDistanceMachine::init()
set_store_model_features(true);

distance=NULL;
m_parameters->add((CSGObject**)&distance, "distance", "Distance to use");
watch_param("distance", &distance, AnyParameterProperties("Distance to use"));
SG_ADD(&distance, "distance", "Distance to use", MS_AVAILABLE);
}

void CDistanceMachine::distances_lhs(SGVector<float64_t>& result, index_t idx_a1, index_t idx_a2, index_t idx_b)
Expand Down
7 changes: 1 addition & 6 deletions src/shogun/machine/KernelMachine.cpp
Expand Up @@ -634,12 +634,7 @@ void CKernelMachine::init()
use_linadd=true;
use_bias=true;

// SG_ADD((CSGObject**) &kernel, "kernel", "", MS_AVAILABLE);
m_parameters->add((CSGObject**) &kernel, "kernel", "");
m_model_selection_parameters->add((CSGObject**) &kernel, "kernel", "");

watch_param("kernel", &kernel, AnyParameterProperties("", MS_AVAILABLE));

SG_ADD(&kernel, "kernel", "", MS_AVAILABLE);
SG_ADD((CSGObject**) &m_custom_kernel, "custom_kernel", "Custom kernel for"
" data lock", MS_NOT_AVAILABLE);
SG_ADD((CSGObject**) &m_kernel_backup, "kernel_backup",
Expand Down
7 changes: 2 additions & 5 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -25,11 +25,8 @@ CMachine::CMachine()
SG_ADD((machine_int_t*) &m_solver_type, "solver_type",
"Type of solver.", MS_NOT_AVAILABLE);

// SG_ADD((CSGObject**) &m_labels, "labels",
// "Labels to be used.", MS_NOT_AVAILABLE);
m_parameters->add((CSGObject**)&m_labels, "labels", "Labels to be used.");
watch_param("labels", &m_labels, AnyParameterProperties("Labels to be used."));

SG_ADD(&m_labels, "labels",
"Labels to be used.", MS_NOT_AVAILABLE);
SG_ADD(&m_store_model_features, "store_model_features",
"Should feature data of model be stored after training?", MS_NOT_AVAILABLE);
SG_ADD(&m_data_locked, "data_locked",
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/base/MockObject.h
Expand Up @@ -297,7 +297,7 @@ namespace shogun
"Integer", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE));

watch_param(
"watched_object", (CSGObject**)&m_object,
"watched_object", &m_object,
AnyParameterProperties(
"Object", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE));
}
Expand All @@ -306,6 +306,6 @@ namespace shogun
int32_t m_integer = 0;
int32_t m_watched = 0;

CSGObject* m_object = 0;
CMockObject* m_object = nullptr;
};
}
4 changes: 2 additions & 2 deletions tests/unit/base/SGObject_unittest.cc
Expand Up @@ -528,10 +528,10 @@ TEST(SGObject, watched_parameter)
TEST(SGObject, watched_parameter_object)
{
auto obj = some<CMockObject>();
Some<CMockObject> other_obj = some<CMockObject>();
auto other_obj = some<CMockObject>();

EXPECT_EQ(other_obj->ref_count(), 1);
obj->put("watched_object", dynamic_cast<CSGObject*>(other_obj.get()));
obj->put(Tag<CMockObject*>("watched_object"), other_obj.get());
EXPECT_EQ(other_obj->ref_count(), 2);
EXPECT_FALSE(other_obj->equals(obj));
obj = nullptr;
Expand Down

0 comments on commit d5c26ed

Please sign in to comment.