Skip to content

Commit

Permalink
add support for generic "get" for both cpp and swig
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Feb 28, 2018
1 parent 0ab6ee5 commit 87d6615
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 32 deletions.
4 changes: 3 additions & 1 deletion examples/meta/generator/targets/cpp.json
Expand Up @@ -92,7 +92,9 @@
"FloatLiteral": "${number}f",
"MethodCall": {
"Default": "$object->$method($arguments)",
"get_real": "$object->get<float64_t>($arguments)"
"get_real": "$object->get<float64_t>($arguments)",
"get_real_vector": "$object->get<SGVector<float64_t>>($arguments)",
"get_real_matrix": "$object->get<SGMatrix<float64_t>>($arguments)"
},
"StaticCall": "C$typeName::$method($arguments)",
"GlobalCall": "$method($arguments)",
Expand Down
25 changes: 11 additions & 14 deletions examples/meta/src/base_api/put_get.sg
Expand Up @@ -4,18 +4,16 @@ knn.put("m_k", 2)
Kernel k = kernel("GaussianKernel")
k.put("log_width", 2.1)
k.put("log_width", 2.0)
real log_width = k.get_real("log_width")

RealVector vector(2)
vector[0] = 0.0
vector[1] = 0.1

RegressionLabels labels()
labels.put("labels", vector)

RealVector vector2(1)
vector[0] = 0.1
# FIXME Octave treats this as a scalar
#labels.put("labels", vector)
RealVector vector2 = labels.get_real_vector("labels")
labels.put("labels", vector2)

RealMatrix matrix(2,2)
matrix[0,0] = 0.0
Expand All @@ -25,18 +23,17 @@ matrix[1,1] = 0.4

RealFeatures features()
features.put("feature_matrix", matrix)
RealMatrix matrix2 = features.get_real_matrix("feature_matrix")
features.put("feature_matrix", matrix2)

EuclideanDistance distance()
knn.put("distance", distance)

RealMatrix matrix2(2,1)
matrix2[0,0] = 0.0
matrix2[1,0] = 0.2
# FIXME Octave treats this as a vector
#features.put("feature_matrix", matrix2)
SGObject distance2 = knn.get("distance")
knn.put("distance", distance2)

# FIXME: Octave: CDynamicObjectArray::append_element_matrix should accept scalars
#RealMatrix matrix3(1,1)
#matrix3[0,0] = 0.1
#features.put("feature_matrix", matrix2)
LibSVM svm()
svm.put("kernel", k)
SGObject k2 = svm.get("kernel")
svm.put("kernel", k2)

29 changes: 22 additions & 7 deletions src/interfaces/swig/shogun.i
Expand Up @@ -133,9 +133,7 @@ namespace shogun
Tag<int64_t> tag_int64(name);
Tag<float64_t> tag_float64(name);

if ($self->has(tag_t))
$self->put(tag_t, value);
else if ($self->has(tag_int32))
if ($self->has(tag_int32))
$self->put(tag_int32, (int32_t)value);
else if ($self->has(tag_int64))
$self->put(tag_int64, (int64_t)value);
Expand All @@ -152,13 +150,22 @@ namespace shogun
Tag<SGVector<float64_t>> tag_vec(name);
Tag<SGMatrix<float64_t>> tag_mat(name);

if ($self->has(tag_mat))
$self->put(tag_mat, value);
else if ((value.num_rows==1 || value.num_cols==1) && $self->has(tag_vec))
$self->put(tag_vec, SGVector<float64_t>(value.data()));
if ((value.num_rows==1 || value.num_cols==1) && $self->has(tag_vec))
{
SGVector<float64_t> vec(value.data(), value.size(), false);
$self->put(tag_vec, vec);
}
else
$self->put(tag_mat, value);
}

template <typename T, typename T2 = typename std::enable_if<std::is_same<SGMatrix<float64_t>, T>::value, T>::type>
T get_vector_as_matrix_dispatcher(const std::string& name)
{
SGVector<float64_t> vec = $self->get<SGVector<float64_t>>(name);
SGMatrix<float64_t> mat(vec.data(), 1, vec.vlen, false);
return mat;
}
#endif // SWIGJAVA
}

Expand All @@ -176,4 +183,12 @@ namespace shogun
%template(put) CSGObject::put_vector_or_matrix_dispatcher<SGMatrix<float64_t>, SGMatrix<float64_t>>;
#endif // SWIGJAVA

