Skip to content

Commit

Permalink
tag based equals implementation [WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Jan 5, 2018
1 parent 4b4e380 commit 44421a7
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 105 deletions.
10 changes: 10 additions & 0 deletions src/shogun/base/AnyParameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ namespace shogun
return m_properties;
}

inline bool operator==(const AnyParameter& other)
{
return m_value == other.get_value();
}

inline bool operator!=(const AnyParameter& other)
{
return !(*this == other);
}

private:
Any m_value;
AnyParameterProperties m_properties;
Expand Down
135 changes: 56 additions & 79 deletions src/shogun/base/SGObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <stdlib.h>
#include <stdio.h>
#include <typeinfo>

#include <rxcpp/operators/rx-filter.hpp>
#include <rxcpp/rx-lite.hpp>
Expand Down Expand Up @@ -670,84 +671,6 @@ void CSGObject::build_gradient_parameter_dictionary(CMap<TParameter*, CSGObject*
}
}

bool CSGObject::equals(CSGObject* other, float64_t accuracy, bool tolerant)
{
SG_DEBUG("entering %s::equals()\n", get_name());

if (other==this)
{
SG_DEBUG("leaving %s::equals(): other object is me\n", get_name());
return true;
}

if (!other)
{
SG_DEBUG("leaving %s::equals(): other object is NULL\n", get_name());
return false;
}

SG_DEBUG("comparing \"%s\" to \"%s\"\n", get_name(), other->get_name());

/* a crude type check based on the get_name */
if (strcmp(other->get_name(), get_name()))
{
SG_INFO("leaving %s::equals(): name of other object differs\n", get_name());
return false;
}

/* should not be necessary but just ot be sure that type has not changed.
* Will assume that parameters are in same order with same name from here */
if (m_parameters->get_num_parameters()!=other->m_parameters->get_num_parameters())
{
SG_INFO("leaving %s::equals(): number of parameters of other object "
"differs\n", get_name());
return false;
}

for (index_t i=0; i<m_parameters->get_num_parameters(); ++i)
{
SG_DEBUG("comparing parameter %d\n", i);

TParameter* this_param=m_parameters->get_parameter(i);
TParameter* other_param=other->m_parameters->get_parameter(i);

/* some checks to make sure parameters have same order and names and
* are not NULL. Should never be the case but check anyway. */
if (!this_param && !other_param)
continue;

if (!this_param && other_param)
{
SG_DEBUG("leaving %s::equals(): parameter %d is NULL where other's "
"parameter \"%s\" is not\n", get_name(), other_param->m_name);
return false;
}

if (this_param && !other_param)
{
SG_DEBUG("leaving %s::equals(): parameter %d is \"%s\" where other's "
"parameter is NULL\n", get_name(), this_param->m_name);
return false;
}

SG_DEBUG("comparing parameter \"%s\" to other's \"%s\"\n",
this_param->m_name, other_param->m_name);

/* use equals method of TParameter from here */
if (!this_param->equals(other_param, accuracy, tolerant))
{
SG_INFO("leaving %s::equals(): parameters at position %d with name"
" \"%s\" differs from other object parameter with name "
"\"%s\"\n",
get_name(), i, this_param->m_name, other_param->m_name);
return false;
}
}

SG_DEBUG("leaving %s::equals(): object are equal\n", get_name());
return true;
}

