Skip to content

Commit

Permalink
Merge pull request #3925 from geektoni/add_observable_param_list
Browse files Browse the repository at this point in the history
[ShogunBoard] Add methods to print a list of parameters which can be observed.
  • Loading branch information
vigsterkr committed Jul 12, 2017
2 parents 9be1025 + d08660b commit 52f9f20
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
54 changes: 44 additions & 10 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -31,19 +31,13 @@
#include <rxcpp/operators/rx-filter.hpp>
#include <rxcpp/rx-lite.hpp>

#ifdef HAVE_CXX11
#include <unordered_map>
#else
#include <map>
#endif

namespace shogun
{
#ifdef HAVE_CXX11
typedef std::unordered_map<BaseTag, Any> ParametersMap;
#else
typedef std::map<BaseTag, Any> ParametersMap;
#endif
typedef std::map<std::string, std::pair<std::string, std::string>>
ObsParamsList;

class CSGObject::Self
{
Expand Down Expand Up @@ -153,7 +147,7 @@ namespace shogun

using namespace shogun;

CSGObject::CSGObject() : self()
CSGObject::CSGObject() : self(), param_obs_list()
{
init();
set_global_objects();
Expand All @@ -163,7 +157,8 @@ CSGObject::CSGObject() : self()
}

CSGObject::CSGObject(const CSGObject& orig)
: self(), io(orig.io), parallel(orig.parallel), version(orig.version)
: self(), param_obs_list(), io(orig.io), parallel(orig.parallel),
version(orig.version)
{
init();
set_global_objects();
Expand Down Expand Up @@ -837,3 +832,42 @@ void CSGObject::observe_scalar(
auto tmp = std::make_pair(step, std::make_pair(name, value));
m_subscriber_params->on_next(tmp);
}

class CSGObject::ParameterObserverList
{
public:
void register_param(
const std::string& name, const std::string& type,
const std::string& description)
{
m_list_obs_params[name] = std::make_pair(type, description);
}

ObsParamsList get_list() const
{
return m_list_obs_params;
}

private:
/** List of observable parameters (name, description) */
ObsParamsList m_list_obs_params;
};

void CSGObject::register_observable_param(
const std::string& name, const std::string& type,
const std::string& description)
{
param_obs_list->register_param(name, type, description);
}

void CSGObject::list_observable_parameters()
{
SG_INFO("List of observable parameters of object %s\n", get_name());
SG_PRINT("------");
for (auto const& x : param_obs_list->get_list())
{
SG_PRINT(
"%s [%s]: %s\n", x.first.c_str(), x.second.first.c_str(),
x.second.second.c_str());
}
}
13 changes: 13 additions & 0 deletions src/shogun/base/SGObject.h
Expand Up @@ -424,6 +424,9 @@ class CSGObject
/** Subscribe a parameter observer to watch over params */
void subscribe_to_parameters(ParameterObserverInterface* obs);

/** Print to stdout a list of observable parameters */
void list_observable_parameters();

protected:
/** Can (optionally) be overridden to pre-initialize some member
* variables which are not PARAMETER::ADD'ed. Make sure that at
Expand Down Expand Up @@ -575,13 +578,23 @@ class CSGObject
class Self;
Unique<Self> self;

class ParameterObserverList;
Unique<ParameterObserverList> param_obs_list;

protected:
/** Observe the parameter and emits a value using the
* observable object
*/
void observe_scalar(
const int64_t step, const std::string& name, const Any& value);

/**
* Register a parameter as observable
*/
void register_observable_param(
const std::string& name, const std::string& type,
const std::string& description);

public:
/** io */
SGIO* io;
Expand Down

0 comments on commit 52f9f20

Please sign in to comment.