Skip to content

Commit

Permalink
Moved get_feature(index) from SGSparseFeatures to SGSparseVector. Add…
Browse files Browse the repository at this point in the history
…ed unit-tests to fix behaviour on duplicate entries.
  • Loading branch information
Thoralf Klein committed Aug 8, 2013
1 parent da5891c commit b629081
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 11 deletions.
14 changes: 3 additions & 11 deletions src/shogun/features/SparseFeatures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,11 @@ template<class ST> CFeatures* CSparseFeatures<ST>::duplicate() const

template<class ST> ST CSparseFeatures<ST>::get_feature(int32_t num, int32_t index)
{
ASSERT(index>=0 && index<get_num_features())
ASSERT(num>=0 && num<get_num_vectors())

int32_t i;
SGSparseVector<ST> sv=get_sparse_feature_vector(num);
ST ret = 0 ;

if (sv.features)
{
for (i=0; i<sv.num_feat_entries; i++)
if (sv.features[i].feat_index==index)
ret+=sv.features[i].entry ;
}
ASSERT(index>=0 && index<get_num_features());
ST ret = sv.get_feature(index);

free_sparse_feature_vector(num);

Expand All @@ -100,7 +92,7 @@ template<class ST> SGVector<ST> CSparseFeatures<ST>::get_full_feature_vector(int
dense.zero();

for (int32_t i=0; i<sv.num_feat_entries; i++)
dense.vector[sv.features[i].feat_index]= sv.features[i].entry;
dense.vector[sv.features[i].feat_index] = sv.features[i].entry;
}

free_sparse_feature_vector(num);
Expand Down
13 changes: 13 additions & 0 deletions src/shogun/lib/SGSparseVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ void SGSparseVector<T>::sort_features()
SG_FREE(sf_orig);
}

template<class T>
T SGSparseVector<T>::get_feature(int32_t index) {
T ret = 0;
if (features)
{
for (index_t i=0; i<num_feat_entries; i++)
if (features[i].feat_index==index)
ret+=features[i].entry ;
}

return ret ;
}

template<class T> void SGSparseVector<T>::load(CFile* loader)
{
ASSERT(loader)
Expand Down
8 changes: 8 additions & 0 deletions src/shogun/lib/SGSparseVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ template <class T> class SGSparseVector : public SGReferencedData
*/
void sort_features();

/**
* get feature value for index
*
* @param index
* @return value
*/
T get_feature(int32_t index);

/** load vector from file
*
* @param loader File object via which to load data
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/lib/SGSparseVector_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,43 @@ TEST(SGSparseVector, dense_dot_float64_int32)
EXPECT_NEAR(dot, 2.0, 1E-19);
}

TEST(SGSparseVector, get_feature_unique)
{
SGSparseVector<float64_t> vec(3);
EXPECT_EQ(vec.num_feat_entries, 3);

vec.features[0].feat_index = 1;
vec.features[0].entry = -3.0;

vec.features[1].feat_index = 2;
vec.features[1].entry = -4.0;

vec.features[2].feat_index = 3;
vec.features[2].entry = -6.0;

EXPECT_NEAR(vec.get_feature(1), -3.0, 1E-19);
EXPECT_NEAR(vec.get_feature(2), -4.0, 1E-19);
EXPECT_NEAR(vec.get_feature(3), -6.0, 1E-19);
}

TEST(SGSparseVector, get_feature_duplicate)
{
SGSparseVector<float64_t> vec(3);
EXPECT_EQ(vec.num_feat_entries, 3);

vec.features[0].feat_index = 3;
vec.features[0].entry = -3.0;

vec.features[1].feat_index = 2;
vec.features[1].entry = -4.0;

vec.features[2].feat_index = 3;
vec.features[2].entry = -6.0;

EXPECT_NEAR(vec.get_feature(2), -4.0, 1E-19);
EXPECT_NEAR(vec.get_feature(3), -9.0, 1E-19);
}

TEST(SGSparseVector, sort_features_empty)
{
SGSparseVector<float64_t> vec(0);
Expand Down Expand Up @@ -120,6 +157,10 @@ TEST(SGSparseVector, sort_features_duplicate)
vec.features[3].feat_index = 4;
vec.features[3].entry = -16.0;

EXPECT_NEAR(vec.get_feature(2), -4.0, 1E-19);
EXPECT_NEAR(vec.get_feature(3), -9.0, 1E-19);
EXPECT_NEAR(vec.get_feature(4), -16.0, 1E-19);

vec.sort_features();
EXPECT_EQ(vec.num_feat_entries, 3);

Expand All @@ -131,4 +172,8 @@ TEST(SGSparseVector, sort_features_duplicate)

EXPECT_EQ(vec.features[2].feat_index, 4);
EXPECT_NEAR(vec.features[2].entry, -16.0, 1E-19);

EXPECT_NEAR(vec.get_feature(2), -4.0, 1E-19);
EXPECT_NEAR(vec.get_feature(3), -9.0, 1E-19);
EXPECT_NEAR(vec.get_feature(4), -16.0, 1E-19);
}

0 comments on commit b629081

Please sign in to comment.