Skip to content

Commit

Permalink
draft of CSGObject::put
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Feb 20, 2018
1 parent 99b8ddb commit 9bd758b
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 53 deletions.
42 changes: 42 additions & 0 deletions examples/meta/src/base_api/put_get.sg
@@ -0,0 +1,42 @@
KNN knn()
knn.put("m_k", 2)

Kernel k = kernel("GaussianKernel")
k.put("log_width", 2.1)
k.put("log_width", 2.0)

RealVector vector(2)
vector[0] = 0.0
vector[1] = 0.1

RegressionLabels labels()
labels.put("labels", vector)

RealVector vector2(1)
vector[0] = 0.1
# FIXME Octave treats this as a scalar
#labels.put("labels", vector)

RealMatrix matrix(2,2)
matrix[0,0] = 0.0
matrix[0,1] = 0.1
matrix[1,0] = 0.2
matrix[1,1] = 0.4

RealFeatures features()
features.put("feature_matrix", matrix)

EuclideanDistance distance()
knn.put("distance", distance)

RealMatrix matrix2(2,1)
matrix2[0,0] = 0.0
matrix2[1,0] = 0.2
# FIXME Octave treats this as a vector
#features.put("feature_matrix", matrix2)

# FIXME: Octave: CDynamicObjectArray::append_element_matrix should accept scalars
#RealMatrix matrix3(1,1)
#matrix3[0,0] = 0.1
#features.put("feature_matrix", matrix2)

16 changes: 16 additions & 0 deletions src/interfaces/swig/shogun.i
Expand Up @@ -119,3 +119,19 @@
#if defined(SWIGPERL)
%include "abstract_types_extension.i"
#endif

namespace shogun
{
%template(put) CSGObject::put_scalar<int32_t, int32_t>;
%template(put) CSGObject::put_scalar<int64_t, int64_t>;
%template(put) CSGObject::put_scalar<float64_t, float64_t>;


#ifndef SWIGJAVA
%template(put) CSGObject::put<SGVector<float64_t>, SGVector<float64_t>>;
%template(put) CSGObject::put<SGMatrix<float64_t>, SGMatrix<float64_t>>;
#else // SWIGJAVA
%template(put) CSGObject::put_vector_or_matrix<SGMatrix<float64_t>, SGMatrix<float64_t>>;
#endif // SWIGJAVA

} // namespace shogun
80 changes: 42 additions & 38 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -35,6 +35,11 @@
#include <unordered_map>
#include <memory>

#include <shogun/kernel/Kernel.h>
#include <shogun/labels/Labels.h>
#include <shogun/features/Features.h>
#include <shogun/distance/Distance.h>

namespace shogun
{

Expand Down Expand Up @@ -1013,46 +1018,45 @@ CSGObject* CSGObject::create_empty() const
return object;
}

