Skip to content

Commit

Permalink
added remaining factory functions to Python API (shogun-toolbox#4470)
Browse files Browse the repository at this point in the history
* added remaining factory functions to Python API
* add factory wrapper in loop
* fixes error handling inside _swig_monkey_patch
* simplified adding factory functions
* fixed string check logic
* changed python.json to accept kwargs as is, changes are described in shogun-toolbox#4128: Changing from translation type A) to B)

* minor changes to pipeline example to work with new python API
* keep track of getters automatically
  • Loading branch information
Gil authored and vigsterkr committed Mar 9, 2019
1 parent 0785b1e commit ea24d4c
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 60 deletions.
14 changes: 7 additions & 7 deletions examples/meta/generator/targets/python.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
"Statement": "$statement\n",
"Comment": "#$comment\n",
"Init": {
"Construct": "$name = $typeName($arguments)$kwargs",
"Copy": "$name = $expr$kwargs",
"Construct": "$name = $typeName($arguments$kwargs)",
"Copy": "$name = $expr",
"KeywordArguments": {
"List": "\n$elements",
"Element": "$name.put(\"$keyword\", $expr)",
"Separator": "\n",
"InitialSeperatorWhenArgs>0": false
"List": "$elements",
"Element": "$keyword=$expr",
"Separator": ", ",
"InitialSeperatorWhenArgs>0": true
},
"BoolVector": "$name = np.zeros( ($arguments), dtype='bool')",
"CharVector": "$name = np.zeros( ($arguments), dtype='|S1')",
Expand Down Expand Up @@ -64,7 +64,7 @@
"get_real_matrix": "$object.get($arguments)"
},
"StaticCall": "$typeName.$method($arguments)",
"GlobalCall": "$method($arguments)",
"GlobalCall": "$method($arguments$kwargs)",
"Identifier": "$identifier",
"Enum":"$value"
},
Expand Down
4 changes: 2 additions & 2 deletions examples/meta/src/pipeline/pipeline.sg
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Transformer pca = transformer("PCA")

#![create_machine]
Distance d = distance("EuclideanDistance", lhs=feats_train, rhs=feats_train)
KMeans kmeans(distance=d, k=2)
Machine kmeans = machine("KMeans", distance=d, k=2)
#![create_machine]


Expand All @@ -31,6 +31,6 @@ Labels labels_predict = pipeline.apply(feats_train)

# additional integration variables
#![extract_centers_and_radius]
RealMatrix c = kmeans.get_cluster_centers()
RealMatrix c = kmeans.get_real_matrix("cluster_centers")
RealVector r = kmeans.get_real_vector("radiuses")
#![extract_centers_and_radius]
112 changes: 109 additions & 3 deletions src/interfaces/python/swig_typemaps.i
Original file line number Diff line number Diff line change
Expand Up @@ -1292,16 +1292,85 @@ TYPEMAP_SPARSEFEATURES_OUT(PyObject, NPY_OBJECT)
PyErr_SetString(PyExc_SystemError, $1.what());
SWIG_fail;
}
%rename(_kernel) kernel;

%feature("nothread") _rename_python_function;
%feature("docstring", "Renames a Python function in the given module or class. \n"
"Similar functionality to SWIG's %rename.") _rename_python_function;

%typemap(out) void _rename_python_function "$result = PyErr_Occurred() ? NULL : SWIG_Py_Void();"
%inline %{
static void _rename_python_function(PyObject *type, PyObject *old_name, PyObject *new_name) {
PyObject *dict = NULL,
*func_obj = NULL;
#if PY_VERSION_HEX>=0x03000000
if (!PyUnicode_Check(old_name) || !PyUnicode_Check(new_name))
#else
if (!PyString_Check(old_name) || !PyString_Check(new_name))
#endif
{
PyErr_SetString(PyExc_TypeError, "'old_name' and 'new_name' have to be strings");
return;
}
if (PyType_Check(type)) {
PyTypeObject *pytype = (PyTypeObject *)type;
dict = pytype->tp_dict;
func_obj = PyDict_GetItem(dict, old_name);
if (func_obj == NULL) {
PyErr_SetString(PyExc_ValueError, "'old_name' name does not exist in the given type");
return;
}
}
else if ( PyModule_Check(type)) {
dict = PyModule_GetDict(type);
func_obj = PyDict_GetItem(dict, old_name);
if (func_obj == NULL) {
PyErr_SetString(PyExc_ValueError, "'old_name' does not exist in the given module");
return;
}
}
else {
PyErr_SetString(PyExc_ValueError, "'type' is neither a module or a Python type");
return;
}
if (PyDict_Contains(dict, new_name))
{
PyErr_SetString(PyExc_ValueError, "new_name already exists in the given scope");
return;
}
PyDict_SetItem(dict, new_name, func_obj);
PyDict_DelItem(dict, old_name);
}
%}

