Skip to content

Commit

Permalink
add arugment for primitive type to create_object
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Mar 19, 2018
1 parent 278ad1b commit ae8a7b9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/shogun/base/class_list.h
Expand Up @@ -7,9 +7,11 @@
#ifndef __SG_CLASS_LIST_H__
#define __SG_CLASS_LIST_H__

#include <shogun/base/SGObject.h>
#include <shogun/lib/config.h>

#include <shogun/lib/DataType.h>
#include <shogun/lib/ShogunException.h>

#include <shogun/io/SGIO.h>

Expand Down Expand Up @@ -37,19 +39,24 @@ namespace shogun {
*
*/
template <class T>
T* create_object(const char* name)
T* create_object(const char* name, EPrimitiveType pt=PT_NOT_GENERIC) throw (ShogunException)
{
auto* object = create(name, PT_NOT_GENERIC);
auto* object = create(name, pt);
if (!object)
{
SG_SERROR("No such class %s", name);
SG_SERROR("Class %s with primitive type %s does not exist.\n", name, ptype(pt).c_str());
}
auto* cast = dynamic_cast<T*>(object);
if (!cast)
T* cast = nullptr;
try
{
cast = object->as<T>();
}
catch (const ShogunException& e)
{
delete_object(object);
SG_SERROR("Type mismatch");
throw e;
}

cast->ref();
return cast;
}
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/base/create_unittest.cc
Expand Up @@ -2,6 +2,7 @@
#include <shogun/kernel/Kernel.h>
#include <shogun/machine/KernelMachine.h>
#include <shogun/machine/Machine.h>
#include <shogun/features/DenseFeatures.h>

#include <gtest/gtest.h>

Expand All @@ -17,6 +18,16 @@ TEST(CreateObject,create_wrong_type)
EXPECT_THROW(create_object<CMachine>("GaussianKernel"), ShogunException);
}

TEST(CreateObject,create_wrong_ptype)
{
EXPECT_THROW(create_object<CMachine>("GaussianKernel", PT_FLOAT64), ShogunException);
}

TEST(CreateObject,create_wrong_ptype2)
{
EXPECT_THROW(create_object<CMachine>("DenseFeatures"), ShogunException);
}

TEST(CreateObject,create_wrong_type_wrong_name)
{
EXPECT_THROW(create_object<CMachine>("GoussianKernel"), ShogunException);
Expand All @@ -27,6 +38,7 @@ TEST(CreateObject,create)
auto* obj = create_object<CKernel>("GaussianKernel");
EXPECT_TRUE(obj != nullptr);
EXPECT_TRUE(dynamic_cast<CKernel*>(obj) != nullptr);
EXPECT_EQ(obj->get_generic(), PT_NOT_GENERIC);
delete obj;
}

Expand Down

0 comments on commit ae8a7b9

Please sign in to comment.