From 2a59acebc3036476dd00e9cbf4b16d7102a980a3 Mon Sep 17 00:00:00 2001 From: Fernando Iglesias Date: Mon, 8 Jul 2013 18:01:38 +0200 Subject: [PATCH] Add method in KNN to obtain nearest neighbors. Internally, the apply now calls the new method to obtain the nearest neighbors first and then computes the labels. --- src/shogun/multiclass/KNN.cpp | 98 ++++++++++++++++++++--------------- src/shogun/multiclass/KNN.h | 11 ++++ 2 files changed, 67 insertions(+), 42 deletions(-) diff --git a/src/shogun/multiclass/KNN.cpp b/src/shogun/multiclass/KNN.cpp index 79956e0f3ea..0d2cf9e35fd 100644 --- a/src/shogun/multiclass/KNN.cpp +++ b/src/shogun/multiclass/KNN.cpp @@ -107,12 +107,56 @@ bool CKNN::train_machine(CFeatures* data) return true; } +SGMatrix CKNN::nearest_neighbors() +{ + //number of examples to which kNN is applied + int32_t n=distance->get_num_vec_rhs(); + //distances to train data + float64_t* dists=SG_MALLOC(float64_t, m_train_labels.vlen); + //indices to train data + int32_t* train_idxs=SG_MALLOC(int32_t, m_train_labels.vlen); + //pre-allocation of the nearest neighbors + SGMatrix NN(m_k, n); + + //for each test example + for (int32_t i=0; i NN = nearest_neighbors(); + + //from the indices to the nearest neighbors, compute the class labels for (int32_t i=0; iset_label(i, out_idx + m_min_label); } @@ -266,8 +282,6 @@ CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data) SG_FREE(classes); SG_FREE(train_lab); - if ( ! m_use_covertree ) - SG_FREE(dists); return output; } diff --git a/src/shogun/multiclass/KNN.h b/src/shogun/multiclass/KNN.h index 20ebb63341d..10bcaa480b2 100644 --- a/src/shogun/multiclass/KNN.h +++ b/src/shogun/multiclass/KNN.h @@ -75,6 +75,17 @@ class CKNN : public CDistanceMachine */ virtual EMachineType get_classifier_type() { return CT_KNN; } + /** + * for each example in the rhs features of the distance member, find the m_k + * nearest neighbors among the vectors in the lhs features + * + * @return matrix with indices to the nearest neighbors, the dimensions of the + * matrix are k rows and n columns, where n is the number of feature vectors in rhs; + * among the nearest neighbors, the closest are in the first row, and the furthest + * in the last one + */ + SGMatrix nearest_neighbors(); + /** classify objects * * @param data (test)data to be classified