Skip to content

Commit

Permalink
Add Tags framework
Browse files Browse the repository at this point in the history
migrate code from aer and add basetag

add error handling and sgobject.has

add try-catch, has_type() and split unit tests

add all_parameters() and rename unit-tests

add doc, use set in unit-test and modify hashing in basetag

add type() in tag, rephrase doc and fix indentations

add doc and make ctors explicit

remove type.h, add dummy typename in template functions

add MockObject and add()

add any unit-tests

add operator!= in any.h and more unit-tests

refine docs and remove all_parameters()
  • Loading branch information
sanuj committed Jul 3, 2016
1 parent 0939e27 commit 2ac1610
Show file tree
Hide file tree
Showing 14 changed files with 960 additions and 24 deletions.
51 changes: 47 additions & 4 deletions src/shogun/base/SGObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include <shogun/lib/RefCount.h>

#include <shogun/base/SGObject.h>
#include <shogun/io/SGIO.h>
#include <shogun/base/Version.h>
#include <shogun/base/Parameter.h>
#include <shogun/base/DynArray.h>
Expand All @@ -28,10 +27,33 @@

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <unordered_map>

namespace shogun
{
struct CSGObject::Self
{

void set(const BaseTag& tag, const Any& any)
{
map[tag] = any;
}

Any get(const BaseTag& tag) const
{
if(!has(tag))
return Any();
return map.at(tag);
}

bool has(const BaseTag& tag) const
{
return map.find(tag) != map.end();
}

std::unordered_map<BaseTag, Any> map;
};

class Parallel;

extern Parallel* sg_parallel;
Expand Down Expand Up @@ -117,7 +139,7 @@ namespace shogun

using namespace shogun;

CSGObject::CSGObject()
CSGObject::CSGObject() : self()
{
init();
set_global_objects();
Expand All @@ -127,7 +149,7 @@ CSGObject::CSGObject()
}

CSGObject::CSGObject(const CSGObject& orig)
:io(orig.io), parallel(orig.parallel), version(orig.version)
: self(), io(orig.io), parallel(orig.parallel), version(orig.version)
{
init();
set_global_objects();
Expand Down Expand Up @@ -741,3 +763,24 @@ CSGObject* CSGObject::clone()
SG_DEBUG("leaving %s::clone(): Clone successful\n", get_name());
return copy;
}

void CSGObject::set_with_base_tag(const BaseTag& _tag, const Any& any)
{
self->set(_tag, any);
}

Any CSGObject::get_with_base_tag(const BaseTag& _tag) const
{
Any any = self->get(_tag);
if(any.empty())
{
SG_ERROR("There is no parameter called \"%s\" in %s",
_tag.name().c_str(), get_name());
}
return any;
}

bool CSGObject::has_with_base_tag(const BaseTag& _tag) const
{
return self->has(_tag);
}
170 changes: 169 additions & 1 deletion src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
#define __SGOBJECT_H__

#include <shogun/lib/config.h>

#include <shogun/lib/common.h>
#include <shogun/lib/DataType.h>
#include <shogun/lib/ShogunException.h>
#include <shogun/base/Version.h>
#include <shogun/base/unique.h>
#include <shogun/io/SGIO.h>
#include <shogun/lib/tag.h>
#include <shogun/lib/any.h>

/** \namespace shogun
* @brief all of classes and functions are contained in the shogun namespace
Expand Down Expand Up @@ -278,6 +281,116 @@ class CSGObject
*/
void build_gradient_parameter_dictionary(CMap<TParameter*, CSGObject*>* dict);

/** Checks if object has a class parameter identified by a name.
*
* @param name name of the parameter
* @return true if the parameter exists with the input name
*/
bool has(const std::string& name) const
{
BaseTag tag(name);
return has_with_base_tag(tag);
}

/** Checks if object has a class parameter identified by a Tag.
*
* @param tag tag of the parameter containing name and type information
* @return true if the parameter exists with the input tag
*/
template <typename T>
bool has(const Tag<T>& tag) const
{
return has<T>(tag.name());
}

/** Checks if a type exists for a class parameter identified by a name.
*
* @param name name of the parameter
* @return true if the parameter exists with the input name and type
*/
template <typename T, typename U=void>
bool has(const std::string& name) const
{
BaseTag tag(name);
if(!has_with_base_tag(tag))
return false;
const Any value = get_with_base_tag(tag);
return value.sameType<T>();
}

/** Setter for a class parameter, identified by a Tag.
* Throws an exception if the class does not have such a parameter.
*
* @param _tag name and type information of parameter
* @param value value of the parameter
*/
template <typename T>
void set(const Tag<T>& _tag, const T& value)
{
if(has_with_base_tag(_tag))
{
if(has<T>(_tag.name()))
set_with_base_tag(_tag, erase_type(value));
else
{
SG_ERROR("Type for parameter with name \"%s\" is not correct.\n",
_tag.name().c_str());
}
}
else
{
SG_ERROR("\"%s\" does not have a parameter with name \"%s\".\n",
_tag.name().c_str(), get_name());
}
}

/** Setter for a class parameter, identified by a name.
* Throws an exception if the class does not have such a parameter.
*
* @param name name of the parameter
* @param value value of the parameter along with type information
*/
template <typename T, typename U=void>
void set(const std::string& name, const T& value)
{
Tag<T> tag(name);
set(tag, value);
}

/** Getter for a class parameter, identified by a Tag.
* Throws an exception if the class does not have such a parameter.
*
* @param _tag name and type information of parameter
* @return value of the parameter identified by the input tag
*/
template <typename T>
T get(const Tag<T>& _tag) const
{
const Any value = get_with_base_tag(_tag);
try
{
return recall_type<T>(value);
}
catch(std::logic_error)
{
SG_ERROR("Type for parameter with name \"%s\" is not correct in \"%s\".\n",
_tag.name().c_str(), get_name());
}
}

/** Getter for a class parameter, identified by a name.
* Throws an exception if the class does not have such a parameter.
*
* @param name name of the parameter
* @return value of the parameter corresponding to the input name and type
*/
template <typename T, typename U=void>
T get(const std::string& name) const
{
Tag<T> tag(name);
return get(tag);
}

protected:
/** Can (optionally) be overridden to pre-initialize some member
* variables which are not PARAMETER::ADD'ed. Make sure that at
Expand Down Expand Up @@ -315,6 +428,33 @@ class CSGObject
*/
virtual void save_serializable_post() throw (ShogunException);

/** Registers a class parameter which is identified by a tag.
* This enables the parameter to be modified by set() and retrieved by get().
* Parameters can be registered in the constructor of the class.
*
* @param _tag name and type information of parameter
* @param value value of the parameter
*/
template <typename T>
void register_param(Tag<T>& _tag, const T& value)
{
set_with_base_tag(_tag, erase_type(value));
}

/** Registers a class parameter which is identified by a name.
* This enables the parameter to be modified by set() and retrieved by get().
* Parameters can be registered in the constructor of the class.
*
* @param name name of the parameter
* @param value value of the parameter along with type information
*/
template <typename T>
void register_param(const std::string& name, const T& value)
{
BaseTag tag(name);
set_with_base_tag(tag, erase_type(value));
}

public:
/** Updates the hash of current parameter combination */
virtual void update_parameter_hash();
Expand Down Expand Up @@ -353,6 +493,31 @@ class CSGObject
void unset_global_objects();
void init();

/** Checks if object has a parameter identified by a BaseTag.
* This only checks for name and not type information.
* See its usage in has() and has<T>().
*
* @param _tag name information of parameter
* @return true if the parameter exists with the input tag
*/
bool has_with_base_tag(const BaseTag& _tag) const;

/** Registers and modifies a class parameter, identified by a BaseTag.
* Throws an exception if the class does not have such a parameter.
*
* @param _tag name information of parameter
* @param any value without type information of the parameter
*/
void set_with_base_tag(const BaseTag& _tag, const Any& any);

/** Getter for a class parameter, identified by a BaseTag.
* Throws an exception if the class does not have such a parameter.
*
* @param _tag name information of parameter
* @return value of the parameter identified by the input tag
*/
Any get_with_base_tag(const BaseTag& _tag) const;

/** Gets an incremental hash of all parameters as well as the parameters of
* CSGObject children of the current object's parameters.
*
Expand All @@ -364,6 +529,9 @@ class CSGObject
void get_parameter_incremental_hash(uint32_t& hash, uint32_t& carry,
uint32_t& total_length);

class Self;
Unique<Self> self;

public:
/** io */
SGIO* io;
Expand Down
Loading

0 comments on commit 2ac1610

Please sign in to comment.