Skip to content

Commit

Permalink
Introduce new factory for creating objects
Browse files Browse the repository at this point in the history
Populate for kernel, machine, features, labels
  • Loading branch information
karlnapf committed Mar 19, 2018
1 parent 38e7e2f commit f2dcf6c
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 37 deletions.
2 changes: 0 additions & 2 deletions src/interfaces/swig/Kernel.i
Expand Up @@ -8,8 +8,6 @@
* Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society
*/

%newobject kernel();

#ifdef HAVE_PYTHON
%feature("autodoc", "get_kernel_matrix(self) -> numpy 2dim array of float") get_kernel_matrix;
%feature("autodoc", "get_POIM2(self) -> [] of float") get_POIM2;
Expand Down
1 change: 0 additions & 1 deletion src/interfaces/swig/Machine.i
Expand Up @@ -23,7 +23,6 @@
%newobject apply_structured(CFeatures* data);
%newobject apply_latent();
%newobject apply_latent(CFeatures* data);
%newobject machine();

#if defined(SWIGPYTHON) || defined(SWIGOCTAVE) || defined(SWIGRUBY) || defined(SWIGLUA) || defined(SWIGR)

Expand Down
6 changes: 6 additions & 0 deletions src/interfaces/swig/factory.i
@@ -0,0 +1,6 @@
%{
#include <shogun/util/factory.h>
%}
%include <shogun/util/factory.h>

%template(features) shogun::features<float64_t>;
6 changes: 5 additions & 1 deletion src/interfaces/swig/shogun.i
Expand Up @@ -115,14 +115,18 @@
%include "Boost.i"

%include "ParameterObserver.i"
%include "factory.i"

#if defined(SWIGPERL)
%include "abstract_types_extension.i"
#endif

%pragma(java) moduleimports=%{
import org.jblas.*;
%}

namespace shogun
{

%extend CSGObject
{
template <typename T, typename U= typename std::enable_if_t<std::is_arithmetic<T>::value>>
Expand Down
6 changes: 0 additions & 6 deletions src/shogun/kernel/Kernel.cpp
Expand Up @@ -19,7 +19,6 @@
#include <shogun/lib/Time.h>
#include <shogun/lib/common.h>
#include <shogun/lib/config.h>
#include <shogun/base/class_list.h>

#include <shogun/base/Parallel.h>

Expand Down Expand Up @@ -1379,8 +1378,3 @@ template SGMatrix<float32_t> CKernel::get_kernel_matrix<float32_t>();

template void* CKernel::get_kernel_matrix_helper<float64_t>(void* p);
template void* CKernel::get_kernel_matrix_helper<float32_t>(void* p);

CKernel* shogun::kernel(const char* name)
{
return create_object<CKernel>(name);
}
6 changes: 0 additions & 6 deletions src/shogun/kernel/Kernel.h
Expand Up @@ -1087,11 +1087,5 @@ class CKernel : public CSGObject
CKernelNormalizer* normalizer;
};

/** Creates kernel by its name
*
* @param name the name of the kernel to create
*/
CKernel* kernel(const char* name);

}
#endif /* _KERNEL_H__ */
5 changes: 0 additions & 5 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -287,8 +287,3 @@ rxcpp::subscription CMachine::connect_to_signal_handler()
[this]() { this->on_complete(); });
return get_global_signal()->get_observable()->subscribe(subscriber);
}

CMachine* shogun::machine(const char* name)
{
return create_object<CMachine>(name);
}
5 changes: 0 additions & 5 deletions src/shogun/machine/Machine.h
Expand Up @@ -456,10 +456,5 @@ class CMachine : public CSGObject
std::mutex m_mutex;
};

/** Creates machine by its name
*
* @param name the name of the machine to create
*/
CMachine* machine(const char* name);
}
#endif // _MACHINE_H__
93 changes: 93 additions & 0 deletions src/shogun/util/factory.h
@@ -0,0 +1,93 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Heiko Strathmann
*/
#ifndef FACTORY_H_
#define FACTORY_H_

#include <shogun/base/class_list.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/labels/DenseLabels.h>
#include <shogun/io/CSVFile.h>
#include <shogun/io/SGIO.h>

