Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ObservedValue. #4552

Merged
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
dfe6f08
Removed SG_OBS_VALUE_TYPE enum.
geektoni Feb 21, 2019
7ade42f
Remove observable enum from SGObject
geektoni Feb 22, 2019
d7c01f9
Fix SGObject definition of ObsParamList
geektoni Feb 22, 2019
1e1aea3
Fix ParameterObserverInterface.
geektoni Feb 22, 2019
4dc16bc
Fix CrossValidation.
geektoni Feb 22, 2019
ad5bf08
Fix ParameterObserverCV.
geektoni Feb 22, 2019
fbfa32d
Fix typo inside ParameterObserverCV.
geektoni Feb 22, 2019
7c2971f
Fix TBOutputFormat unittest.
geektoni Feb 26, 2019
b1b10ba
Make ObservedValue inherith from SGObject.
geektoni Mar 5, 2019
40dd84e
Split ObservedValue into header and cpp.
geektoni Mar 6, 2019
d2e86d1
Fix errors.
geektoni Mar 6, 2019
0ee5027
Split ObservedValue into a templated version.
geektoni Mar 7, 2019
590c914
Rename method to get observable parameters.
geektoni Mar 11, 2019
bfc7d9b
Extends get() such to work also with char* (in case it is not an opti…
geektoni Mar 11, 2019
1f60f9e
Switch from SG_ADD to watch_param inside ObservedValue.
geektoni Mar 12, 2019
68590be
Add unit test for ObservedValue.
geektoni Mar 12, 2019
b264edd
Fix variable naming.
geektoni Mar 12, 2019
c8a2121
Fix wrong initialization of ObservedValue parameters.
geektoni Mar 12, 2019
3251a5d
Rename register_observable_param() to register_observable().
geektoni Mar 12, 2019
d390976
Enabled TBOutputFormat vector tests and change SG_ADD with watch_para…
geektoni Mar 12, 2019
bc029a9
Change SG_SPRINT with SG_PRINT inside ParameterObserverCV.
geektoni Mar 12, 2019
b504148
Add more documentation.
geektoni Mar 12, 2019
5a8a543
Fix code style.
geektoni Mar 12, 2019
0c9d1a9
Fix minor issues and styles problems.
geektoni Mar 14, 2019
6332560
Change CParameterObserverCV::get_num_observations() return type to si…
geektoni Mar 18, 2019
ff79e6a
Add utils function to convert from size_t to int32_t.
geektoni Mar 18, 2019
2acba5e
Revert "Add utils function to convert from size_t to int32_t."
geektoni Mar 19, 2019
f295ae5
Use the new safe_convert functions to return the number of observations.
geektoni Mar 19, 2019
c8d30c4
Make ObservedValue keep also an Any version of the stored value.
geektoni Mar 19, 2019
f1d327e
Last fixes for ObservedValue.h
geektoni Mar 20, 2019
87827dd
Fix bug when initializing ObservedValue.
geektoni Mar 21, 2019
1d5b0a3
Add a description for the observed value.
geektoni Mar 21, 2019
fce9d10
Fix wrong parameter order.
geektoni Mar 21, 2019
e601ae3
Fix initialization error inside unit tests.
geektoni Mar 21, 2019
a5f75e7
Fix code style.
geektoni Mar 21, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 15 additions & 42 deletions src/shogun/base/SGObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ namespace shogun
{
geektoni marked this conversation as resolved.
Show resolved Hide resolved

typedef std::map<BaseTag, AnyParameter> ParametersMap;
typedef std::unordered_map<
std::string, std::pair<SG_OBS_VALUE_TYPE, std::string>>
ObsParamsList;
typedef std::unordered_map<std::string, std::string> ObsParamsList;

class CSGObject::Self
{
Expand Down Expand Up @@ -774,44 +772,24 @@ void CSGObject::subscribe_to_parameters(ParameterObserverInterface* obs)
// Create an observable which emits values only if they are about
// parameters selected by the observable.
auto subscription = m_observable_params
->filter([obs](ObservedValue v) {
return obs->filter(v.get_name());
->filter([obs](Some<ObservedValue> v) {
return obs->filter(v->get<std::string>("name"));
})
.timestamp()
.subscribe(sub);
}

void CSGObject::observe(const ObservedValue value)
void CSGObject::observe(const Some<ObservedValue> value)
{
m_subscriber_params->on_next(value);
}

class CSGObject::ParameterObserverList
{
public:
void register_param(
const std::string& name, const SG_OBS_VALUE_TYPE type,
const std::string& description)
void register_param(const std::string& name, const std::string& description)
{
m_list_obs_params[name] = std::make_pair(type, description);
}

std::string type_name(SG_OBS_VALUE_TYPE type)
{
std::string value;
switch (type)
{
case TENSORBOARD:
value = std::string("Tensorboard");
break;
case CROSSVALIDATION:
value = std::string("CrossValidation");
break;
default:
value = std::string("Unknown");
break;
}
return value;
m_list_obs_params[name] = description;
}

ObsParamsList get_list() const
Expand All @@ -824,24 +802,19 @@ class CSGObject::ParameterObserverList
ObsParamsList m_list_obs_params;
};

void CSGObject::register_observable_param(
const std::string& name, const SG_OBS_VALUE_TYPE type,
const std::string& description)
void CSGObject::register_observable(
geektoni marked this conversation as resolved.
Show resolved Hide resolved
const std::string& name, const std::string& description)
{
param_obs_list->register_param(name, type, description);
param_obs_list->register_param(name, description);
}

void CSGObject::list_observable_parameters()
std::vector<std::string> CSGObject::observable_names()
{
SG_INFO("List of observable parameters of object %s\n", get_name());
SG_PRINT("------");
for (auto const& x : param_obs_list->get_list())
{
SG_PRINT(
"%s [%s]: %s\n", x.first.c_str(),
param_obs_list->type_name(x.second.first).c_str(),
x.second.second.c_str());
geektoni marked this conversation as resolved.
Show resolved Hide resolved
}
std::vector<std::string> list;
std::transform(param_obs_list->get_list().begin(),
param_obs_list->get_list().end(), list.begin(),
[](auto const &x ) {return x.first;});
return list;
}

bool CSGObject::has(const std::string& name) const
Expand Down
37 changes: 24 additions & 13 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <shogun/lib/common.h>
#include <shogun/lib/config.h>
#include <shogun/lib/exception/ShogunException.h>
#include <shogun/lib/parameter_observers/ObservedValue.h>
#include <shogun/lib/tag.h>

#include <map>
Expand All @@ -43,6 +42,7 @@ class Parallel;
class Parameter;
class CSerializableFile;
class ParameterObserverInterface;
class ObservedValue;
class CDynamicObjectArray;

template <class T, class K> class CMap;
Expand Down Expand Up @@ -124,14 +124,15 @@ class CSGObject
{
public:
/** Definition of observed subject */
typedef rxcpp::subjects::subject<ObservedValue> SGSubject;
typedef rxcpp::subjects::subject<Some<ObservedValue>> SGSubject;
/** Definition of observable */
typedef rxcpp::observable<ObservedValue,
rxcpp::dynamic_observable<ObservedValue>>
typedef rxcpp::observable<Some<ObservedValue>,
rxcpp::dynamic_observable<Some<ObservedValue>>>
SGObservable;
/** Definition of subscriber */
typedef rxcpp::subscriber<
ObservedValue, rxcpp::observer<ObservedValue, void, void, void, void>>
Some<ObservedValue>,
rxcpp::observer<Some<ObservedValue>, void, void, void, void>>
SGSubscriber;

/** default constructor */
Expand Down Expand Up @@ -568,9 +569,20 @@ class CSGObject
{
if (m_string_to_enum_map.find(_tag.name()) == m_string_to_enum_map.end())
{
SG_ERROR(
"There are no options for parameter %s::%s", get_name(),
_tag.name().c_str());
const Any value = get_parameter(_tag).get_value();
try
{
return any_cast<T>(value);
}
catch (const TypeMismatchException& exc)
{
SG_ERROR(
"Cannot get parameter %s::%s of type %s, incompatible "
"requested type %s or there are no options for parameter "
"%s::%s.\n",
get_name(), _tag.name().c_str(), exc.actual().c_str(),
exc.expected().c_str(), get_name(), _tag.name().c_str());
}
}
return string_enum_reverse_lookup(_tag.name(), get<machine_int_t>(_tag.name()));
}
Expand Down Expand Up @@ -646,7 +658,7 @@ class CSGObject
void subscribe_to_parameters(ParameterObserverInterface* obs);

/** Print to stdout a list of observable parameters */
void list_observable_parameters();
std::vector<std::string> observable_names();

/** Get string to enum mapping */
stringToEnumMapType get_string_to_enum_map() const
Expand Down Expand Up @@ -952,17 +964,16 @@ class CSGObject
* Observe a parameter value and emit them to observer.
* @param value Observed parameter's value
*/
void observe(const ObservedValue value);
void observe(const Some<ObservedValue> value);

/**
* Register which params this object can emit.
* @param name the param name
* @param type the param type
* @param description a user oriented description
*/
void register_observable_param(
const std::string& name, const SG_OBS_VALUE_TYPE type,
const std::string& description);
void register_observable(
const std::string& name, const std::string& description);

/** mapping from strings to enum for SWIG interface */
stringToEnumMapType m_string_to_enum_map;
Expand Down
10 changes: 6 additions & 4 deletions src/shogun/evaluation/CrossValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <shogun/evaluation/Evaluation.h>
#include <shogun/evaluation/SplittingStrategy.h>
#include <shogun/lib/List.h>
//#include <shogun/lib/parameter_observers/ObservedValue.h>
#include <shogun/lib/parameter_observers/ObservedValueTemplated.h>
#include <shogun/machine/Machine.h>
#include <shogun/mathematics/Statistics.h>

Expand Down Expand Up @@ -108,10 +110,10 @@ CEvaluationResult* CCrossValidation::evaluate_impl()
SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i])

/* Emit the value*/
std::string obs_value_name{"cross_validation_run"};
ObservedValue cv_data{i, obs_value_name, make_any(storage),
CROSSVALIDATION};
observe(cv_data);
observe(
ObservedValue::make_observation<CrossValidationStorage*>(
geektoni marked this conversation as resolved.
Show resolved Hide resolved
i, "cross_validation_run", storage));

SG_UNREF(storage)
}

Expand Down
29 changes: 10 additions & 19 deletions src/shogun/io/TBOutputFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,14 @@

using namespace shogun;

#define CHECK_TYPE(type) \
else if ( \
value.first.get_value().type_info().hash_code() == \
geektoni marked this conversation as resolved.
Show resolved Hide resolved
typeid(type).hash_code()) \
{ \
summaryValue->set_simple_value( \
any_cast<type>(value.first.get_value())); \
}

#define CHECK_TYPE_HISTO(type) \
else if ( \
value.first.get_value().type_info().hash_code() == \
value.first->get_any().type_info().hash_code() == \
typeid(type).hash_code()) \
{ \
tensorflow::histogram::Histogram h; \
tensorflow::HistogramProto* hp = new tensorflow::HistogramProto(); \
auto v = any_cast<type>(value.first.get_value()); \
auto v = any_cast<type>(value.first->get_any()); \
for (auto value_v : v) \
h.Add(value_v); \
h.EncodeToProto(hp, true); \
Expand All @@ -77,18 +68,18 @@ tensorflow::Event TBOutputFormat::convert_scalar(
tensorflow::Event e;
std::time_t now_t = convert_to_millis(value.second);
e.set_wall_time(now_t);
e.set_step(value.first.get_step());
e.set_step(value.first->get<int64_t>("step"));

tensorflow::Summary* summary = e.mutable_summary();
auto summaryValue = summary->add_value();
summaryValue->set_tag(value.first.get_name());
summaryValue->set_tag(value.first->get<std::string>("name"));
summaryValue->set_node_name(node_name);

auto write_summary = [&summaryValue=summaryValue](auto val) {
summaryValue->set_simple_value(val);
};

sg_any_dispatch(value.first.get_value(), sg_all_typemap, write_summary);
sg_any_dispatch(value.first->get_any(), sg_all_typemap, write_summary);

return e;
}
Expand All @@ -99,19 +90,19 @@ tensorflow::Event TBOutputFormat::convert_vector(
tensorflow::Event e;
std::time_t now_t = convert_to_millis(value.second);
e.set_wall_time(now_t);
e.set_step(value.first.get_step());
e.set_step(value.first->get<int64_t>("step"));

tensorflow::Summary* summary = e.mutable_summary();
auto summaryValue = summary->add_value();
summaryValue->set_tag(value.first.get_name());
summaryValue->set_tag(value.first->get<std::string>("name"));
geektoni marked this conversation as resolved.
Show resolved Hide resolved
summaryValue->set_node_name(node_name);

if (value.first.get_value().type_info().hash_code() ==
if (value.first->get_any().type_info().hash_code() ==
typeid(std::vector<int8_t>).hash_code())
{
tensorflow::histogram::Histogram h;
tensorflow::HistogramProto* hp = new tensorflow::HistogramProto();
auto v = any_cast<std::vector<int8_t>>(value.first.get_value());
auto v = any_cast<std::vector<int8_t>>(value.first->get_any());
for (auto value_v : v)
h.Add(value_v);
h.EncodeToProto(hp, true);
Expand All @@ -131,7 +122,7 @@ tensorflow::Event TBOutputFormat::convert_vector(
else
{
SG_ERROR(
"Unsupported type %s", value.first.get_value().type_info().name());
"Unsupported type %s", value.first->get_any().type_info().name());
}

return e;
Expand Down
20 changes: 20 additions & 0 deletions src/shogun/lib/parameter_observers/ObservedValue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Giovanni De Toni
*
*/

#include <shogun/base/Parameter.h>
#include <shogun/lib/parameter_observers/ObservedValue.h>

using namespace shogun;

ObservedValue::ObservedValue(int64_t step, std::string name)
: CSGObject(), m_step(step), m_name(name), m_any_value(make_any(nullptr))
{
SG_ADD(&m_step, "step", "Step");
this->watch_param(
"name", &m_name,
AnyParameterProperties("Name of the observed value"));
}
Loading