namespace shogun
void CSGObject::put(const std::string& name, CSGObject* value)
{
#define SGOBJECT_PUT_DEFINE(T) \
void CSGObject::put(const std::string& name, T const& value) throw( \
ShogunException) \
{ \
Tag<T> tag(name); \
put(tag, value); \
REQUIRE(value, "Cannot set %s::%s, no object provided.\n", get_name(), name.c_str());


if (dynamic_cast<CKernel*>(value))
put(Tag<CKernel*>(name), (CKernel*) value);
else if (dynamic_cast<CDistance*>(value))
{
if (has<CDistance*>(name))
{
SG_REF(value);
CDistance* old = get<CDistance*>(name);
SG_UNREF(old);
}
put(Tag<CDistance*>(name), (CDistance*) value);
}
else
{
SG_WARNING("Could not match %s with any base-type when putting %s::%s, trying as SGObject.\n",value->get_name(),get_name(), name.c_str());
put(Tag<CSGObject*>(name), value);
}
}

SGOBJECT_PUT_DEFINE(SGVector<int32_t>)
SGOBJECT_PUT_DEFINE(SGVector<float64_t>)
SGOBJECT_PUT_DEFINE(CSGObject*)
namespace shogun
{
// TODO: Move this to SGBase.i and make it SGMatrix<T> (T=float64) rather than T=SGMatrix<float64>
template<>
void CSGObject::put_vector_or_matrix(const std::string& name, SGMatrix<float64_t> value)
{
Tag<SGVector<float64_t>> tag_vec(name);
Tag<SGMatrix<float64_t>> tag_mat(name);

#define PUT_DEFINE_CHECK_AND_CAST(T) \
else if (has(Tag<T>(name))) put(Tag<T>(name), (T)value);
if (has(tag_mat))
put(tag_mat, value);
else if ((value.num_rows==1 || value.num_cols==1) && has(tag_vec))
put(tag_vec, SGVector<float64_t>(value.data()));
else
put(tag_mat, value);
}
}

/* Some target languages have problems with scalar numeric types, so allow to
* convert all int/float types into each other.
*
* For example, Octave treats a=1.0 as an integer, and b=1.1 as a float.
* Furthermore, if a user wants to set a registered 16bit integer using a
* literal obj.put("16-bit-var", 2), might complain about a wrong type since
* internally the int literal is represented at a different word length. */
#define SGOBJECT_PUT_DEFINE_WITH_CONVERSION(numeric_t) \
void CSGObject::put( \
const std::string& name, \
numeric_t const& value) throw(ShogunException) \
{ \
/* use correct type of possible, otherwise cast-convert */ \
if (has(Tag<numeric_t>(name))) \
put(Tag<numeric_t>(name), value); \
PUT_DEFINE_CHECK_AND_CAST(int32_t) \
PUT_DEFINE_CHECK_AND_CAST(float32_t) \
PUT_DEFINE_CHECK_AND_CAST(float64_t) \
else /* if nothing works, moan about original type */ \
put(Tag<numeric_t>(name), value); \
}

SGOBJECT_PUT_DEFINE_WITH_CONVERSION(int32_t)
SGOBJECT_PUT_DEFINE_WITH_CONVERSION(float32_t)
SGOBJECT_PUT_DEFINE_WITH_CONVERSION(float64_t)
};
64 changes: 49 additions & 15 deletions src/shogun/base/SGObject.h
Expand Up @@ -368,21 +368,55 @@ class CSGObject
}
}

#define SGOBJECT_PUT_DECLARE(T) \
/** Setter for a class parameter, identified by a name. \
* Throws an exception if the class does not have such a parameter. \
* \
* @param name name of the parameter \
* @param value value of the parameter along with type information \
*/ \
void put(const std::string& name, T const& value) throw(ShogunException);

SGOBJECT_PUT_DECLARE(int32_t)
SGOBJECT_PUT_DECLARE(float32_t)
SGOBJECT_PUT_DECLARE(float64_t)
SGOBJECT_PUT_DECLARE(SGVector<int32_t>)
SGOBJECT_PUT_DECLARE(SGVector<float64_t>)
SGOBJECT_PUT_DECLARE(CSGObject*)
/** Untyped setter for an object class parameter, identified by a name.
* Will attempt to convert passed object to appropriate type.
*
* @param name name of the parameter
* @param value value of the parameter
*/
void put(const std::string& name, CSGObject* value);

/** Typed setter for a non-object class parameter, identified by a name.
*
* @param name name of the parameter
* @param value value of the parameter along with type information
*/
template <typename T, typename T2 = typename std::enable_if<!std::is_base_of<CSGObject, typename std::remove_pointer<T>::type>::value, T>::type>
void put(const std::string& name, T value)
{
put(Tag<T>(name), value);
}

// FIXME: move to swig interface.i, can be moved once the typemaps match
// also should be void put_vector_or_matrix(const std::string& name, SGMatrix<T> value);
template <typename T, typename T2 = typename std::enable_if<!std::is_base_of<CSGObject, typename std::remove_pointer<T>::type>::value, T>::type>
void put_vector_or_matrix(const std::string& name, T value);

/** Untyped setter for a scalar class parameter, identified by a name.
* Will attempt to convert passed scalar to appropriate type.
*
* @param name name of the parameter
* @param value value of the parameter
*/
template <typename T, typename U= typename std::enable_if_t<std::is_arithmetic<T>::value>>
void put_scalar(const std::string& name, T value)
{
Tag<T> tag_t(name);
Tag<int32_t> tag_int32(name);
Tag<int64_t> tag_int64(name);
Tag<float64_t> tag_float64(name);

if (has(tag_t))
put(tag_t, value);
else if (has(tag_int32))
put(tag_int32, (int32_t)value);
else if (has(tag_int64))
put(tag_int64, (int64_t)value);
else if (has(tag_float64))
put(tag_float64, (float64_t)value);
else
put(tag_t, value);
}

/** Getter for a class parameter, identified by a Tag.
* Throws an exception if the class does not have such a parameter.
Expand Down

0 comments on commit 9bd758b

Please sign in to comment.