namespace shogun
{

CKernel* kernel(const std::string& name);
CMachine* machine(const std::string& name);
CFeatures* features(CFile* file, EPrimitiveType primitive_type=PT_FLOAT64);
template <class T> CFeatures* features(SGMatrix<T> mat);
CLabels* labels(CFile* file);

CFile* csv_file(std::string fname, char rw='r');

#define BASE_CLASS_FACTORY(T, factory_name) \
T* factory_name(const std::string& name) \
{ \
return create_object<T>(name.c_str()); \
}

BASE_CLASS_FACTORY(CKernel, kernel)
BASE_CLASS_FACTORY(CMachine, machine)

template <class T>
CFeatures* features(SGMatrix<T> mat)
{
CFeatures* features = new CDenseFeatures<T>(mat);
SG_REF(features);
return features;
}

CFeatures* features(CFile* file, EPrimitiveType primitive_type)
{
REQUIRE(file, "No file provided.\n");
CFeatures* result = nullptr;

if (dynamic_cast<CCSVFile*>(file))
{
switch (primitive_type)
{
case PT_FLOAT64:
result = new CDenseFeatures<float64_t>();
break;
default:
SG_SNOTIMPLEMENTED
}
result->load(file);
}
else
SG_SERROR("Cannot load features from %s.\n", file->get_name());

SG_REF(result);
return result;
}

CLabels* labels(CFile* file)
{
REQUIRE(file, "No file provided.\n");
CLabels* result = nullptr;

if (dynamic_cast<CCSVFile*>(file))
{
CDenseLabels* result_ = new CDenseLabels();
result_->load(file);
result = result_;
}
else
SG_SERROR("Cannot load labels from file %s.\n", file->get_name());

SG_REF(result);
return result;
}

CFile* csv_file(std::string fname, char rw)
{
CFile* result = new CCSVFile(fname.c_str(), rw);
SG_REF(result);
return result;
}

}
#endif // FACTORY_H_
15 changes: 4 additions & 11 deletions tests/unit/base/create_unittest.cc
Expand Up @@ -42,18 +42,11 @@ TEST(CreateObject,create)
delete obj;
}

TEST(CreateObject,create_kernel)
TEST(CreateObject,create_with_ptype)
{
auto* obj = kernel("GaussianKernel");
auto* obj = create_object<CDenseFeatures<float64_t>>("DenseFeatures", PT_FLOAT64);
EXPECT_TRUE(obj != nullptr);
EXPECT_TRUE(dynamic_cast<CKernel*>(obj) != nullptr);
EXPECT_TRUE(dynamic_cast<CDenseFeatures<float64_t>*>(obj) != nullptr);
EXPECT_EQ(obj->get_generic(), PT_FLOAT64);
delete obj;
}

TEST(CreateObject, create_machine)
{
auto* obj = machine("LibSVM");
EXPECT_TRUE(obj != nullptr);
EXPECT_TRUE(dynamic_cast<CMachine*>(obj) != nullptr);
delete obj;
}
79 changes: 79 additions & 0 deletions tests/unit/utils/factory_unittest.cc
@@ -0,0 +1,79 @@
#include <shogun/base/class_list.h>
#include <shogun/kernel/GaussianKernel.h>
#include <shogun/classifier/svm/LibSVM.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/util/factory.h>
#include "utils/Utils.h"

#include <gtest/gtest.h>

using namespace shogun;

class GaussianKernel;

TEST(Factory,kernel)
{
auto* obj = kernel("GaussianKernel");
EXPECT_TRUE(obj != nullptr);
EXPECT_TRUE(dynamic_cast<CGaussianKernel*>(obj) != nullptr);
delete obj;
}

TEST(Factory, machine)
{
auto* obj = machine("LibSVM");
EXPECT_TRUE(obj != nullptr);
EXPECT_TRUE(dynamic_cast<CLibSVM*>(obj) != nullptr);
delete obj;
}

TEST(Factory, features_from_matrix)
{
SGMatrix<float64_t> mat(2,3);
auto* obj = features(mat);
EXPECT_TRUE(obj != nullptr);
EXPECT_TRUE(dynamic_cast<CDenseFeatures<float64_t>*>(obj) != nullptr);
delete obj;
}

TEST(Factory, features_dense_from_file)
{
std::string filename = "Factory-features_from_file.XXXXXX";

SGMatrix<float64_t> mat(2,3);
mat.set_const(1);
generate_temp_filename(const_cast<char*>(filename.c_str()));
auto file_save=some<CCSVFile>(filename.c_str(), 'w');
mat.save(file_save);
file_save->close();

auto file_load = some<CCSVFile>(filename.c_str(), 'r');
auto* obj = features(file_load, PT_FLOAT64);
EXPECT_TRUE(obj != nullptr);
auto* cast = dynamic_cast<CDenseFeatures<float64_t>*>(obj);
ASSERT_TRUE(cast != nullptr);
auto loaded_mat = cast->get_feature_matrix();
EXPECT_TRUE(loaded_mat.equals(mat));
delete obj;
}

TEST(Factory, labels_from_file)
{
std::string filename = "Factory-labels_from_file.XXXXXX";

SGVector<float64_t> vec(3);
vec.set_const(1);
generate_temp_filename(const_cast<char*>(filename.c_str()));
auto file_save=some<CCSVFile>(filename.c_str(), 'w');
vec.save(file_save);
file_save->close();

auto file_load = some<CCSVFile>(filename.c_str(), 'r');
auto* obj = labels(file_load);
EXPECT_TRUE(obj != nullptr);
auto* cast = dynamic_cast<CDenseLabels*>(obj);
ASSERT_TRUE(cast != nullptr);
auto loaded_vec = cast->get_labels();
EXPECT_TRUE(loaded_vec.equals(vec));
delete obj;
}

0 comments on commit f2dcf6c

Please sign in to comment.