Skip to content

Commit

Permalink
Add equals method to SGSparseVector
Browse files Browse the repository at this point in the history
Via operator== of SGSparseVectorEntry. Test for all types.
  • Loading branch information
karlnapf committed Jan 8, 2018
1 parent 8dec3e0 commit 95a0bae
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 7 deletions.
73 changes: 73 additions & 0 deletions src/shogun/lib/SGSparseVector.cpp
Expand Up @@ -358,6 +358,42 @@ SGSparseVector<T> SGSparseVector<T>::clone() const
return SGSparseVector<T>(copy, num_feat_entries);
}

template <class T>
inline bool
SGSparseVector<T>::operator==(const SGSparseVector<T>& other) const
{
if (num_feat_entries != other.num_feat_entries)
return false;

if (features != other.features)
return false;

return true;
}

template <class T>
bool SGSparseVector<T>::equals(const SGSparseVector<T>& other) const
{
/* same instance */
if (*this==other)
return true;

// both empty
if (!(num_feat_entries || other.num_feat_entries))
return true;

// only one empty
if (!num_feat_entries || !other.num_feat_entries)
return false;

// different size
if (num_feat_entries != other.num_feat_entries)
return false;

// content
return std::equal(features, features+num_feat_entries, other.features);
}

template<class T> void SGSparseVector<T>::load(CFile * loader)
{
ASSERT(loader)
Expand Down Expand Up @@ -615,6 +651,43 @@ void SGSparseVector<complex128_t>::display_vector(const char * name, const char
SG_SPRINT("%s]\n", prefix);
}


template <class T>
bool SGSparseVectorEntry<T>::operator==(const SGSparseVectorEntry<T>& other) const
{
if (feat_index != other.feat_index)
return false;

return entry == other.entry;
}

#ifndef REAL_SPARSE_EQUALS
#define REAL_SPARSE_EQUALS(real_t) \
template <> \
bool SGSparseVectorEntry<real_t>::operator==(const SGSparseVectorEntry<real_t>& other) const \
{ \
if (feat_index != other.feat_index) \
return false; \
\
return CMath::fequals<real_t>(entry, other.entry, std::numeric_limits<real_t>::epsilon()); \
}

REAL_SPARSE_EQUALS(float32_t)
REAL_SPARSE_EQUALS(float64_t)
REAL_SPARSE_EQUALS(floatmax_t)
#undef REAL_SPARSE_EQUALS
#endif // REAL_SPARSE_EQUALS

template <>
bool SGSparseVectorEntry<complex128_t>::operator==(const SGSparseVectorEntry<complex128_t>& other) const
{
if (feat_index != other.feat_index)
return false;

return CMath::fequals<float64_t>(entry.real(), other.entry.real(), LDBL_EPSILON) &&
CMath::fequals<float64_t>(entry.imag(), other.entry.imag(), LDBL_EPSILON);
}

template class SGSparseVector<bool>;
template class SGSparseVector<char>;
template class SGSparseVector<int8_t>;
Expand Down
21 changes: 14 additions & 7 deletions src/shogun/lib/SGSparseVector.h
Expand Up @@ -30,6 +30,11 @@ template <class T> struct SGSparseVectorEntry
index_t feat_index;
/** entry ... */
T entry;

/** Comparson of entry
* @return true iff index and value (numerically for floats) are equal
*/
inline bool operator==(const SGSparseVectorEntry<T>& other) const;
};

/** @brief template class SGSparseVector
Expand Down Expand Up @@ -189,6 +194,15 @@ template <class T> class SGSparseVector : public SGReferencedData
void display_vector(const char* name="vector",
const char* prefix="");


/** Pointer identify comparison.
* @return true iff length and pointer are equal
*/
inline bool
operator==(const SGSparseVector<T>& other) const;

bool equals(const SGSparseVector<T>& other) const;

protected:
virtual void copy_data(const SGReferencedData& orig);

Expand Down Expand Up @@ -216,13 +230,6 @@ template <class T> class SGSparseVector : public SGReferencedData

};

template <class T>
inline bool
operator==(const SGSparseVector<T>& lhs, const SGSparseVector<T>& rhs)
{
SG_SERROR("Comparison is not implemented for sparse vectors");
return false;
}
}

#endif // __SGSPARSEVECTOR_H__
49 changes: 49 additions & 0 deletions tests/unit/lib/SGSparseVector_unittest.cc
Expand Up @@ -506,3 +506,52 @@ TEST(SGSparseVector, sparse_dot_not_sorted_features_different_length_last_index_
EXPECT_EQ(4, SGSparseVector<int32_t>::sparse_dot(v1, v2));
EXPECT_EQ(4, SGSparseVector<int32_t>::sparse_dot(v2, v1));
}

/** @brief Fixture class template for typed tests of equals method */
template <typename T>
class SGSparseVectorEquals : public ::testing::Test
{
public:
SGSparseVector<T> v1_ = SGSparseVector<T>(1);
SGSparseVector<T> v2_ = SGSparseVector<T>(1);
};
typedef ::testing::Types<int16_t, int32_t, int64_t, float32_t, float64_t, floatmax_t, complex128_t> SGSparseVectorEqualsTypes;
TYPED_TEST_CASE(SGSparseVectorEquals, SGSparseVectorEqualsTypes);

TYPED_TEST(SGSparseVectorEquals, equals_same_dim)
{
auto& v1=this->v1_;
auto& v2=this->v2_;

v1.features[0].feat_index = 1;
v1.features[0].entry = 1;
v2.features[0].feat_index = 1;
v2.features[0].entry = 1;
EXPECT_TRUE(v1.equals(v1));
EXPECT_TRUE(v1.equals(v2));
EXPECT_TRUE(v2.equals(v1));

v1.features[0].feat_index = 1;
v1.features[0].entry = 1;
v2.features[0].feat_index = 1;
v2.features[0].entry = 0;
EXPECT_FALSE(v1.equals(v2));
EXPECT_FALSE(v2.equals(v1));

v1.features[0].feat_index = 1;
v1.features[0].entry = 1;
v2.features[0].feat_index = 0;
v2.features[0].entry = 1;
EXPECT_FALSE(v1.equals(v2));
EXPECT_FALSE(v2.equals(v1));
}

TYPED_TEST(SGSparseVectorEquals, equals_different_dim)
{
auto& v1=this->v1_;
auto& v2=this->v2_;

EXPECT_FALSE(v1.equals(v2));
EXPECT_FALSE(v2.equals(v1));
}

0 comments on commit 95a0bae

Please sign in to comment.