Skip to content

Commit

Permalink
Add SGSparseMatrix equals and operator== (#4075)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and karlnapf committed Jan 20, 2018
1 parent c5f2733 commit de36675
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 8 deletions.
41 changes: 38 additions & 3 deletions src/shogun/lib/SGSparseMatrix.cpp
@@ -1,9 +1,10 @@
#include <shogun/base/range.h>
#include <shogun/io/File.h>
#include <shogun/io/LibSVMFile.h>
#include <shogun/io/SGIO.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/SGSparseMatrix.h>
#include <shogun/lib/SGSparseVector.h>
#include <shogun/io/File.h>
#include <shogun/io/SGIO.h>
#include <shogun/io/LibSVMFile.h>

namespace shogun {

Expand Down Expand Up @@ -282,6 +283,40 @@ template<class T> void SGSparseMatrix<T>::from_dense(SGMatrix<T> full)
SG_FREE(num_feat_entries);
}

template <class T>
bool SGSparseMatrix<T>::operator==(const SGSparseMatrix<T>& other) const
{
if (num_vectors != other.num_vectors)
return false;

if (num_features != other.num_features)
return false;

if (sparse_matrix != other.sparse_matrix)
return false;

return true;
}

template <class T>
bool SGSparseMatrix<T>::equals(const SGSparseMatrix<T>& other) const
{
// same instance
if (*this == other)
return true;

// different size
if (num_vectors != other.num_vectors || num_features != other.num_features)
return false;

for (auto i : range(num_vectors))
{
if (!sparse_matrix[i].equals(other.sparse_matrix[i]))
return false;
}
return true;
}

template class SGSparseMatrix<bool>;
template class SGSparseMatrix<char>;
template class SGSparseMatrix<int8_t>;
Expand Down
14 changes: 13 additions & 1 deletion src/shogun/lib/SGSparseMatrix.h
Expand Up @@ -183,8 +183,20 @@ template <class T> class SGSparseMatrix : public SGReferencedData
/** sort the indices of the sparse matrix such that they are in ascending order */
void sort_features();

protected:
/** Pointer identify comparison.
* @return true iff number of vectors and features and pointer are
* equal
*/
bool operator==(const SGSparseMatrix<T>& other) const;

/** Equals method up to precision for matrix (element-wise)
* @param other matrix to compare with
* @return false if any element differs or if shapes are different,
* true otherwise
*/
bool equals(const SGSparseMatrix<T>& other) const;

protected:
/** copy data */
virtual void copy_data(const SGReferencedData& orig);

Expand Down
65 changes: 61 additions & 4 deletions tests/unit/lib/SGSparseMatrix_unittest.cc
Expand Up @@ -10,12 +10,13 @@

#include <gtest/gtest.h>

#include <shogun/lib/common.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/SGSparseVector.h>
#include <shogun/lib/SGSparseMatrix.h>
#include <shogun/base/range.h>
#include <shogun/io/LibSVMFile.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/SGSparseMatrix.h>
#include <shogun/lib/SGSparseVector.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/common.h>
#include <shogun/mathematics/Random.h>

using namespace shogun;
Expand Down Expand Up @@ -286,3 +287,59 @@ TEST(SGSparseMatrix, transposed_square_matrix)
EXPECT_NEAR(sparse_matrix(feat_index,vec_index), sparse_matrix_t(vec_index,feat_index), 1E-14);
}
}

TEST(SGSparseMatrix, equals_same_shape)
{
const index_t number_of_features = 2;
const index_t number_of_vectors = 2;

SGSparseMatrix<float64_t> m1(number_of_vectors, number_of_features);
SGSparseMatrix<float64_t> m2(number_of_vectors, number_of_features);

for (auto i : range(number_of_vectors))
{
m1.sparse_matrix[i] = SGSparseVector<float64_t>(number_of_vectors);
m2.sparse_matrix[i] = SGSparseVector<float64_t>(number_of_vectors);

for (auto j : range(number_of_features))
{
m1.sparse_matrix[i].features[j].feat_index = 0;
m1.sparse_matrix[i].features[j].entry = 1;
m2.sparse_matrix[i].features[j].feat_index = 0;
m2.sparse_matrix[i].features[j].entry = 1;
}
}

EXPECT_TRUE(m1.equals(m1));
EXPECT_TRUE(m1.equals(m2));
EXPECT_TRUE(m2.equals(m1));

m1.sparse_matrix[0].features[0].feat_index = 1;
EXPECT_FALSE(m1.equals(m2));
EXPECT_FALSE(m2.equals(m1));
m1.sparse_matrix[0].features[0].feat_index = 0;

m1.sparse_matrix[1].features[1].entry = 2;
EXPECT_FALSE(m1.equals(m2));
EXPECT_FALSE(m2.equals(m1));
}

TEST(SGSparseMatrix, equals_different_shape)
{
SGSparseMatrix<float64_t> m1(2, 2);
SGSparseMatrix<float64_t> m2(2, 3);
EXPECT_FALSE(m1.equals(m2));
EXPECT_FALSE(m2.equals(m1));
}

TEST(SGSparseMatrix, pointer_equal)
{
SGSparseMatrix<float64_t> m1(2, 2);
SGSparseMatrix<float64_t> m2(2, 3);
EXPECT_FALSE(m1 == m2);
EXPECT_TRUE(m1 == m1);
EXPECT_TRUE(m2 == m2);

SGSparseMatrix<float64_t> m3(2, 2);
EXPECT_FALSE(m1 == m3);
}

0 comments on commit de36675

Please sign in to comment.