%pythoncode %{
import sys

_GETTERS = ["get",
"get_real",
"get_int",
"get_real_matrix",
"get_real_vector",
"get_int_vector"
]

_FACTORIES = ["distance",
"evaluation",
"kernel",
"machine",
"multiclass_strategy",
"ecoc_encoder",
"ecoc_decoder",
"transformer",
"layer"
]

def _internal_factory_wrapper(object_name, new_name, docstring=None):
"""
A wrapper that returns a generic factory that
accepts kwargs and passes them to shogun.object_name
via .put
"""
_obj = getattr(_shogun, object_name)
_obj = getattr(sys.modules[__name__], object_name)
def _internal_factory(name, **kwargs):

new_obj = _obj(name)
Expand All @@ -1312,10 +1381,47 @@ def _internal_factory_wrapper(object_name, new_name, docstring=None):
_internal_factory.__doc__ = docstring
else:
_internal_factory.__doc__ = _obj.__doc__.replace(object_name, new_name)
_internal_factory.__qualname__ = new_name

return _internal_factory

kernel = _internal_factory_wrapper("_kernel", "kernel")
for factory in _FACTORIES:
# renames function in the current module (shogun) from `factory` to "_" + `factory`
# which "hides" it from the user
factory_private_name = "_{}".format(factory)
_rename_python_function(sys.modules[__name__], factory, factory_private_name)
# adds a new function called `factory` to the shogun module which is a wrapper
# that passes **kwargs to objects via .put (see _internal_factory_wrapper)
_swig_monkey_patch(sys.modules[__name__], factory, _internal_factory_wrapper(factory_private_name, factory))

# makes all the SGObject getters defined in _GETTERS private
_internal_getter_methods = []
for getter in _GETTERS:
_private_getter = "_{}".format(getter)
_rename_python_function(_shogun.SGObject, getter, _private_getter)
_internal_getter_methods.append(_shogun.SGObject.__dict__[_private_getter])

def _internal_get_param(self, name):
"""
Returns the value of the given parameter.
The return type depends on the parameter,
e.g. could be a builtin scalar or a
numpy array representing a vector or matrix
"""

for getter in _internal_getter_methods:
try:
return getter(self, name)
except SystemError:
pass
except Exception:
raise
if name in self.parameter_names():
raise ValueError("The current Python API does not have a getter for '{}'".format(name))
else:
raise KeyError("There is no parameter called '{}' in {}".format(name, self.get_name()))

_swig_monkey_patch(SGObject, "get", _internal_get_param)
%}

#endif /* HAVE_PYTHON */
72 changes: 35 additions & 37 deletions src/interfaces/swig/SGBase.i
Original file line number Diff line number Diff line change
Expand Up @@ -237,29 +237,49 @@ public void readExternal(java.io.ObjectInput in) throws java.io.IOException, jav
}
%}

%feature("nothread") _swig_monkey_patch;
%feature("docstring", "Adds a Python object (such as a function) \n"
"to a class (method) or to a module. \n"
"If the name of the function conflicts with \n"
"another Python object in the same scope\n"
"raises a TypeError.") _swig_monkey_patch;

