Skip to content

Commit

Permalink
Support MS and gradient availability in parameter map (#4060)
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Dec 30, 2017
1 parent 805ec8f commit 0bd36ee
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 62 deletions.
92 changes: 92 additions & 0 deletions src/shogun/base/AnyParameter.h
@@ -0,0 +1,92 @@
#ifndef __ANYPARAMETER_H__
#define __ANYPARAMETER_H__

#include <shogun/lib/any.h>

namespace shogun
{

/** model selection availability */
enum EModelSelectionAvailability
{
MS_NOT_AVAILABLE = 0,
MS_AVAILABLE = 1,
};

/** gradient availability */
enum EGradientAvailability
{
GRADIENT_NOT_AVAILABLE = 0,
GRADIENT_AVAILABLE = 1
};

class AnyParameterProperties
{
public:
AnyParameterProperties()
: m_model_selection(MS_NOT_AVAILABLE),
m_gradient(GRADIENT_NOT_AVAILABLE)
{
}
AnyParameterProperties(
EModelSelectionAvailability model_selection,
EGradientAvailability gradient)
: m_model_selection(model_selection), m_gradient(gradient)
{
}
AnyParameterProperties(const AnyParameterProperties& other)
: m_model_selection(other.m_model_selection),
m_gradient(other.m_gradient)
{
}

EModelSelectionAvailability get_model_selection() const
{
return m_model_selection;
}

EGradientAvailability get_gradient() const
{
return m_gradient;
}

private:
EModelSelectionAvailability m_model_selection;
EGradientAvailability m_gradient;
};

class AnyParameter
{
public:
AnyParameter() : m_value(), m_properties()
{
}
explicit AnyParameter(const Any& value) : m_value(value), m_properties()
{
}
AnyParameter(const Any& value, AnyParameterProperties properties)
: m_value(value), m_properties(properties)
{
}
AnyParameter(const AnyParameter& other)
: m_value(other.m_value), m_properties(other.m_properties)
{
}

Any get_value() const
{
return m_value;
}

AnyParameterProperties get_properties() const
{
return m_properties;
}

private:
Any m_value;
AnyParameterProperties m_properties;
};
}

#endif
46 changes: 13 additions & 33 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -36,27 +36,6 @@

namespace shogun
{
class AnyParameter
{
public:
AnyParameter() : m_value()
{
}
explicit AnyParameter(const Any& value) : m_value(value)
{
}
explicit AnyParameter(const AnyParameter& other)
: m_value(other.m_value)
{
}
Any value() const
{
return m_value;
}

private:
Any m_value;
};

typedef std::map<BaseTag, AnyParameter> ParametersMap;
typedef std::unordered_map<std::string,
Expand All @@ -66,16 +45,16 @@ namespace shogun
class CSGObject::Self
{
public:
void put(const BaseTag& tag, const Any& any)
void put(const BaseTag& tag, const AnyParameter& parameter)
{
map[tag] = AnyParameter(any);
map[tag] = parameter;
}

Any get(const BaseTag& tag) const
AnyParameter get(const BaseTag& tag) const
{
if(!has(tag))
return Any();
return map.at(tag).value();
return AnyParameter();
return map.at(tag);
}

bool has(const BaseTag& tag) const
Expand Down Expand Up @@ -809,23 +788,24 @@ bool CSGObject::clone_parameters(CSGObject* other)
return true;
}

void CSGObject::type_erased_put(const BaseTag& _tag, const Any& any)
void CSGObject::put_parameter(
const BaseTag& _tag, const AnyParameter& parameter)
{
self->put(_tag, any);
self->put(_tag, parameter);
}

Any CSGObject::type_erased_get(const BaseTag& _tag) const
AnyParameter CSGObject::get_parameter(const BaseTag& _tag) const
{
Any any = self->get(_tag);
if(any.empty())
const auto& parameter = self->get(_tag);
if (parameter.get_value().empty())
{
SG_ERROR("There is no parameter called \"%s\" in %s",
_tag.name().c_str(), get_name());
}
return any;
return parameter;
}

bool CSGObject::type_erased_has(const BaseTag& _tag) const
bool CSGObject::has_parameter(const BaseTag& _tag) const
{
return self->has(_tag);
}
Expand Down
50 changes: 22 additions & 28 deletions src/shogun/base/SGObject.h
Expand Up @@ -13,6 +13,7 @@
#ifndef __SGOBJECT_H__
#define __SGOBJECT_H__

#include <shogun/base/AnyParameter.h>
#include <shogun/base/Version.h>
#include <shogun/base/unique.h>
#include <shogun/io/SGIO.h>
Expand Down Expand Up @@ -79,15 +80,19 @@ template <class T> class SGStringList;
#define SG_ADD4(param, name, description, ms_available) \
{ \
m_parameters->add(param, name, description); \
watch_param(name, param); \
watch_param( \
name, param, \
AnyParameterProperties(ms_available, GRADIENT_NOT_AVAILABLE)); \
if (ms_available) \
m_model_selection_parameters->add(param, name, description); \
}

#define SG_ADD5(param, name, description, ms_available, gradient_available) \
{ \
m_parameters->add(param, name, description); \
watch_param(name, param); \
watch_param( \
name, param, \
AnyParameterProperties(ms_available, gradient_available)); \
if (ms_available) \
m_model_selection_parameters->add(param, name, description); \
if (gradient_available) \
Expand All @@ -100,19 +105,6 @@ template <class T> class SGStringList;
* End of macros for registering parameters/model selection parameters
******************************************************************************/

/** model selection availability */
enum EModelSelectionAvailability {
MS_NOT_AVAILABLE=0,
MS_AVAILABLE=1,
};

/** gradient availability */
enum EGradientAvailability
{
GRADIENT_NOT_AVAILABLE=0,
GRADIENT_AVAILABLE=1
};

/** @brief Class SGObject is the base class of all shogun objects.
*
* Apart from dealing with reference counting that is used to manage shogung
Expand Down Expand Up @@ -307,7 +299,7 @@ class CSGObject
*/
bool has(const std::string& name) const
{
return type_erased_has(BaseTag(name));
return has_parameter(BaseTag(name));
}

/** Checks if object has a class parameter identified by a Tag.
Expand All @@ -330,9 +322,9 @@ class CSGObject
bool has(const std::string& name) const
{
BaseTag tag(name);
if(!type_erased_has(tag))
if (!has_parameter(tag))
return false;
const Any value = type_erased_get(tag);
const Any value = get_parameter(tag).get_value();
return value.same_type<T>();
}

Expand All @@ -345,10 +337,10 @@ class CSGObject
template <typename T>
void put(const Tag<T>& _tag, const T& value)
{
if(type_erased_has(_tag))
if (has_parameter(_tag))
{
if(has<T>(_tag.name()))
type_erased_put(_tag, erase_type(value));
put_parameter(_tag, AnyParameter(erase_type(value)));
else
{
SG_ERROR("Type for parameter with name \"%s\" is not correct.\n",
Expand Down Expand Up @@ -384,7 +376,7 @@ class CSGObject
template <typename T>
T get(const Tag<T>& _tag) const
{
const Any value = type_erased_get(_tag);
const Any value = get_parameter(_tag).get_value();
try
{
return recall_type<T>(value);
Expand Down Expand Up @@ -476,7 +468,7 @@ class CSGObject
template <typename T>
void register_param(Tag<T>& _tag, const T& value)
{
type_erased_put(_tag, erase_type(value));
put_parameter(_tag, AnyParameter(erase_type(value)));
}

/** Registers a class parameter which is identified by a name.
Expand All @@ -491,14 +483,16 @@ class CSGObject
void register_param(const std::string& name, const T& value)
{
BaseTag tag(name);
type_erased_put(tag, erase_type(value));
put_parameter(tag, AnyParameter(erase_type(value)));
}

template <typename T>
void watch_param(const std::string& name, T* value)
void watch_param(
const std::string& name, T* value, AnyParameterProperties properties)
{
BaseTag tag(name);
type_erased_put(tag, erase_type_non_owning(value));
put_parameter(
tag, AnyParameter(erase_type_non_owning(value), properties));
}

public:
Expand Down Expand Up @@ -556,23 +550,23 @@ class CSGObject
* @param _tag name information of parameter
* @return true if the parameter exists with the input tag
*/
bool type_erased_has(const BaseTag& _tag) const;
bool has_parameter(const BaseTag& _tag) const;

/** Registers and modifies a class parameter, identified by a BaseTag.
* Throws an exception if the class does not have such a parameter.
*
* @param _tag name information of parameter
* @param any value without type information of the parameter
*/
void type_erased_put(const BaseTag& _tag, const Any& any);
void put_parameter(const BaseTag& _tag, const AnyParameter& any);

/** Getter for a class parameter, identified by a BaseTag.
* Throws an exception if the class does not have such a parameter.
*
* @param _tag name information of parameter
* @return value of the parameter identified by the input tag
*/
Any type_erased_get(const BaseTag& _tag) const;
AnyParameter get_parameter(const BaseTag& _tag) const;

/** Gets an incremental hash of all parameters as well as the parameters of
* CSGObject children of the current object's parameters.
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/base/MockObject.h
Expand Up @@ -39,7 +39,10 @@ namespace shogun
register_param("int", m_integer);
register_param("float", decimal);

watch_param("watched_int", &m_watched);
watch_param(
"watched_int", &m_watched,
AnyParameterProperties(
MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE));
}

private:
Expand Down

0 comments on commit 0bd36ee

Please sign in to comment.