From 50e2efe9ccd2285683380efdd93b8322d13f19f5 Mon Sep 17 00:00:00 2001 From: Heiko Strathmann Date: Tue, 14 May 2013 12:13:37 +0100 Subject: [PATCH] =?UTF-8?q?-fixed=20floatmax=5Ft=20issues=20from=20S=C3=B6?= =?UTF-8?q?ren's=20commit=20-moved=20num=5Flements=20check=20of=20DynamicO?= =?UTF-8?q?bjectArray=20to=20CSGObject=20(cleaner)=20-added=20some=20new?= =?UTF-8?q?=20unit=20tests=20for=20num=5Felements=20and=20some=20NULL=20ca?= =?UTF-8?q?ses?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/shogun/base/Parameter.cpp | 32 +++++++------- src/shogun/base/SGObject.cpp | 17 +++++++ tests/unit/base/SGObject_unittest.cc | 66 ++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 16 deletions(-) diff --git a/src/shogun/base/Parameter.cpp b/src/shogun/base/Parameter.cpp index a04c9956ed3..4daaa334663 100644 --- a/src/shogun/base/Parameter.cpp +++ b/src/shogun/base/Parameter.cpp @@ -2877,12 +2877,6 @@ bool TParameter::equals(TParameter* other, float64_t accuracy) case CT_SCALAR: { SG_SDEBUG("CT_SCALAR\n"); - if (strcmp("m_name", "num_elements")) - { - SG_SDEBUG("Ignoring num_elements field\n"); - break; - } - if (!TParameter::compare_stype(m_datatype.m_stype, m_datatype.m_ptype, m_datatype.sizeof_ptype(), m_parameter, other->m_parameter, @@ -3147,7 +3141,6 @@ bool TParameter::compare_ptype(EPrimitiveType ptype, void* data1, void* data2, { floatmax_t casted1=*((floatmax_t*)data1); floatmax_t casted2=*((floatmax_t*)data2); - if (CMath::abs(casted1-casted2)>accuracy) { SG_SDEBUG("leaving TParameter::compare_ptype(): PT_FLOATMAX: " @@ -3168,18 +3161,25 @@ bool TParameter::compare_ptype(EPrimitiveType ptype, void* data1, void* data2, return true; } - if (casted1 && !(casted1->equals(casted2, accuracy))) + /* make sure to not call NULL methods */ + if (casted1) { - SG_SDEBUG("leaving TParameter::compare_ptype(): PT_SGOBJECT " - "equals returned false\n"); - return false; + if (!(casted1->equals(casted2, accuracy))) + { + SG_SDEBUG("leaving TParameter::compare_ptype(): PT_SGOBJECT " + "equals returned false\n"); + return false; + } } - - if (casted2 && !(casted2->equals(casted1, accuracy))) + else { - SG_SDEBUG("leaving TParameter::compare_ptype(): PT_SGOBJECT " - "equals returned false\n"); - return false; + if (!(casted2->equals(casted1, accuracy))) + { + SG_SDEBUG("leaving TParameter::compare_ptype(): PT_SGOBJECT " + "equals returned false\n"); + return false; + } + } break; } diff --git a/src/shogun/base/SGObject.cpp b/src/shogun/base/SGObject.cpp index a13ad2f9bc1..c7e9b4eb7d6 100644 --- a/src/shogun/base/SGObject.cpp +++ b/src/shogun/base/SGObject.cpp @@ -1233,12 +1233,20 @@ bool CSGObject::equals(CSGObject* other, float64_t accuracy) { SG_DEBUG("entering %s::equals()\n", get_name()); + if (other==this) + { + SG_DEBUG("leaving %s::equals(): other object is me\n", get_name()); + return true; + } + if (!other) { SG_DEBUG("leaving %s::equals(): other object is NULL\n", get_name()); return false; } + SG_SPRINT("comparing \"%s\" to \"%s\"\n", get_name(), other->get_name()); + /* a crude type check based on the get_name */ if (strcmp(other->get_name(), get_name())) { @@ -1284,6 +1292,15 @@ bool CSGObject::equals(CSGObject* other, float64_t accuracy) SG_DEBUG("comparing parameter \"%s\" to other's \"%s\"\n", this_param->m_name, other_param->m_name); + /* hard-wired exception for DynamicObjectArray parameter num_elements */ + if (!strcmp("DynamicObjectArray", get_name()) && + !strcmp(this_param->m_name, "num_elements") && + !strcmp(other_param->m_name, "num_elements")) + { + SG_DEBUG("Ignoring DynamicObjectArray::num_elements field\n"); + continue; + } + /* use equals method of TParameter from here */ if (!this_param->equals(other_param, accuracy)) { diff --git a/tests/unit/base/SGObject_unittest.cc b/tests/unit/base/SGObject_unittest.cc index 264fc8242f8..93dff049783 100644 --- a/tests/unit/base/SGObject_unittest.cc +++ b/tests/unit/base/SGObject_unittest.cc @@ -19,10 +19,35 @@ #include #endif #include +#include #include using namespace shogun; +TEST(SGObject,equals_same) +{ + CGaussianKernel* kernel=new CGaussianKernel(); + EXPECT_TRUE(kernel->equals(kernel)); + SG_UNREF(kernel); +} + +TEST(SGObject,equals_NULL_parameter) +{ + SGMatrix data(3,10); + for (index_t i=0; i* feats=new CDenseFeatures(data); + CGaussianKernel* kernel=new CGaussianKernel(); + CQuadraticTimeMMD* mmd=new CQuadraticTimeMMD(kernel, feats, 5); + CQuadraticTimeMMD* mmd2=new CQuadraticTimeMMD(NULL, feats, 5); + + mmd->equals(mmd2); + + SG_UNREF(mmd); + SG_UNREF(mmd2); +} + TEST(SGObject,equals_null) { CBinaryLabels* labels=new CBinaryLabels(10); @@ -43,6 +68,47 @@ TEST(SGObject,equals_different_name) SG_UNREF(labels2); } +TEST(SGObject,equals_DynamicObjectArray_equal) +{ + CDynamicObjectArray* array1=new CDynamicObjectArray(); + CDynamicObjectArray* array2=new CDynamicObjectArray(); + + EXPECT_TRUE(TParameter::compare_ptype(PT_SGOBJECT, &array1, &array2)); + + SG_UNREF(array1); + SG_UNREF(array2); +} + +TEST(SGObject,equals_DynamicObjectArray_equal_after_resize) +{ + CDynamicObjectArray* array1=new CDynamicObjectArray(); + CDynamicObjectArray* array2=new CDynamicObjectArray(); + + /* enforce a resize */ + for (index_t i=0; i<1000; ++i) + array1->append_element(new CGaussianKernel()); + + array1->reset_array(); + + EXPECT_TRUE(TParameter::compare_ptype(PT_SGOBJECT, &array1, &array2)); + + SG_UNREF(array1); + SG_UNREF(array2); +} + +TEST(SGObject,equals_DynamicObjectArray_different) +{ + CDynamicObjectArray* array1=new CDynamicObjectArray(); + CDynamicObjectArray* array2=new CDynamicObjectArray(); + + array1->append_element(new CGaussianKernel()); + + EXPECT_FALSE(TParameter::compare_ptype(PT_SGOBJECT, &array1, &array2)); + + SG_UNREF(array1); + SG_UNREF(array2); +} + #ifdef HAVE_EIGEN3 TEST(SGObject,equals_complex_equal) {