diff --git a/src/shogun/lib/SGMatrix.cpp b/src/shogun/lib/SGMatrix.cpp index adde27c41c5..c38a7cab81e 100644 --- a/src/shogun/lib/SGMatrix.cpp +++ b/src/shogun/lib/SGMatrix.cpp @@ -394,7 +394,7 @@ void SGMatrix::create_diagonal_matrix(T* matrix, T* v,int32_t size) } template -SGMatrix SGMatrix::submatrix(index_t col_start, index_t col_end) const +SGMatrix SGMatrix::slice(index_t col_start, index_t col_end) const { assert_on_cpu(); return SGMatrix( diff --git a/src/shogun/lib/SGMatrix.h b/src/shogun/lib/SGMatrix.h index f66bc1f5976..d99c848c27c 100644 --- a/src/shogun/lib/SGMatrix.h +++ b/src/shogun/lib/SGMatrix.h @@ -151,7 +151,7 @@ template class SGMatrix : public SGReferencedData * @param col_end column index (excluded) * @return the submatrix */ - SGMatrix submatrix(index_t col_start, index_t col_end) const; + SGMatrix slice(index_t col_start, index_t col_end) const; /** Map a column to a SGVector * \warning The returned SGVector is non-owning! diff --git a/src/shogun/lib/SGVector.cpp b/src/shogun/lib/SGVector.cpp index a5a913b45f9..06ea93cab73 100644 --- a/src/shogun/lib/SGVector.cpp +++ b/src/shogun/lib/SGVector.cpp @@ -309,6 +309,15 @@ void SGVector::resize_vector(int32_t n) vlen=n; } +template +SGVector SGVector::slice(index_t l, index_t h) const +{ + assert_on_cpu(); + ASSERT(l >= 0 && h <= vlen); + + return SGVector(vector, h - l, l); +} + /** addition operator */ template SGVector SGVector::operator+ (SGVector x) diff --git a/src/shogun/lib/SGVector.h b/src/shogun/lib/SGVector.h index bfb8b6bd0dc..dbbcf887121 100644 --- a/src/shogun/lib/SGVector.h +++ b/src/shogun/lib/SGVector.h @@ -270,6 +270,13 @@ template class SGVector : public SGReferencedData */ void resize_vector(int32_t n); + /** Returns a view of the vector from l inclusive to h exclusive + * + * @param l slice start index (inclusive) + * @param h slice end index (exclusive) + */ + SGVector slice(index_t l, index_t h) const; + /** Operator overload for vector read only access * * @param index dimension to access diff --git a/src/shogun/multiclass/QDA.cpp b/src/shogun/multiclass/QDA.cpp index 0c6d3a9617b..114ce69d335 100644 --- a/src/shogun/multiclass/QDA.cpp +++ b/src/shogun/multiclass/QDA.cpp @@ -144,7 +144,7 @@ CMulticlassLabels* CQDA::apply_multiclass(CFeatures* data) rf->free_feature_vector(vec, i); } - Map Em_M(m_M.submatrix(m_dim * k, m_dim * (k + 1))); + Map Em_M(m_M.slice(m_dim * k, m_dim * (k + 1))); A = X*Em_M; for (int i = 0; i < num_vecs; i++) diff --git a/tests/unit/lib/SGMatrix_unittest.cc b/tests/unit/lib/SGMatrix_unittest.cc index bd7c97b9e59..ad0bd34f07d 100644 --- a/tests/unit/lib/SGMatrix_unittest.cc +++ b/tests/unit/lib/SGMatrix_unittest.cc @@ -636,7 +636,7 @@ TEST(SGMatrixTest, max_single) EXPECT_GE(max, mat.matrix[i]); } -TEST(SGMatrixTest, get_submatrix) +TEST(SGMatrixTest, get_slice) { const index_t n_rows = 6, n_cols = 8; const index_t start_col = 2, end_col = 5; @@ -646,7 +646,7 @@ TEST(SGMatrixTest, get_submatrix) for (index_t i = 0; i < n_rows * n_cols; ++i) mat[i] = CMath::randn_double(); - auto sub = mat.submatrix(start_col, end_col); + auto sub = mat.slice(start_col, end_col); EXPECT_EQ(sub.num_rows, mat.num_rows); EXPECT_EQ(sub.num_cols, end_col - start_col); diff --git a/tests/unit/lib/SGVector_unittest.cc b/tests/unit/lib/SGVector_unittest.cc index 4510a0f0add..7e0dc9a6b00 100644 --- a/tests/unit/lib/SGVector_unittest.cc +++ b/tests/unit/lib/SGVector_unittest.cc @@ -468,3 +468,46 @@ TEST(SGVectorTest, as) EXPECT_EQ((int32_t)data[i], vec_int[i]); } } + +TEST(SGVectorTest, slice) +{ + index_t vlen = 100; + index_t l = 20; + index_t h = 50; + index_t l2 = 10; + index_t h2 = 20; + index_t c1 = 0; + index_t c2 = 1; + + SGVector vec(vlen); + vec.range_fill(); + + auto sliced_vec = vec.slice(l, h); + EXPECT_EQ(sliced_vec.size(), h - l); + for (index_t i = 0; i < h - l; i++) + { + EXPECT_EQ(vec[i + l], sliced_vec[i]); + } + + auto sliced_vec2 = sliced_vec.slice(l2, h2); + EXPECT_EQ(sliced_vec2.size(), h2 - l2); + for (index_t i = 0; i < h2 - l2; i++) + { + EXPECT_EQ(sliced_vec[i + l2], sliced_vec2[i]); + } + + sliced_vec.set_const(c1); + auto sliced_vec3 = sliced_vec2.slice(0, sliced_vec2.vlen); + EXPECT_EQ(sliced_vec2, sliced_vec3); + sliced_vec3.set_const(c2); + + for (index_t i = 0; i < vlen; i++) + { + if (i < l || i >= h) + EXPECT_EQ(vec[i], i); + else if (i < l + l2 || i >= l + h2) + EXPECT_EQ(vec[i], sliced_vec[i - l]); + else + EXPECT_EQ(vec[i], sliced_vec2[i - l - l2]); + } +}