Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kwargs ~works~ for objects ~already~ #4154

Merged
merged 4 commits into from
Feb 7, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 6 additions & 9 deletions examples/meta/src/meta_api/kwargs.sg
Expand Up @@ -22,17 +22,14 @@
# This also fails (keywords not allowed outside initialisation variables)
#kernel_factory("GaussianKernel", a=glob_fun(ordinary_argument))

# doesnt work in ruby (TODO, don't allow in meta grammar)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sorig we will need to block this, ruby doesn't like it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does ruby say? I can disable nested global calls, but that may disable some nested calls that are actually useful.

Copy link
Member

@sorig sorig Feb 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're just using the wrong syntax. See here. Testing in #4158

# KernelMachine svm2 = kernel_machine("LibSVM", kernel=kernel("GaussianKernel"))



# Real example
# ------------

# The following lines fail in Octave because of a float literal issue
#KernelMachine svm = kernel_machine("LibSVM", C1=2.0)
#GaussianKernel k(log_width=2.0)
#k.put("log_width", 4.0)

# The following line fails in Python because of a bool literal issue
#GaussianKernel k2(log_width=14.0, lhs_equals_rhs=True)
Kernel k = kernel("GaussianKernel", log_width=2.0)
KernelMachine svm = kernel_machine("LibSVM", C1=1.0, kernel=k)

# The following line currently gives a type error
#svm.put("kernel", k)
19 changes: 0 additions & 19 deletions src/interfaces/swig/SGBase.i
Expand Up @@ -508,22 +508,3 @@ copy_reg._reconstructor=_sg_reconstructor
%}

#endif /* SWIGPYTHON */

%include <shogun/lib/basetag.h>
%include <shogun/lib/tag.h>
%include <shogun/base/SGObject.h>

%define SUPPORT_TAG(camel_type, short_type, type)
%template(Tag ## camel_type) shogun::Tag<type>;
%template(put) shogun::CSGObject::put<type>;
%template(put) shogun::CSGObject::put<type, void>;
%template(put) shogun::CSGObject::put<type>;
%template(get_ ## short_type) shogun::CSGObject::get<type, void>;
%template(has) shogun::CSGObject::has<type>;
%template(has_ ## short_type) shogun::CSGObject::has<type, void>;
%enddef

SUPPORT_TAG(String, string, std::string)
SUPPORT_TAG(Float64, float, float64_t)
SUPPORT_TAG(Int64, int, int64_t)
SUPPORT_TAG(Object, object, CSGObject*)
37 changes: 37 additions & 0 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -1012,3 +1012,40 @@ CSGObject* CSGObject::create_empty() const
SG_REF(object);
return object;
}

namespace shogun
{
#define SGOBJECT_PUT_DEFINE(T) \
void CSGObject::put(const std::string& name, T const & value) throw(ShogunException)\
{ \
Tag<T> tag(name); \
put(tag, value); \
}

SGOBJECT_PUT_DEFINE(SGVector<int32_t>)
SGOBJECT_PUT_DEFINE(SGVector<float64_t>)
SGOBJECT_PUT_DEFINE(CSGObject*)

#define PUT_DEFINE_CHECK_AND_CAST(T) \
else if (has(Tag<T>(name))) \
put(Tag<T>(name), (T)value);

#define SGOBJECT_PUT_DEFINE_NUMBER(numeric_t) \
void CSGObject::put(const std::string& name, numeric_t const & value) throw(ShogunException)\
{ \
/* use correct type of possible, otherwise cast-convert */ \
if (has(Tag<numeric_t>(name))) \
put(Tag<numeric_t>(name), value); \
PUT_DEFINE_CHECK_AND_CAST(int32_t) \
PUT_DEFINE_CHECK_AND_CAST(float32_t) \
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lisitsyn since it is only octave that cannot decide between float and int, maybe some of this could be done in swig....though I don't know how that would work....

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PUT_DEFINE_CHECK_AND_CAST(float64_t) \
else \
/* if nothing works, moan about original type */ \
put(Tag<numeric_t>(name), value); \
}

SGOBJECT_PUT_DEFINE_NUMBER(int32_t)
SGOBJECT_PUT_DEFINE_NUMBER(float32_t)
SGOBJECT_PUT_DEFINE_NUMBER(float64_t)

};
28 changes: 16 additions & 12 deletions src/shogun/base/SGObject.h
Expand Up @@ -364,18 +364,22 @@ class CSGObject
}
}

/** Setter for a class parameter, identified by a name.
* Throws an exception if the class does not have such a parameter.
*
* @param name name of the parameter
* @param value value of the parameter along with type information
*/
template <typename T, typename U = void>
void put(const std::string& name, const T& value) throw(ShogunException)
{
Tag<T> tag(name);
put(tag, value);
}

#define SGOBJECT_PUT_DECLARE(T) \
/** Setter for a class parameter, identified by a name. \
* Throws an exception if the class does not have such a parameter. \
* \
* @param name name of the parameter \
* @param value value of the parameter along with type information \
*/ \
void put(const std::string& name, T const & value) throw(ShogunException);

SGOBJECT_PUT_DECLARE(int32_t)
SGOBJECT_PUT_DECLARE(float32_t)
SGOBJECT_PUT_DECLARE(float64_t)
SGOBJECT_PUT_DECLARE(SGVector<int32_t>)
SGOBJECT_PUT_DECLARE(SGVector<float64_t>)
SGOBJECT_PUT_DECLARE(CSGObject*)

/** Getter for a class parameter, identified by a Tag.
* Throws an exception if the class does not have such a parameter.
Expand Down