From 87d66152dd82ffd3ab646b91b645619abdb6a11d Mon Sep 17 00:00:00 2001 From: Heiko Strathmann Date: Tue, 20 Feb 2018 23:08:53 +0000 Subject: [PATCH] add support for generic "get" for both cpp and swig --- examples/meta/generator/targets/cpp.json | 4 ++- examples/meta/src/base_api/put_get.sg | 25 ++++++++---------- src/interfaces/swig/shogun.i | 29 ++++++++++++++++----- src/shogun/base/SGObject.cpp | 30 ++++++++++++++++++--- src/shogun/base/SGObject.h | 33 +++++++++++++++++++----- 5 files changed, 89 insertions(+), 32 deletions(-) diff --git a/examples/meta/generator/targets/cpp.json b/examples/meta/generator/targets/cpp.json index 2281990c9c8..b9b02afc9df 100644 --- a/examples/meta/generator/targets/cpp.json +++ b/examples/meta/generator/targets/cpp.json @@ -92,7 +92,9 @@ "FloatLiteral": "${number}f", "MethodCall": { "Default": "$object->$method($arguments)", - "get_real": "$object->get($arguments)" + "get_real": "$object->get($arguments)", + "get_real_vector": "$object->get>($arguments)", + "get_real_matrix": "$object->get>($arguments)" }, "StaticCall": "C$typeName::$method($arguments)", "GlobalCall": "$method($arguments)", diff --git a/examples/meta/src/base_api/put_get.sg b/examples/meta/src/base_api/put_get.sg index 3e8f4984022..7be6d09a3df 100644 --- a/examples/meta/src/base_api/put_get.sg +++ b/examples/meta/src/base_api/put_get.sg @@ -4,6 +4,7 @@ knn.put("m_k", 2) Kernel k = kernel("GaussianKernel") k.put("log_width", 2.1) k.put("log_width", 2.0) +real log_width = k.get_real("log_width") RealVector vector(2) vector[0] = 0.0 @@ -11,11 +12,8 @@ vector[1] = 0.1 RegressionLabels labels() labels.put("labels", vector) - -RealVector vector2(1) -vector[0] = 0.1 -# FIXME Octave treats this as a scalar -#labels.put("labels", vector) +RealVector vector2 = labels.get_real_vector("labels") +labels.put("labels", vector2) RealMatrix matrix(2,2) matrix[0,0] = 0.0 @@ -25,18 +23,17 @@ matrix[1,1] = 0.4 RealFeatures features() features.put("feature_matrix", matrix) +RealMatrix matrix2 = features.get_real_matrix("feature_matrix") +features.put("feature_matrix", matrix2) EuclideanDistance distance() knn.put("distance", distance) -RealMatrix matrix2(2,1) -matrix2[0,0] = 0.0 -matrix2[1,0] = 0.2 -# FIXME Octave treats this as a vector -#features.put("feature_matrix", matrix2) +SGObject distance2 = knn.get("distance") +knn.put("distance", distance2) -# FIXME: Octave: CDynamicObjectArray::append_element_matrix should accept scalars -#RealMatrix matrix3(1,1) -#matrix3[0,0] = 0.1 -#features.put("feature_matrix", matrix2) +LibSVM svm() +svm.put("kernel", k) +SGObject k2 = svm.get("kernel") +svm.put("kernel", k2) diff --git a/src/interfaces/swig/shogun.i b/src/interfaces/swig/shogun.i index f1ce36cb7f8..24091fd48a0 100644 --- a/src/interfaces/swig/shogun.i +++ b/src/interfaces/swig/shogun.i @@ -133,9 +133,7 @@ namespace shogun Tag tag_int64(name); Tag tag_float64(name); - if ($self->has(tag_t)) - $self->put(tag_t, value); - else if ($self->has(tag_int32)) + if ($self->has(tag_int32)) $self->put(tag_int32, (int32_t)value); else if ($self->has(tag_int64)) $self->put(tag_int64, (int64_t)value); @@ -152,13 +150,22 @@ namespace shogun Tag> tag_vec(name); Tag> tag_mat(name); - if ($self->has(tag_mat)) - $self->put(tag_mat, value); - else if ((value.num_rows==1 || value.num_cols==1) && $self->has(tag_vec)) - $self->put(tag_vec, SGVector(value.data())); + if ((value.num_rows==1 || value.num_cols==1) && $self->has(tag_vec)) + { + SGVector vec(value.data(), value.size(), false); + $self->put(tag_vec, vec); + } else $self->put(tag_mat, value); } + + template , T>::value, T>::type> + T get_vector_as_matrix_dispatcher(const std::string& name) + { + SGVector vec = $self->get>(name); + SGMatrix mat(vec.data(), 1, vec.vlen, false); + return mat; + } #endif // SWIGJAVA } @@ -176,4 +183,12 @@ namespace shogun %template(put) CSGObject::put_vector_or_matrix_dispatcher, SGMatrix>; #endif // SWIGJAVA +%template(get_real) CSGObject::get; +%template(get_real_matrix) CSGObject::get, void>; +#ifndef SWIGJAVA +%template(get_real_vector) CSGObject::get, void>; +#else // SWIGJAVA +%template(get_real_vector) CSGObject::get_vector_as_matrix_dispatcher, SGMatrix>; +#endif // SWIGJAVA + } // namespace shogun diff --git a/src/shogun/base/SGObject.cpp b/src/shogun/base/SGObject.cpp index def57d14184..47a21ace3c5 100644 --- a/src/shogun/base/SGObject.cpp +++ b/src/shogun/base/SGObject.cpp @@ -1028,14 +1028,38 @@ CSGObject* CSGObject::create_empty() const void CSGObject::put(const std::string& name, CSGObject* value) { - REQUIRE(value, "Cannot set %s::%s, no object provided.\n", get_name(), name.c_str()); + REQUIRE(value, "Cannot put %s::%s, no object provided.\n", get_name(), name.c_str()); if (put_sgobject_type_dispatcher(name, value)) return; if (put_sgobject_type_dispatcher(name, value)) return; + if (put_sgobject_type_dispatcher(name, value)) + return; + if (put_sgobject_type_dispatcher(name, value)) + return; + + + SG_ERROR("Cannot put object %s as parameter %s::%s of type %s, type does not match.\n", + value->get_name(), get_name(), name.c_str(), + self->map[BaseTag(name)].get_value().type().c_str()); - SG_WARNING("Could not match %s with any base-type when putting %s::%s, trying as SGObject.\n",value->get_name(),get_name(), name.c_str()); - put(Tag(name), value); +} + +CSGObject* CSGObject::get(const std::string& name) +{ + if (auto* result = get_sgobject_type_dispatcher(name)) + return result; + if (auto* result = get_sgobject_type_dispatcher(name)) + return result; + if (auto* result = get_sgobject_type_dispatcher(name)) + return result; + if (auto* result = get_sgobject_type_dispatcher(name)) + return result; + + + SG_ERROR("Cannot get parameter %s::%s of type %s as object, not object type.\n", + get_name(), name.c_str(), self->map[BaseTag(name)].get_value().type().c_str()); + return nullptr; } diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index bb46e9be948..1d5f6a06fc4 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -354,17 +354,16 @@ class CSGObject catch (const TypeMismatchException& exc) { SG_ERROR( - "Setting parameter %s::%s failed. Provided type is %s, but " - "actual type is %s.\n", - get_name(), _tag.name().c_str(), exc.expected().c_str(), - exc.actual().c_str()); + "Cannot set parameter %s::%s of type %s, incompatible provided type %s.\n", + get_name(), _tag.name().c_str(), + exc.actual().c_str(), exc.expected().c_str()); } ref_value(&value); update_parameter(_tag, make_any(value)); } else { - SG_ERROR("\"%s\" does not have a parameter with name \"%s\".\n", + SG_ERROR("Parameter %s::%s does not exist.\n", get_name(), _tag.name().c_str()); } } @@ -386,6 +385,19 @@ class CSGObject } return false; } + + template + CSGObject* get_sgobject_type_dispatcher(const std::string& name) + { + if (has(name)) + { + T* result = get(name); + SG_REF(result) + return (CSGObject*)result; + } + + return nullptr; + } #endif // SWIG /** Untyped setter for an object class parameter, identified by a name. @@ -396,6 +408,14 @@ class CSGObject */ void put(const std::string& name, CSGObject* value); + /** Untyped getter for an object class parameter, identified by a name. + * Will attempt to get specified object of appropriate internal type. + * + * @param name name of the parameter + * @return object parameter + */ + CSGObject* get(const std::string& name); + #ifndef SWIG /** Untyped setter for an object class parameter, identified by a name. * Will attempt to convert passed object to appropriate type. @@ -438,8 +458,7 @@ class CSGObject catch (const TypeMismatchException& exc) { SG_ERROR( - "Getting parameter %s::%s failed. Requested type is %s, " - "but actual type is %s.\n", + "Cannot get parameter %s::%s of type %s, incompatible requested type %s.\n", get_name(), _tag.name().c_str(), exc.actual().c_str(), exc.expected().c_str()); }