Skip to content

Commit

Permalink
introduce "machine" factory and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Feb 28, 2018
1 parent 28112f1 commit 5863cda
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 28 deletions.
3 changes: 0 additions & 3 deletions examples/meta/src/base_api/kernel.sg

This file was deleted.

4 changes: 4 additions & 0 deletions examples/meta/src/base_api/objects.sg
@@ -0,0 +1,4 @@
Machine lib_svm = machine("LibSVM")
Machine lda = machine("LDA")
Kernel kernel_gaussian = kernel("GaussianKernel")
Kernel kernel_linear = kernel("LinearKernel")
2 changes: 1 addition & 1 deletion examples/meta/src/binary/kernel_support_vector_machine.sg
Expand Up @@ -15,7 +15,7 @@ Kernel k = kernel("GaussianKernel", log_width=1.0074515102711323)
#![set_parameters]

#![create_instance]
KernelMachine svm = kernel_machine("LibSVM", C1=1.0, C2=1.0, kernel=k, labels=labels_train, epsilon=0.001)
Machine svm = machine("LibSVM", C1=1.0, C2=1.0, kernel=k, labels=labels_train, epsilon=0.001)
#![create_instance]

#![train_and_apply]
Expand Down
5 changes: 2 additions & 3 deletions examples/meta/src/meta_api/kwargs.sg
Expand Up @@ -27,6 +27,5 @@
# ------------

Kernel k = kernel("GaussianKernel", log_width=2.0)
KernelMachine svm = kernel_machine("LibSVM", C1=1.1, kernel=k)

KernelMachine svm2 = kernel_machine("LibSVM", kernel=kernel("GaussianKernel"))
Machine svm = machine("LibSVM", C1=1.1, kernel=k)
Machine svm2 = machine("LibSVM", kernel=kernel("GaussianKernel"))
3 changes: 3 additions & 0 deletions src/interfaces/swig/Kernel.i
Expand Up @@ -7,6 +7,9 @@
* Written (W) 2009 Soeren Sonnenburg
* Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society
*/

%newobject kernel();

#ifdef HAVE_PYTHON
%feature("autodoc", "get_kernel_matrix(self) -> numpy 2dim array of float") get_kernel_matrix;
%feature("autodoc", "get_POIM2(self) -> [] of float") get_POIM2;
Expand Down
1 change: 1 addition & 0 deletions src/interfaces/swig/Machine.i
Expand Up @@ -23,6 +23,7 @@
%newobject apply_structured(CFeatures* data);
%newobject apply_latent();
%newobject apply_latent(CFeatures* data);
%newobject machine();

#if defined(SWIGPYTHON) || defined(SWIGOCTAVE) || defined(SWIGRUBY) || defined(SWIGLUA) || defined(SWIGR)

Expand Down
6 changes: 0 additions & 6 deletions src/shogun/machine/KernelMachine.cpp
Expand Up @@ -6,7 +6,6 @@
* Fernando Iglesias, Thoralf Klein
*/

#include <shogun/base/class_list.h>
#include <shogun/base/progress.h>
#include <shogun/io/SGIO.h>
#include <shogun/labels/RegressionLabels.h>
Expand Down Expand Up @@ -653,8 +652,3 @@ bool CKernelMachine::supports_locking() const
{
return true;
}

CKernelMachine* shogun::kernel_machine(const char* name)
{
return create_object<CKernelMachine>(name);
}
6 changes: 0 additions & 6 deletions src/shogun/machine/KernelMachine.h
Expand Up @@ -328,11 +328,5 @@ class CKernelMachine : public CMachine
/** array of ``support vectors'' (indices of feature objects) */
SGVector<int32_t> m_svs;
};

/** Creates kernel machine by its name
*
* @param name the name of the kernel machine to create
*/
CKernelMachine* kernel_machine(const char* name);
}
#endif /* _KERNEL_MACHINE_H__ */
5 changes: 5 additions & 0 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -287,3 +287,8 @@ rxcpp::subscription CMachine::connect_to_signal_handler()
[this]() { this->on_complete(); });
return get_global_signal()->get_observable()->subscribe(subscriber);
}

CMachine* shogun::machine(const char* name)
{
return create_object<CMachine>(name);
}
18 changes: 12 additions & 6 deletions src/shogun/machine/Machine.h
Expand Up @@ -11,16 +11,16 @@
#ifndef _MACHINE_H__
#define _MACHINE_H__

#include <shogun/lib/config.h>

#include <shogun/lib/common.h>
#include <shogun/base/SGObject.h>
#include <shogun/base/class_list.h>
#include <shogun/features/Features.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/labels/LatentLabels.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/labels/StructuredLabels.h>
#include <shogun/labels/LatentLabels.h>
#include <shogun/features/Features.h>
#include <shogun/lib/common.h>
#include <shogun/lib/config.h>

#include <condition_variable>
#include <mutex>
Expand Down Expand Up @@ -455,5 +455,11 @@ class CMachine : public CSGObject
/** Mutex used to pause threads */
std::mutex m_mutex;
};

/** Creates machine by its name
*
* @param name the name of the machine to create
*/
CMachine* machine(const char* name);
}
#endif // _MACHINE_H__
6 changes: 3 additions & 3 deletions tests/unit/base/create_unittest.cc
Expand Up @@ -38,10 +38,10 @@ TEST(CreateObject,create_kernel)
delete obj;
}

TEST(CreateObject, create_kernel_machine)
TEST(CreateObject, create_machine)
{
auto* obj = kernel_machine("LibSVM");
auto* obj = machine("LibSVM");
EXPECT_TRUE(obj != nullptr);
EXPECT_TRUE(dynamic_cast<CKernelMachine*>(obj) != nullptr);
EXPECT_TRUE(dynamic_cast<CMachine*>(obj) != nullptr);
delete obj;
}

0 comments on commit 5863cda

Please sign in to comment.