Skip to content

Commit

Permalink
Add get_step() method such to avoid calling put() with a step paramater.
Browse files Browse the repository at this point in the history
The get_step() method wills search for a registered parameter called
current_iteration. If it is found, it will be used such as a step value
for an observation (otherwise -1 will be used).
  • Loading branch information
geektoni committed Mar 31, 2019
1 parent 43469f4 commit 44b80c2
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/shogun/base/SGObject.h
Expand Up @@ -362,7 +362,7 @@ class CSGObject
template <typename T,
typename std::enable_if_t<is_string<T>::value>* = nullptr>
void
put(const Tag<T>& _tag, const T& value, int64_t step = -1) noexcept(false)
put(const Tag<T>& _tag, const T& value) noexcept(false)
{
std::string val_string(value);

Expand All @@ -384,7 +384,7 @@ class CSGObject

machine_int_t enum_value = string_to_enum[val_string];

put(Tag<machine_int_t>(_tag.name()), enum_value, step);
put(Tag<machine_int_t>(_tag.name()), enum_value);
}
#endif

Expand All @@ -397,9 +397,9 @@ class CSGObject
template <class T,
class X = typename std::enable_if<is_sg_base<T>::value>::type,
class Z = void>
void put(const std::string& name, T* value, int64_t step = -1)
void put(const std::string& name, T* value)
{
put(Tag<T*>(name), value, step);
put(Tag<T*>(name), value);
}

/** Typed appender for an object class parameter of a Shogun base class
Expand Down Expand Up @@ -968,6 +968,19 @@ class CSGObject
void register_observable(
const std::string& name, const std::string& description);

/**
* Get the current step for the observed values.
*/
SG_FORCED_INLINE int64_t get_step()
{
int64_t step = -1;
Tag<int64_t> tag("current_iteration");
if (has(tag)) {
step = get(tag);
}
return step;
}

/** mapping from strings to enum for SWIG interface */
stringToEnumMapType m_string_to_enum_map;

Expand Down Expand Up @@ -1118,7 +1131,7 @@ class ObservedValueTemplated : public ObservedValue

template <typename T,
typename std::enable_if_t<!is_string<T>::value>* = nullptr>
void CSGObject::put(const Tag<T>& _tag, const T& value, int64_t step) noexcept(
void CSGObject::put(const Tag<T>& _tag, const T& value) noexcept(
false)
{
if (has_parameter(_tag))
Expand All @@ -1145,7 +1158,7 @@ void CSGObject::put(const Tag<T>& _tag, const T& value, int64_t step) noexcept(
ref_value(value);
update_parameter(_tag, make_any(value));

observe<T>(step, _tag.name());
observe<T>(this->get_step(), _tag.name());
}
else
{
Expand Down

0 comments on commit 44b80c2

Please sign in to comment.