%template(get_real) CSGObject::get<float64_t, void>;
%template(get_real_matrix) CSGObject::get<SGMatrix<float64_t>, void>;
#ifndef SWIGJAVA
%template(get_real_vector) CSGObject::get<SGVector<float64_t>, void>;
#else // SWIGJAVA
%template(get_real_vector) CSGObject::get_vector_as_matrix_dispatcher<SGMatrix<float64_t>, SGMatrix<float64_t>>;
#endif // SWIGJAVA

} // namespace shogun
30 changes: 27 additions & 3 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -1028,14 +1028,38 @@ CSGObject* CSGObject::create_empty() const

void CSGObject::put(const std::string& name, CSGObject* value)
{
REQUIRE(value, "Cannot set %s::%s, no object provided.\n", get_name(), name.c_str());
REQUIRE(value, "Cannot put %s::%s, no object provided.\n", get_name(), name.c_str());

if (put_sgobject_type_dispatcher<CKernel>(name, value))
return;
if (put_sgobject_type_dispatcher<CDistance>(name, value))
return;
if (put_sgobject_type_dispatcher<CFeatures>(name, value))
return;
if (put_sgobject_type_dispatcher<CLabels>(name, value))
return;


SG_ERROR("Cannot put object %s as parameter %s::%s of type %s, type does not match.\n",
value->get_name(), get_name(), name.c_str(),
self->map[BaseTag(name)].get_value().type().c_str());

SG_WARNING("Could not match %s with any base-type when putting %s::%s, trying as SGObject.\n",value->get_name(),get_name(), name.c_str());
put(Tag<CSGObject*>(name), value);
}

CSGObject* CSGObject::get(const std::string& name)
{
if (auto* result = get_sgobject_type_dispatcher<CDistance>(name))
return result;
if (auto* result = get_sgobject_type_dispatcher<CKernel>(name))
return result;
if (auto* result = get_sgobject_type_dispatcher<CFeatures>(name))
return result;
if (auto* result = get_sgobject_type_dispatcher<CLabels>(name))
return result;


SG_ERROR("Cannot get parameter %s::%s of type %s as object, not object type.\n",
get_name(), name.c_str(), self->map[BaseTag(name)].get_value().type().c_str());
return nullptr;
}

33 changes: 26 additions & 7 deletions src/shogun/base/SGObject.h
Expand Up @@ -354,17 +354,16 @@ class CSGObject
catch (const TypeMismatchException& exc)
{
SG_ERROR(
"Setting parameter %s::%s failed. Provided type is %s, but "
"actual type is %s.\n",
get_name(), _tag.name().c_str(), exc.expected().c_str(),
exc.actual().c_str());
"Cannot set parameter %s::%s of type %s, incompatible provided type %s.\n",
get_name(), _tag.name().c_str(),
exc.actual().c_str(), exc.expected().c_str());
}
ref_value(&value);
update_parameter(_tag, make_any(value));
}
else
{
SG_ERROR("\"%s\" does not have a parameter with name \"%s\".\n",
SG_ERROR("Parameter %s::%s does not exist.\n",
get_name(), _tag.name().c_str());
}
}
Expand All @@ -386,6 +385,19 @@ class CSGObject
}
return false;
}

template <typename T>
CSGObject* get_sgobject_type_dispatcher(const std::string& name)
{
if (has<T*>(name))
{
T* result = get<T*>(name);
SG_REF(result)
return (CSGObject*)result;
}

return nullptr;
}
#endif // SWIG

/** Untyped setter for an object class parameter, identified by a name.
Expand All @@ -396,6 +408,14 @@ class CSGObject
*/
void put(const std::string& name, CSGObject* value);

/** Untyped getter for an object class parameter, identified by a name.
* Will attempt to get specified object of appropriate internal type.
*
* @param name name of the parameter
* @return object parameter
*/
CSGObject* get(const std::string& name);

#ifndef SWIG
/** Untyped setter for an object class parameter, identified by a name.
* Will attempt to convert passed object to appropriate type.
Expand Down Expand Up @@ -438,8 +458,7 @@ class CSGObject
catch (const TypeMismatchException& exc)
{
SG_ERROR(
"Getting parameter %s::%s failed. Requested type is %s, "
"but actual type is %s.\n",
"Cannot get parameter %s::%s of type %s, incompatible requested type %s.\n",
get_name(), _tag.name().c_str(), exc.actual().c_str(),
exc.expected().c_str());
}
Expand Down

0 comments on commit 87d6615

Please sign in to comment.