CSGObject* CSGObject::clone()
{
SG_DEBUG("Constructing an empty instance of %s\n", get_name());
Expand Down Expand Up @@ -820,7 +743,7 @@ AnyParameter CSGObject::get_parameter(const BaseTag& _tag) const
const auto& parameter = self->get(_tag);
if (parameter.get_value().empty())
{
SG_ERROR("There is no parameter called \"%s\" in %s",
SG_ERROR("There is no parameter called \"%s\" in %s\n",
_tag.name().c_str(), get_name());
}
return parameter;
Expand Down Expand Up @@ -1016,3 +939,57 @@ std::string CSGObject::to_string() const
ss << ")";
return ss.str();
}

bool CSGObject::equals(CSGObject* other, float64_t accuracy, bool tolerant)
{
if (other==this)
return true;

if (other==nullptr)
{
SG_DEBUG("No object to compare to provided.\n");
return false;
}

/* Assumption: can use SGObject::get_name to distinguish types */
if (strcmp(this->get_name(), other->get_name()))
{
SG_DEBUG("Own type %s differs from provided %s.\n",
get_name(), other->get_name());
return false;
}

/* Assumption: objects of same type have same set of tags. */
for (auto it=self->map.cbegin(); it!=self->map.cend(); it++)
{
auto tag = it->first;
auto own = it->second;
auto given = other->get_parameter(tag);

SG_SDEBUG("Comparing parameter %s::%s.\n", this->get_name(),
tag.name().c_str());
if (own != given)
{
if (io->get_loglevel()<=MSG_DEBUG)
{
std::stringstream ss;
std::unique_ptr<AnyVisitor> visitor(new ToStringVisitor(&ss));

ss << "Own parameter " << this->get_name() << "::" << tag.name() << "=";
own.get_value().visit(visitor.get());

ss << " different from provided=" << other->get_name() << "::";
given.get_value().visit(visitor.get());

SG_SDEBUG("%s\n", ss.str().c_str());
}

return false;

}
}

SG_SDEBUG("All parameters of %s equal.\n", this->get_name());
return true;
}

73 changes: 47 additions & 26 deletions tests/unit/base/SGObject_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
* Written (W) 2015 Wu Lin
*/

#include <shogun/labels/BinaryLabels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/kernel/GaussianKernel.h>
#include <shogun/kernel/LinearKernel.h>
#include <shogun/regression/GaussianProcessRegression.h>
#include <shogun/machine/gp/ExactInferenceMethod.h>
#include <shogun/machine/gp/ZeroMean.h>
Expand All @@ -28,14 +27,56 @@

using namespace shogun;

TEST(SGObject,equals_same)
TEST(SGObject, equals_same_instance)
{
CGaussianKernel* kernel=new CGaussianKernel();
auto kernel = some<CGaussianKernel>();
EXPECT_TRUE(kernel->equals(kernel));
SG_UNREF(kernel);
}

TEST(SGObject,equals_NULL_parameter)
TEST(SGObject, equals_null)
{
auto kernel = some<CGaussianKernel>();
EXPECT_FALSE(kernel->equals(nullptr));
}

TEST(SGObject, equals_different_type)
{
auto kernel = some<CGaussianKernel>();
auto kernel2 = some<CLinearKernel>();

EXPECT_FALSE(kernel->equals(kernel2));
}

TEST(SGObject, equals_basic_member)
{
auto kernel = some<CGaussianKernel>(1);
auto kernel2 = some<CGaussianKernel>(1);
EXPECT_TRUE(kernel->equals(kernel2));

kernel->set_width(2);
EXPECT_FALSE(kernel->equals(kernel2));
}

TEST(SGObject, equals_object_member)
{
SGMatrix<float64_t> data(3,10);
SGMatrix<float64_t> data2(3,10);
auto feats=some<CDenseFeatures<float64_t>>(data);
auto feats2=some<CDenseFeatures<float64_t>>(data2);

auto kernel = some<CGaussianKernel>();
auto kernel2 = some<CGaussianKernel>();

kernel->init(feats, feats);
kernel2->init(feats2, feats2);

EXPECT_TRUE(kernel->equals(kernel2));

data(1,1) = 1;
EXPECT_FALSE(kernel->equals(kernel2));
}

TEST(SGObject,equals_other_has_NULL_parameter)
{
SGMatrix<float64_t> data(3,10);
for (index_t i=0; i<data.num_rows*data.num_cols; ++i)
Expand Down Expand Up @@ -80,26 +121,6 @@ TEST(SGObject,ref_unref_simple)
EXPECT_TRUE(labs == NULL);
}

TEST(SGObject,equals_null)
{
CBinaryLabels* labels=new CBinaryLabels(10);

EXPECT_FALSE(labels->equals(NULL));

SG_UNREF(labels);
}

TEST(SGObject,equals_different_name)
{
CBinaryLabels* labels=new CBinaryLabels(10);
CRegressionLabels* labels2=new CRegressionLabels(10);

EXPECT_FALSE(labels->equals(labels2));

SG_UNREF(labels);
SG_UNREF(labels2);
}

TEST(SGObject,equals_DynamicObjectArray_equal)
{
CDynamicObjectArray* array1=new CDynamicObjectArray();
Expand Down

0 comments on commit 44421a7

Please sign in to comment.