Skip to content

Commit

Permalink
Merge pull request #3641 from MikeLing/clean_up_KNN_3
Browse files Browse the repository at this point in the history
add tests for KNN and fix an error in KDTree solver
  • Loading branch information
vigsterkr committed Mar 1, 2017
2 parents 7ec1bf1 + 48a1352 commit 9cbd1a1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 8 deletions.
11 changes: 7 additions & 4 deletions src/shogun/multiclass/KDTreeKNNsolver.cpp
Expand Up @@ -40,7 +40,7 @@ CMulticlassLabels* CKDTREEKNNSolver::classify_objects(CDistance* distance, const
output->set_label(i, out_idx + m_min_label);
}
SG_UNREF(query);

SG_UNREF(kd_tree);
return output;
}

Expand All @@ -49,7 +49,7 @@ int32_t* CKDTREEKNNSolver::classify_objects_k(CDistance* distance, const int32_t
int32_t* output=SG_MALLOC(int32_t, m_k*num_lab);

//allocation for distances to nearest neighbors
float64_t* dists=SG_MALLOC(float64_t, m_k);
SGVector<float64_t> dists(m_k);

CFeatures* lhs = distance->get_lhs();
CKDTree* kd_tree = new CKDTree(m_leaf_size);
Expand All @@ -65,12 +65,15 @@ int32_t* CKDTREEKNNSolver::classify_objects_k(CDistance* distance, const int32_t
for (index_t j=0; j<m_k; j++)
{
train_lab[j] = m_train_labels[ NN(j,i) ];
dists[j] = distance->distance(i, NN(j,i));
dists[j] = distance->distance(NN(j,i), i);
}
CMath::qsort_index(dists, train_lab, m_k);
CMath::qsort_index(dists.vector, train_lab, m_k);

choose_class_for_multiple_k(output+i, classes, train_lab, num_lab);
}

SG_UNREF(data);
SG_UNREF(kd_tree);

return output;
}
55 changes: 51 additions & 4 deletions tests/unit/multiclass/KNN_unittest.cc
Expand Up @@ -91,7 +91,7 @@ TEST(KNN, kdtree_solver)

int32_t k=4;
CEuclideanDistance* distance = new CEuclideanDistance();
CKNN* knn=new CKNN (k, distance, labels, KNN_BRUTE);
CKNN* knn=new CKNN (k, distance, labels, KNN_KDTREE);
SG_REF(knn);

features->add_subset(train);
Expand Down Expand Up @@ -159,6 +159,53 @@ TEST(KNN, lsh_solver)
SG_UNREF(knn);
}

TEST(KNN, classify_multiple_brute)
{
int32_t num = 50;
int32_t feats = 2;
int32_t classes = 3;

SGVector< float64_t > lab(classes*num);
SGMatrix< float64_t > feat(feats, classes*num);

generate_knn_data(feat, lab, num, classes, feats);
SGVector<index_t> train (int32_t(num*classes*0.75));
SGVector<index_t> test (int32_t(num*classes*0.25));
train.random(0, classes*num-1);
test.random(0, classes*num-1);

CMulticlassLabels* labels = new CMulticlassLabels(lab);
CDenseFeatures< float64_t >* features = new CDenseFeatures< float64_t >(feat);
CFeatures* features_test = (CFeatures*) features->clone();
CLabels* labels_test = (CLabels*) labels->clone();

int32_t k=4;
CEuclideanDistance* distance = new CEuclideanDistance();
CKNN* knn=new CKNN (k, distance, labels, KNN_BRUTE);
SG_REF(knn);

features->add_subset(train);
labels->add_subset(train);
knn->train(features);

// classify for multiple k
features_test->add_subset(test);
labels_test->add_subset(test);

CEuclideanDistance* dist = new CEuclideanDistance(features, ((CDotFeatures*)features_test));
knn->set_distance(dist);
SGMatrix<int32_t> out_mat =knn->classify_for_multiple_k();
features_test->remove_subset();

for ( index_t i = 0; i < labels_test->get_num_labels(); ++i )
for ( index_t j = 0; j < k; ++j )
EXPECT_EQ(out_mat(i, j), ((CMulticlassLabels*)labels_test)->get_label(i));

SG_UNREF(knn);
SG_UNREF(features_test);
SG_UNREF(labels_test);
}


TEST(KNN, classify_multiple_kdtree)
{
Expand All @@ -169,7 +216,7 @@ TEST(KNN, classify_multiple_kdtree)

SGVector< float64_t > lab(classes*num);
SGMatrix< float64_t > feat(feats, classes*num);

generate_knn_data(feat, lab, num, classes, feats);
SGVector<index_t> train (int32_t(num*classes*0.75));
SGVector<index_t> test (int32_t(num*classes*0.25));
Expand All @@ -183,7 +230,7 @@ TEST(KNN, classify_multiple_kdtree)

int32_t k=4;
CEuclideanDistance* distance = new CEuclideanDistance();
CKNN* knn=new CKNN (k, distance, labels, KNN_BRUTE);
CKNN* knn=new CKNN (k, distance, labels, KNN_KDTREE);
SG_REF(knn);

features->add_subset(train);
Expand All @@ -193,7 +240,7 @@ TEST(KNN, classify_multiple_kdtree)
// classify for multiple k
features_test->add_subset(test);
labels_test->add_subset(test);
CEuclideanDistance* dist = new CEuclideanDistance(features, ((CDotFeatures*)features_test));
CEuclideanDistance* dist = new CEuclideanDistance(features, ((CDotFeatures*)features_test));
knn->set_distance(dist);
SGMatrix<int32_t> out_mat =knn->classify_for_multiple_k();
features_test->remove_subset();
Expand Down

0 comments on commit 9cbd1a1

Please sign in to comment.