Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kwargs ~works~ for objects ~already~ #4154

Merged
merged 4 commits into from Feb 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 6 additions & 9 deletions examples/meta/src/meta_api/kwargs.sg
Expand Up @@ -22,17 +22,14 @@
# This also fails (keywords not allowed outside initialisation variables)
#kernel_factory("GaussianKernel", a=glob_fun(ordinary_argument))

# doesnt work in ruby (TODO, don't allow in meta grammar)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sorig we will need to block this, ruby doesn't like it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does ruby say? I can disable nested global calls, but that may disable some nested calls that are actually useful.

Copy link
Member

@sorig sorig Feb 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're just using the wrong syntax. See here. Testing in #4158

# KernelMachine svm2 = kernel_machine("LibSVM", kernel=kernel("GaussianKernel"))



# Real example
# ------------

# The following lines fail in Octave because of a float literal issue
#KernelMachine svm = kernel_machine("LibSVM", C1=2.0)
#GaussianKernel k(log_width=2.0)
#k.put("log_width", 4.0)

# The following line fails in Python because of a bool literal issue
#GaussianKernel k2(log_width=14.0, lhs_equals_rhs=True)
Kernel k = kernel("GaussianKernel", log_width=2.0)
KernelMachine svm = kernel_machine("LibSVM", C1=1.1, kernel=k)

# The following line currently gives a type error
#svm.put("kernel", k)
19 changes: 0 additions & 19 deletions src/interfaces/swig/SGBase.i
Expand Up @@ -508,22 +508,3 @@ copy_reg._reconstructor=_sg_reconstructor
%}

#endif /* SWIGPYTHON */

%include <shogun/lib/basetag.h>
%include <shogun/lib/tag.h>
%include <shogun/base/SGObject.h>

%define SUPPORT_TAG(camel_type, short_type, type)
%template(Tag ## camel_type) shogun::Tag<type>;
%template(put) shogun::CSGObject::put<type>;
%template(put) shogun::CSGObject::put<type, void>;
%template(put) shogun::CSGObject::put<type>;
%template(get_ ## short_type) shogun::CSGObject::get<type, void>;
%template(has) shogun::CSGObject::has<type>;
%template(has_ ## short_type) shogun::CSGObject::has<type, void>;
%enddef

SUPPORT_TAG(String, string, std::string)
SUPPORT_TAG(Float64, float, float64_t)
SUPPORT_TAG(Int64, int, int64_t)
SUPPORT_TAG(Object, object, CSGObject*)
44 changes: 44 additions & 0 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -1012,3 +1012,47 @@ CSGObject* CSGObject::create_empty() const
SG_REF(object);
return object;
}

namespace shogun
{
#define SGOBJECT_PUT_DEFINE(T) \
void CSGObject::put(const std::string& name, T const& value) throw( \
ShogunException) \
{ \
Tag<T> tag(name); \
put(tag, value); \
}

SGOBJECT_PUT_DEFINE(SGVector<int32_t>)
SGOBJECT_PUT_DEFINE(SGVector<float64_t>)
SGOBJECT_PUT_DEFINE(CSGObject*)

#define PUT_DEFINE_CHECK_AND_CAST(T) \
else if (has(Tag<T>(name))) put(Tag<T>(name), (T)value);

/* Some target languages have problems with scalar numeric types, so allow to
* convert all int/float types into each other.
*
* For example, Octave treats a=1.0 as an integer, and b=1.1 as a float.
* Furthermore, if a user wants to set a registered 16bit integer using a
* literal obj.put("16-bit-var", 2), might complain about a wrong type since
* internally the int literal is represented at a different word length. */
#define SGOBJECT_PUT_DEFINE_WITH_CONVERSION(numeric_t) \
void CSGObject::put( \
const std::string& name, \
numeric_t const& value) throw(ShogunException) \
{ \
/* use correct type of possible, otherwise cast-convert */ \
if (has(Tag<numeric_t>(name))) \
put(Tag<numeric_t>(name), value); \
PUT_DEFINE_CHECK_AND_CAST(int32_t) \
PUT_DEFINE_CHECK_AND_CAST(float32_t) \
PUT_DEFINE_CHECK_AND_CAST(float64_t) \
else /* if nothing works, moan about original type */ \
put(Tag<numeric_t>(name), value); \
}

SGOBJECT_PUT_DEFINE_WITH_CONVERSION(int32_t)
SGOBJECT_PUT_DEFINE_WITH_CONVERSION(float32_t)
SGOBJECT_PUT_DEFINE_WITH_CONVERSION(float64_t)
};
27 changes: 15 additions & 12 deletions src/shogun/base/SGObject.h
Expand Up @@ -364,18 +364,21 @@ class CSGObject
}
}

/** Setter for a class parameter, identified by a name.
* Throws an exception if the class does not have such a parameter.
*
* @param name name of the parameter
* @param value value of the parameter along with type information
*/
template <typename T, typename U = void>
void put(const std::string& name, const T& value) throw(ShogunException)
{
Tag<T> tag(name);
put(tag, value);
}
#define SGOBJECT_PUT_DECLARE(T) \
/** Setter for a class parameter, identified by a name. \
* Throws an exception if the class does not have such a parameter. \
* \
* @param name name of the parameter \
* @param value value of the parameter along with type information \
*/ \
void put(const std::string& name, T const& value) throw(ShogunException);

SGOBJECT_PUT_DECLARE(int32_t)
SGOBJECT_PUT_DECLARE(float32_t)
SGOBJECT_PUT_DECLARE(float64_t)
SGOBJECT_PUT_DECLARE(SGVector<int32_t>)
SGOBJECT_PUT_DECLARE(SGVector<float64_t>)
SGOBJECT_PUT_DECLARE(CSGObject*)

/** Getter for a class parameter, identified by a Tag.
* Throws an exception if the class does not have such a parameter.
Expand Down