diff --git a/src/shogun/lib/SGMatrix.cpp b/src/shogun/lib/SGMatrix.cpp index 0c8b22e1649..faca6544e29 100644 --- a/src/shogun/lib/SGMatrix.cpp +++ b/src/shogun/lib/SGMatrix.cpp @@ -932,7 +932,7 @@ void SGMatrix::load(CFile* loader) { ASSERT(loader) unref(); - + SG_SET_LOCALE_C; SGMatrix mat; loader->get_matrix(mat.matrix, mat.num_rows, mat.num_cols); @@ -974,6 +974,20 @@ SGVector SGMatrix::get_row_vector(index_t row) const return rowv; } +template +SGVector SGMatrix::get_diagonal_vector() const +{ + index_t diag_vlen=CMath::min(num_cols, num_rows); + SGVector diag(diag_vlen); + + for (index_t i=0; i; template class SGMatrix; template class SGMatrix; diff --git a/src/shogun/lib/SGMatrix.h b/src/shogun/lib/SGMatrix.h index e5087dd1e99..a0ebda1f07d 100644 --- a/src/shogun/lib/SGMatrix.h +++ b/src/shogun/lib/SGMatrix.h @@ -61,6 +61,12 @@ template class SGMatrix : public SGReferencedData */ SGVector get_row_vector(index_t row) const; + /** get a main diagonal vector. Matrix is not required to be square. + * + * @return main diagonal vector + */ + SGVector get_diagonal_vector() const; + /** operator overload for matrix read only access * @param i_row * @param i_col @@ -95,7 +101,7 @@ template class SGMatrix : public SGReferencedData return matrix[index]; } - /** + /** * get the matrix (no copying is done here) * * @return the refcount increased matrix @@ -180,7 +186,7 @@ template class SGMatrix : public SGReferencedData static double* compute_eigenvectors(double* matrix, int n, int m); /** compute few eigenpairs of a symmetric matrix using LAPACK DSYEVR method - * (Relatively Robust Representations). + * (Relatively Robust Representations). * Has at least O(n^3/3) complexity * @param matrix_ symmetric matrix * @param eigenvalues contains iu-il+1 eigenvalues in ascending order (to be free'd) diff --git a/tests/unit/lib/SGMatrix_unittest.cc b/tests/unit/lib/SGMatrix_unittest.cc index 3a4f0691154..e395bd342f5 100644 --- a/tests/unit/lib/SGMatrix_unittest.cc +++ b/tests/unit/lib/SGMatrix_unittest.cc @@ -1,4 +1,5 @@ #include +#include #include #include @@ -53,14 +54,14 @@ TEST(SGMatrixTest,equals_equal) a(1,1)=4; a(2,0)=5; a(2,1)=6; - + b(0,0)=1; b(0,1)=2; b(1,0)=3; b(1,1)=4; b(2,0)=5; b(2,1)=6; - + EXPECT_TRUE(a.equals(b)); } @@ -74,14 +75,14 @@ TEST(SGMatrixTest,equals_different) a(1,1)=4; a(2,0)=5; a(2,1)=6; - + b(0,0)=1; b(0,1)=2; b(1,0)=3; b(1,1)=4; b(2,0)=5; b(2,1)=7; - + EXPECT_FALSE(a.equals(b)); } @@ -91,6 +92,44 @@ TEST(SGMatrixTest,equals_different_size) SGMatrix b(2,2); a.zero(); b.zero(); - + EXPECT_FALSE(a.equals(b)); } + +TEST(SGMatrixTest,get_diagonal_vector_square_matrix) +{ + SGMatrix mat(3, 3); + + mat(0,0)=8; + mat(0,1)=1; + mat(0,2)=6; + mat(1,0)=3; + mat(1,1)=5; + mat(1,2)=7; + mat(2,0)=4; + mat(2,1)=9; + mat(2,2)=2; + + SGVector diag=mat.get_diagonal_vector(); + + EXPECT_EQ(diag[0], 8); + EXPECT_EQ(diag[1], 5); + EXPECT_EQ(diag[2], 2); +} + +TEST(SGMatrixTest,get_diagonal_vector_rectangular_matrix) +{ + SGMatrix mat(3, 2); + + mat(0,0)=8; + mat(0,1)=1; + mat(1,0)=3; + mat(1,1)=5; + mat(2,0)=4; + mat(2,1)=9; + + SGVector diag=mat.get_diagonal_vector(); + + EXPECT_EQ(diag[0], 8); + EXPECT_EQ(diag[1], 5); +}