Skip to content

Commit

Permalink
move put_scalar to swig interface for good
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Feb 19, 2018
1 parent bc6fafe commit 339a3a5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
30 changes: 27 additions & 3 deletions src/interfaces/swig/shogun.i
Expand Up @@ -122,9 +122,33 @@

namespace shogun
{
%template(put) CSGObject::put_scalar<int32_t, int32_t>;
%template(put) CSGObject::put_scalar<int64_t, int64_t>;
%template(put) CSGObject::put_scalar<float64_t, float64_t>;

%extend CSGObject
{
template <typename T, typename U= typename std::enable_if_t<std::is_arithmetic<T>::value>>
void put_scalar_dispatcher(const std::string& name, T value)
{
Tag<T> tag_t(name);
Tag<int32_t> tag_int32(name);
Tag<int64_t> tag_int64(name);
Tag<float64_t> tag_float64(name);

if ($self->has(tag_t))
$self->put(tag_t, value);
else if ($self->has(tag_int32))
$self->put(tag_int32, (int32_t)value);
else if ($self->has(tag_int64))
$self->put(tag_int64, (int64_t)value);
else if ($self->has(tag_float64))
$self->put(tag_float64, (float64_t)value);
else
$self->put(tag_t, value);
}
}

%template(put) CSGObject::put_scalar_dispatcher<int32_t, int32_t>;
%template(put) CSGObject::put_scalar_dispatcher<int64_t, int64_t>;
%template(put) CSGObject::put_scalar_dispatcher<float64_t, float64_t>;


#ifndef SWIGJAVA
Expand Down
26 changes: 0 additions & 26 deletions src/shogun/base/SGObject.h
Expand Up @@ -426,32 +426,6 @@ class CSGObject
template <typename T, typename T2 = typename std::enable_if<!std::is_base_of<CSGObject, typename std::remove_pointer<T>::type>::value, T>::type>
void put_vector_or_matrix(const std::string& name, T value);

/** Untyped setter for a scalar class parameter, identified by a name.
* Will attempt to convert passed scalar to appropriate type.
*
* @param name name of the parameter
* @param value value of the parameter
*/
template <typename T, typename U= typename std::enable_if_t<std::is_arithmetic<T>::value>>
void put_scalar(const std::string& name, T value)
{
Tag<T> tag_t(name);
Tag<int32_t> tag_int32(name);
Tag<int64_t> tag_int64(name);
Tag<float64_t> tag_float64(name);

if (has(tag_t))
put(tag_t, value);
else if (has(tag_int32))
put(tag_int32, (int32_t)value);
else if (has(tag_int64))
put(tag_int64, (int64_t)value);
else if (has(tag_float64))
put(tag_float64, (float64_t)value);
else
put(tag_t, value);
}

/** Getter for a class parameter, identified by a Tag.
* Throws an exception if the class does not have such a parameter.
*
Expand Down

0 comments on commit 339a3a5

Please sign in to comment.