Skip to content

Commit

Permalink
add get_diagonal_vector() method to SGMatrix class
Browse files Browse the repository at this point in the history
  • Loading branch information
votjakovr committed Aug 21, 2013
1 parent 6c21825 commit d711ce4
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 8 deletions.
16 changes: 15 additions & 1 deletion src/shogun/lib/SGMatrix.cpp
Expand Up @@ -932,7 +932,7 @@ void SGMatrix<T>::load(CFile* loader)
{
ASSERT(loader)
unref();

SG_SET_LOCALE_C;
SGMatrix<T> mat;
loader->get_matrix(mat.matrix, mat.num_rows, mat.num_cols);
Expand Down Expand Up @@ -974,6 +974,20 @@ SGVector<T> SGMatrix<T>::get_row_vector(index_t row) const
return rowv;
}

template<class T>
SGVector<T> SGMatrix<T>::get_diagonal_vector() const
{
index_t diag_vlen=CMath::min(num_cols, num_rows);
SGVector<T> diag(diag_vlen);

for (index_t i=0; i<diag_vlen; i++)
{
diag[i]=matrix[i*num_rows+i];
}

return diag;
}

template class SGMatrix<bool>;
template class SGMatrix<char>;
template class SGMatrix<int8_t>;
Expand Down
10 changes: 8 additions & 2 deletions src/shogun/lib/SGMatrix.h
Expand Up @@ -61,6 +61,12 @@ template<class T> class SGMatrix : public SGReferencedData
*/
SGVector<T> get_row_vector(index_t row) const;

/** get a main diagonal vector. Matrix is not required to be square.
*
* @return main diagonal vector
*/
SGVector<T> get_diagonal_vector() const;

/** operator overload for matrix read only access
* @param i_row
* @param i_col
Expand Down Expand Up @@ -95,7 +101,7 @@ template<class T> class SGMatrix : public SGReferencedData
return matrix[index];
}

/**
/**
* get the matrix (no copying is done here)
*
* @return the refcount increased matrix
Expand Down Expand Up @@ -180,7 +186,7 @@ template<class T> class SGMatrix : public SGReferencedData
static double* compute_eigenvectors(double* matrix, int n, int m);

/** compute few eigenpairs of a symmetric matrix using LAPACK DSYEVR method
* (Relatively Robust Representations).
* (Relatively Robust Representations).
* Has at least O(n^3/3) complexity
* @param matrix_ symmetric matrix
* @param eigenvalues contains iu-il+1 eigenvalues in ascending order (to be free'd)
Expand Down
49 changes: 44 additions & 5 deletions tests/unit/lib/SGMatrix_unittest.cc
@@ -1,4 +1,5 @@
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/SGVector.h>
#include <shogun/mathematics/Math.h>
#include <gtest/gtest.h>

Expand Down Expand Up @@ -53,14 +54,14 @@ TEST(SGMatrixTest,equals_equal)
a(1,1)=4;
a(2,0)=5;
a(2,1)=6;

b(0,0)=1;
b(0,1)=2;
b(1,0)=3;
b(1,1)=4;
b(2,0)=5;
b(2,1)=6;

EXPECT_TRUE(a.equals(b));
}

Expand All @@ -74,14 +75,14 @@ TEST(SGMatrixTest,equals_different)
a(1,1)=4;
a(2,0)=5;
a(2,1)=6;

b(0,0)=1;
b(0,1)=2;
b(1,0)=3;
b(1,1)=4;
b(2,0)=5;
b(2,1)=7;

EXPECT_FALSE(a.equals(b));
}

Expand All @@ -91,6 +92,44 @@ TEST(SGMatrixTest,equals_different_size)
SGMatrix<float64_t> b(2,2);
a.zero();
b.zero();

EXPECT_FALSE(a.equals(b));
}

TEST(SGMatrixTest,get_diagonal_vector_square_matrix)
{
SGMatrix<int32_t> mat(3, 3);

mat(0,0)=8;
mat(0,1)=1;
mat(0,2)=6;
mat(1,0)=3;
mat(1,1)=5;
mat(1,2)=7;
mat(2,0)=4;
mat(2,1)=9;
mat(2,2)=2;

SGVector<int32_t> diag=mat.get_diagonal_vector();

EXPECT_EQ(diag[0], 8);
EXPECT_EQ(diag[1], 5);
EXPECT_EQ(diag[2], 2);
}

TEST(SGMatrixTest,get_diagonal_vector_rectangular_matrix)
{
SGMatrix<int32_t> mat(3, 2);

mat(0,0)=8;
mat(0,1)=1;
mat(1,0)=3;
mat(1,1)=5;
mat(2,0)=4;
mat(2,1)=9;

SGVector<int32_t> diag=mat.get_diagonal_vector();

EXPECT_EQ(diag[0], 8);
EXPECT_EQ(diag[1], 5);
}

0 comments on commit d711ce4

Please sign in to comment.