Skip to content

Commit

Permalink
Add method in KNN to obtain nearest neighbors.
Browse files Browse the repository at this point in the history
Internally, the apply now calls the new method to obtain the nearest
neighbors first and then computes the labels.
  • Loading branch information
iglesias committed Jul 8, 2013
1 parent 583fc6c commit 2a59ace
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 42 deletions.
98 changes: 56 additions & 42 deletions src/shogun/multiclass/KNN.cpp
Expand Up @@ -107,12 +107,56 @@ bool CKNN::train_machine(CFeatures* data)
return true;
}

SGMatrix<int32_t> CKNN::nearest_neighbors()
{
//number of examples to which kNN is applied
int32_t n=distance->get_num_vec_rhs();
//distances to train data
float64_t* dists=SG_MALLOC(float64_t, m_train_labels.vlen);
//indices to train data
int32_t* train_idxs=SG_MALLOC(int32_t, m_train_labels.vlen);
//pre-allocation of the nearest neighbors
SGMatrix<int32_t> NN(m_k, n);

//for each test example
for (int32_t i=0; i<n && (!CSignal::cancel_computations()); i++)
{
SG_PROGRESS(i, 0, n)

//lhs idx 0..n-1 (i.e., all train examples) and rhs idx i
distances_lhs(dists,0,m_train_labels.vlen-1,i);

//fill in an array with 0, 1,2, ..., num train examples
for (int32_t j=0; j<m_train_labels.vlen; j++)
train_idxs[j]=j;

//sort the distance vector between test example i and all train examples
CMath::qsort_index(dists, train_idxs, m_train_labels.vlen);

#ifdef DEBUG_KNN
SG_PRINT("\nQuick sort query %d\n", i)
for (int32_t j=0; j<m_k; j++)
SG_PRINT("%d ", train_idxs[j])
SG_PRINT("\n")
#endif

//fill in the output the indices of the nearest neighbors
for (int32_t j=0; j<m_k; j++)
NN(j,i) = train_idxs[j];
}

SG_FREE(train_idxs);
SG_FREE(dists);

return NN;
}

CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data)
{
if (data)
init_distance(data);

// redirecting to fast (without sorting) classify if k==1
//redirecting to fast (without sorting) classify if k==1
if (m_k == 1)
return classify_NN();

Expand All @@ -125,24 +169,13 @@ CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data)

CMulticlassLabels* output=new CMulticlassLabels(num_lab);

float64_t* dists = NULL;
int32_t* train_lab = NULL;

//distances to train data and working buffer of m_train_labels
if ( ! m_use_covertree )
{
dists=SG_MALLOC(float64_t, m_train_labels.vlen);
train_lab=SG_MALLOC(int32_t, m_train_labels.vlen);
}
else
{
train_lab=SG_MALLOC(int32_t, m_k);
}
//labels of the k nearest neighbors
int32_t* train_lab=SG_MALLOC(int32_t, m_k);

SG_INFO("%d test examples\n", num_lab)
CSignal::clear_cancel();

///histogram of classes and returned output
//histogram of classes and returned output
float64_t* classes=SG_MALLOC(float64_t, m_num_classes);

#ifdef BENCHMARK_KNN
Expand All @@ -152,36 +185,19 @@ CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data)

if ( ! m_use_covertree )
{
//get the k nearest neighbors of each example
SGMatrix<int32_t> NN = nearest_neighbors();

//from the indices to the nearest neighbors, compute the class labels
for (int32_t i=0; i<num_lab && (!CSignal::cancel_computations()); i++)
{
SG_PROGRESS(i, 0, num_lab)

#ifdef DEBUG_KNN
distances_lhs(dists,0,m_train_labels.vlen-1,i);

for (int32_t j=0; j<m_train_labels.vlen; j++)
train_lab[j]=j;

CMath::qsort_index(dists, train_lab, m_train_labels.vlen);

SG_PRINT("\nQuick sort query %d\n", i)
//write the labels of the k nearest neighbors from theirs indices
for (int32_t j=0; j<m_k; j++)
SG_PRINT("%d ", train_lab[j])
SG_PRINT("\n")
#endif

//lhs idx 1..n and rhs idx i
distances_lhs(dists,0,m_train_labels.vlen-1,i);

for (int32_t j=0; j<m_train_labels.vlen; j++)
train_lab[j]=m_train_labels.vector[j];

//sort the distance vector for test example j to all
//train examples
CMath::qsort_index(dists, train_lab, m_train_labels.vlen);
train_lab[j] = m_train_labels[ NN(j,i) ];

// Get the index of the 'nearest' class
//get the index of the 'nearest' class
int32_t out_idx = choose_class(classes, train_lab);
//write the label of 'nearest' in the output
output->set_label(i, out_idx + m_min_label);
}

Expand Down Expand Up @@ -266,8 +282,6 @@ CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data)

SG_FREE(classes);
SG_FREE(train_lab);
if ( ! m_use_covertree )
SG_FREE(dists);

return output;
}
Expand Down
11 changes: 11 additions & 0 deletions src/shogun/multiclass/KNN.h
Expand Up @@ -75,6 +75,17 @@ class CKNN : public CDistanceMachine
*/
virtual EMachineType get_classifier_type() { return CT_KNN; }

/**
* for each example in the rhs features of the distance member, find the m_k
* nearest neighbors among the vectors in the lhs features
*
* @return matrix with indices to the nearest neighbors, the dimensions of the
* matrix are k rows and n columns, where n is the number of feature vectors in rhs;
* among the nearest neighbors, the closest are in the first row, and the furthest
* in the last one
*/
SGMatrix<int32_t> nearest_neighbors();

/** classify objects
*
* @param data (test)data to be classified
Expand Down

0 comments on commit 2a59ace

Please sign in to comment.