// taken from https://github.com/swig/swig/issues/723#issuecomment-230178855
%typemap(out) void _swig_monkey_patch "$result = PyErr_Occurred() ? NULL : SWIG_Py_Void();"
%inline %{
static void _swig_monkey_patch(PyObject *type, PyObject *name, PyObject *object) {
if (PyType_Check(type)) {
static void _swig_monkey_patch(PyObject *type, PyObject *name, PyObject *object) {
PyObject *dict = NULL;
#if PY_VERSION_HEX>=0x03000000
if (PyUnicode_Check(name))
if (!PyUnicode_Check(name))
#else
if (PyString_Check(name))
if (!PyString_Check(name))
#endif
{
PyTypeObject *pytype = (PyTypeObject *)type;
PyDict_SetItem(pytype->tp_dict, name, object);
}
else
PyErr_SetString(PyExc_TypeError, "name is not a string");
}
else
PyErr_SetString(PyExc_TypeError, "type is not a Python type");
}
{
PyErr_SetString(PyExc_TypeError, "name is not a string");
return;
}

if (PyType_Check(type)) {
PyTypeObject *pytype = (PyTypeObject *)type;
dict = pytype->tp_dict;
}
else if (PyModule_Check(type)) {
dict = PyModule_GetDict(type);
}
else {
PyErr_SetString(PyExc_TypeError, "type is not a Python type or module");
return;
}
if (PyDict_Contains(dict, name))
{
PyErr_SetString(PyExc_ValueError, "function name already exists in the given scope");
return;
}
PyDict_SetItem(dict, name, object);

}
%}


%typemap(out) PyObject* __reduce_ex__(int proto)
{
return PyObject_CallMethod(self, (char*) "__reduce__", (char*) "");
Expand Down Expand Up @@ -330,9 +350,6 @@ public void readExternal(java.io.ObjectInput in) throws java.io.IOException, jav
%ignore sg_print_error;
%ignore sg_cancel_computations;

#ifdef SWIGPYTHON
%rename(_get) get(const std::string&);
#endif // SWIGPYTHON
%rename(SGObject) CSGObject;

%include <shogun/lib/common.h>
Expand Down Expand Up @@ -472,25 +489,6 @@ namespace shogun
}

%pythoncode %{
def _internal_get_param(self, name):

for f in (self._get,
self._get_real,
self._get_int,
self._get_real_matrix,
self._get_real_vector,
self._get_int_vector
):
try:
return f(name)
except SystemError:
pass
except Exception:
raise
raise KeyError("There is no parameter called '{}' in {}".format(name, self.get_name()))

_swig_monkey_patch(SGObject, "get", _internal_get_param)

try:
import copy_reg
except ImportError:
Expand Down
11 changes: 0 additions & 11 deletions src/interfaces/swig/shogun.i
Original file line number Diff line number Diff line change
Expand Up @@ -229,24 +229,13 @@ namespace shogun
#else // SWIGJAVA
%template(put) CSGObject::put_vector_or_matrix_from_double_matrix_dispatcher<SGMatrix<float64_t>, float64_t>;
#endif // SWIGJAVA
#ifdef SWIGPYTHON
%template(_get_real) CSGObject::get<float64_t, void>;
%template(_get_int) CSGObject::get<int32_t, void>;
%template(_get_real_matrix) CSGObject::get<SGMatrix<float64_t>, void>;
#else // SWIGPYTHON
%template(get_real) CSGObject::get<float64_t, void>;
%template(get_int) CSGObject::get<int32_t, void>;
%template(get_real_matrix) CSGObject::get<SGMatrix<float64_t>, void>;
#endif // SWIGPYTHON

#ifndef SWIGJAVA
#ifdef SWIGPYTHON
%template(_get_real_vector) CSGObject::get<SGVector<float64_t>, void>;
%template(_get_int_vector) CSGObject::get<SGVector<int32_t>, void>;
#else // SWIGPYTHON
%template(get_real_vector) CSGObject::get<SGVector<float64_t>, void>;
%template(get_int_vector) CSGObject::get<SGVector<int32_t>, void>;
#endif // SWIGPYTHON
#else // SWIGJAVA
%template(get_real_vector) CSGObject::get_vector_as_matrix_dispatcher<SGMatrix<float64_t>, float64_t>;
%template(get_int_vector) CSGObject::get_vector_as_matrix_dispatcher<SGMatrix<int32_t>, int32_t>;
Expand Down

0 comments on commit ea24d4c

Please sign in to comment.