diff --git a/src/shogun/lib/SGSparseVector.cpp b/src/shogun/lib/SGSparseVector.cpp index 7bb408b9359..69fd239ee6c 100644 --- a/src/shogun/lib/SGSparseVector.cpp +++ b/src/shogun/lib/SGSparseVector.cpp @@ -358,6 +358,42 @@ SGSparseVector SGSparseVector::clone() const return SGSparseVector(copy, num_feat_entries); } +template +inline bool +SGSparseVector::operator==(const SGSparseVector& other) const +{ + if (num_feat_entries != other.num_feat_entries) + return false; + + if (features != other.features) + return false; + + return true; +} + +template +bool SGSparseVector::equals(const SGSparseVector& other) const +{ + /* same instance */ + if (*this==other) + return true; + + // both empty + if (!(num_feat_entries || other.num_feat_entries)) + return true; + + // only one empty + if (!num_feat_entries || !other.num_feat_entries) + return false; + + // different size + if (num_feat_entries != other.num_feat_entries) + return false; + + // content + return std::equal(features, features+num_feat_entries, other.features); +} + template void SGSparseVector::load(CFile * loader) { ASSERT(loader) @@ -615,6 +651,43 @@ void SGSparseVector::display_vector(const char * name, const char SG_SPRINT("%s]\n", prefix); } + +template +bool SGSparseVectorEntry::operator==(const SGSparseVectorEntry& other) const +{ + if (feat_index != other.feat_index) + return false; + + return entry == other.entry; +} + +#ifndef REAL_SPARSE_EQUALS +#define REAL_SPARSE_EQUALS(real_t) \ +template <> \ +bool SGSparseVectorEntry::operator==(const SGSparseVectorEntry& other) const \ +{ \ + if (feat_index != other.feat_index) \ + return false; \ + \ + return CMath::fequals(entry, other.entry, std::numeric_limits::epsilon()); \ +} + +REAL_SPARSE_EQUALS(float32_t) +REAL_SPARSE_EQUALS(float64_t) +REAL_SPARSE_EQUALS(floatmax_t) +#undef REAL_SPARSE_EQUALS +#endif // REAL_SPARSE_EQUALS + +template <> +bool SGSparseVectorEntry::operator==(const SGSparseVectorEntry& other) const +{ + if (feat_index != other.feat_index) + return false; + + return CMath::fequals(entry.real(), other.entry.real(), LDBL_EPSILON) && + CMath::fequals(entry.imag(), other.entry.imag(), LDBL_EPSILON); +} + template class SGSparseVector; template class SGSparseVector; template class SGSparseVector; diff --git a/src/shogun/lib/SGSparseVector.h b/src/shogun/lib/SGSparseVector.h index 7da4799ef88..e8b1e367eac 100644 --- a/src/shogun/lib/SGSparseVector.h +++ b/src/shogun/lib/SGSparseVector.h @@ -30,6 +30,11 @@ template struct SGSparseVectorEntry index_t feat_index; /** entry ... */ T entry; + + /** Comparson of entry + * @return true iff index and value (numerically for floats) are equal + */ + inline bool operator==(const SGSparseVectorEntry& other) const; }; /** @brief template class SGSparseVector @@ -189,6 +194,15 @@ template class SGSparseVector : public SGReferencedData void display_vector(const char* name="vector", const char* prefix=""); + + /** Pointer identify comparison. + * @return true iff length and pointer are equal + */ + inline bool + operator==(const SGSparseVector& other) const; + + bool equals(const SGSparseVector& other) const; + protected: virtual void copy_data(const SGReferencedData& orig); @@ -216,13 +230,6 @@ template class SGSparseVector : public SGReferencedData }; -template -inline bool -operator==(const SGSparseVector& lhs, const SGSparseVector& rhs) -{ - SG_SERROR("Comparison is not implemented for sparse vectors"); - return false; -} } #endif // __SGSPARSEVECTOR_H__ diff --git a/tests/unit/lib/SGSparseVector_unittest.cc b/tests/unit/lib/SGSparseVector_unittest.cc index cff89d38270..15bdce4cc4e 100644 --- a/tests/unit/lib/SGSparseVector_unittest.cc +++ b/tests/unit/lib/SGSparseVector_unittest.cc @@ -506,3 +506,52 @@ TEST(SGSparseVector, sparse_dot_not_sorted_features_different_length_last_index_ EXPECT_EQ(4, SGSparseVector::sparse_dot(v1, v2)); EXPECT_EQ(4, SGSparseVector::sparse_dot(v2, v1)); } + +/** @brief Fixture class template for typed tests of equals method */ +template +class SGSparseVectorEquals : public ::testing::Test +{ +public: + SGSparseVector v1_ = SGSparseVector(1); + SGSparseVector v2_ = SGSparseVector(1); +}; +typedef ::testing::Types SGSparseVectorEqualsTypes; +TYPED_TEST_CASE(SGSparseVectorEquals, SGSparseVectorEqualsTypes); + +TYPED_TEST(SGSparseVectorEquals, equals_same_dim) +{ + auto& v1=this->v1_; + auto& v2=this->v2_; + + v1.features[0].feat_index = 1; + v1.features[0].entry = 1; + v2.features[0].feat_index = 1; + v2.features[0].entry = 1; + EXPECT_TRUE(v1.equals(v1)); + EXPECT_TRUE(v1.equals(v2)); + EXPECT_TRUE(v2.equals(v1)); + + v1.features[0].feat_index = 1; + v1.features[0].entry = 1; + v2.features[0].feat_index = 1; + v2.features[0].entry = 0; + EXPECT_FALSE(v1.equals(v2)); + EXPECT_FALSE(v2.equals(v1)); + + v1.features[0].feat_index = 1; + v1.features[0].entry = 1; + v2.features[0].feat_index = 0; + v2.features[0].entry = 1; + EXPECT_FALSE(v1.equals(v2)); + EXPECT_FALSE(v2.equals(v1)); +} + +TYPED_TEST(SGSparseVectorEquals, equals_different_dim) +{ + auto& v1=this->v1_; + auto& v2=this->v2_; + + EXPECT_FALSE(v1.equals(v2)); + EXPECT_FALSE(v2.equals(v1)); +} +