Skip to content

Commit

Permalink
Merge pull request #1094 from karlnapf/develop
Browse files Browse the repository at this point in the history
a bunch of equals fixes
  • Loading branch information
karlnapf committed May 14, 2013
2 parents 2d30a5b + 50e2efe commit 3105db8
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/shogun/base/Parameter.cpp
Expand Up @@ -2877,12 +2877,6 @@ bool TParameter::equals(TParameter* other, float64_t accuracy)
case CT_SCALAR:
{
SG_SDEBUG("CT_SCALAR\n");
if (strcmp("m_name", "num_elements"))
{
SG_SDEBUG("Ignoring num_elements field\n");
break;
}

if (!TParameter::compare_stype(m_datatype.m_stype,
m_datatype.m_ptype, m_datatype.sizeof_ptype(), m_parameter,
other->m_parameter,
Expand Down Expand Up @@ -3147,7 +3141,6 @@ bool TParameter::compare_ptype(EPrimitiveType ptype, void* data1, void* data2,
{
floatmax_t casted1=*((floatmax_t*)data1);
floatmax_t casted2=*((floatmax_t*)data2);

if (CMath::abs(casted1-casted2)>accuracy)
{
SG_SDEBUG("leaving TParameter::compare_ptype(): PT_FLOATMAX: "
Expand All @@ -3168,18 +3161,25 @@ bool TParameter::compare_ptype(EPrimitiveType ptype, void* data1, void* data2,
return true;
}

if (casted1 && !(casted1->equals(casted2, accuracy)))
/* make sure to not call NULL methods */
if (casted1)
{
SG_SDEBUG("leaving TParameter::compare_ptype(): PT_SGOBJECT "
"equals returned false\n");
return false;
if (!(casted1->equals(casted2, accuracy)))
{
SG_SDEBUG("leaving TParameter::compare_ptype(): PT_SGOBJECT "
"equals returned false\n");
return false;
}
}

if (casted2 && !(casted2->equals(casted1, accuracy)))
else
{
SG_SDEBUG("leaving TParameter::compare_ptype(): PT_SGOBJECT "
"equals returned false\n");
return false;
if (!(casted2->equals(casted1, accuracy)))
{
SG_SDEBUG("leaving TParameter::compare_ptype(): PT_SGOBJECT "
"equals returned false\n");
return false;
}

}
break;
}
Expand Down
17 changes: 17 additions & 0 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -1233,12 +1233,20 @@ bool CSGObject::equals(CSGObject* other, float64_t accuracy)
{
SG_DEBUG("entering %s::equals()\n", get_name());

if (other==this)
{
SG_DEBUG("leaving %s::equals(): other object is me\n", get_name());
return true;
}

if (!other)
{
SG_DEBUG("leaving %s::equals(): other object is NULL\n", get_name());
return false;
}

SG_SPRINT("comparing \"%s\" to \"%s\"\n", get_name(), other->get_name());

/* a crude type check based on the get_name */
if (strcmp(other->get_name(), get_name()))
{
Expand Down Expand Up @@ -1284,6 +1292,15 @@ bool CSGObject::equals(CSGObject* other, float64_t accuracy)
SG_DEBUG("comparing parameter \"%s\" to other's \"%s\"\n",
this_param->m_name, other_param->m_name);

/* hard-wired exception for DynamicObjectArray parameter num_elements */
if (!strcmp("DynamicObjectArray", get_name()) &&
!strcmp(this_param->m_name, "num_elements") &&
!strcmp(other_param->m_name, "num_elements"))
{
SG_DEBUG("Ignoring DynamicObjectArray::num_elements field\n");
continue;
}

/* use equals method of TParameter from here */
if (!this_param->equals(other_param, accuracy))
{
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/base/SGObject_unittest.cc
Expand Up @@ -19,10 +19,35 @@
#include <shogun/regression/gp/GaussianLikelihood.h>
#endif
#include <shogun/io/SerializableAsciiFile.h>
#include <shogun/statistics/QuadraticTimeMMD.h>
#include <gtest/gtest.h>

using namespace shogun;

TEST(SGObject,equals_same)
{
CGaussianKernel* kernel=new CGaussianKernel();
EXPECT_TRUE(kernel->equals(kernel));
SG_UNREF(kernel);
}

TEST(SGObject,equals_NULL_parameter)
{
SGMatrix<float64_t> data(3,10);
for (index_t i=0; i<data.num_rows*data.num_cols; ++i)
data.matrix[i]=CMath::randn_double();

CDenseFeatures<float64_t>* feats=new CDenseFeatures<float64_t>(data);
CGaussianKernel* kernel=new CGaussianKernel();
CQuadraticTimeMMD* mmd=new CQuadraticTimeMMD(kernel, feats, 5);
CQuadraticTimeMMD* mmd2=new CQuadraticTimeMMD(NULL, feats, 5);

mmd->equals(mmd2);

SG_UNREF(mmd);
SG_UNREF(mmd2);
}

TEST(SGObject,equals_null)
{
CBinaryLabels* labels=new CBinaryLabels(10);
Expand All @@ -43,6 +68,47 @@ TEST(SGObject,equals_different_name)
SG_UNREF(labels2);
}

TEST(SGObject,equals_DynamicObjectArray_equal)
{
CDynamicObjectArray* array1=new CDynamicObjectArray();
CDynamicObjectArray* array2=new CDynamicObjectArray();

EXPECT_TRUE(TParameter::compare_ptype(PT_SGOBJECT, &array1, &array2));

SG_UNREF(array1);
SG_UNREF(array2);
}

TEST(SGObject,equals_DynamicObjectArray_equal_after_resize)
{
CDynamicObjectArray* array1=new CDynamicObjectArray();
CDynamicObjectArray* array2=new CDynamicObjectArray();

/* enforce a resize */
for (index_t i=0; i<1000; ++i)
array1->append_element(new CGaussianKernel());

array1->reset_array();

EXPECT_TRUE(TParameter::compare_ptype(PT_SGOBJECT, &array1, &array2));

SG_UNREF(array1);
SG_UNREF(array2);
}

TEST(SGObject,equals_DynamicObjectArray_different)
{
CDynamicObjectArray* array1=new CDynamicObjectArray();
CDynamicObjectArray* array2=new CDynamicObjectArray();

array1->append_element(new CGaussianKernel());

EXPECT_FALSE(TParameter::compare_ptype(PT_SGOBJECT, &array1, &array2));

SG_UNREF(array1);
SG_UNREF(array2);
}

#ifdef HAVE_EIGEN3
TEST(SGObject,equals_complex_equal)
{
Expand Down

0 comments on commit 3105db8

Please sign in to comment.