diff --git a/src/shogun/base/class_list.h b/src/shogun/base/class_list.h index 9394eca32ee..3d7502b747f 100644 --- a/src/shogun/base/class_list.h +++ b/src/shogun/base/class_list.h @@ -7,9 +7,11 @@ #ifndef __SG_CLASS_LIST_H__ #define __SG_CLASS_LIST_H__ +#include #include #include +#include #include @@ -37,19 +39,24 @@ namespace shogun { * */ template - T* create_object(const char* name) + T* create_object(const char* name, EPrimitiveType pt=PT_NOT_GENERIC) throw (ShogunException) { - auto* object = create(name, PT_NOT_GENERIC); + auto* object = create(name, pt); if (!object) { - SG_SERROR("No such class %s", name); + SG_SERROR("Class %s with primitive type %s does not exist.\n", name, ptype(pt).c_str()); } - auto* cast = dynamic_cast(object); - if (!cast) + T* cast = nullptr; + try + { + cast = object->as(); + } + catch (const ShogunException& e) { delete_object(object); - SG_SERROR("Type mismatch"); + throw e; } + cast->ref(); return cast; } diff --git a/tests/unit/base/create_unittest.cc b/tests/unit/base/create_unittest.cc index 27b1929ea4f..4d702ac4461 100644 --- a/tests/unit/base/create_unittest.cc +++ b/tests/unit/base/create_unittest.cc @@ -2,6 +2,7 @@ #include #include #include +#include #include @@ -17,6 +18,16 @@ TEST(CreateObject,create_wrong_type) EXPECT_THROW(create_object("GaussianKernel"), ShogunException); } +TEST(CreateObject,create_wrong_ptype) +{ + EXPECT_THROW(create_object("GaussianKernel", PT_FLOAT64), ShogunException); +} + +TEST(CreateObject,create_wrong_ptype2) +{ + EXPECT_THROW(create_object("DenseFeatures"), ShogunException); +} + TEST(CreateObject,create_wrong_type_wrong_name) { EXPECT_THROW(create_object("GoussianKernel"), ShogunException); @@ -27,6 +38,7 @@ TEST(CreateObject,create) auto* obj = create_object("GaussianKernel"); EXPECT_TRUE(obj != nullptr); EXPECT_TRUE(dynamic_cast(obj) != nullptr); + EXPECT_EQ(obj->get_generic(), PT_NOT_GENERIC); delete obj; }