Skip to content

Commit

Permalink
lmnn cookbook
Browse files Browse the repository at this point in the history
  • Loading branch information
Saurabh7 committed Jul 13, 2016
1 parent 79be8f1 commit e88f452
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
@@ -0,0 +1,43 @@
===============================
Large Margin Nearest Neighbours
===============================

Large margin nearest neighbours is a metric learning algorithm. It learns a metric that can be used with the :doc:`knn` algorithm.

The Mahalanobis distance metric which is an instance of :sgclass:`CCustomMahalanobisDistance` is obtained as a result.

See :cite:`weinberger2009distance` for a detailed introduction.

-------
Example
-------

Imagine we have files with training and test data. We create CDenseFeatures (here 64 bit floats aka RealFeatures) and :sgclass:`CMulticlassLabels` as

.. sgexample:: large_margin_nearest_neighbours.sg:create_features

We create an instance of :sgclass:`CLMNN` and provide number of nearest neighbours as parameters.

.. sgexample:: large_margin_nearest_neighbours.sg:create_instance

Next we train the LMNN algorithm and get the learned metric.

.. sgexample:: large_margin_nearest_neighbours.sg:train_metric

Then we train the :sgclass:`CKNN` algorithm using the learned metric and apply it to test data, which here gives :sgclass:`CMulticlassLabels`.

.. sgexample:: large_margin_nearest_neighbours.sg:train_and_apply

We can evaluate test performance via e.g. :sgclass:`CMulticlassAccuracy`.

.. sgexample:: large_margin_nearest_neighbours.sg:evaluate_accuracy


----------
References
----------
:wiki:`Large_margin_nearest_neighbor`

.. bibliography:: ../../references.bib
:filter: docname in docnames

9 changes: 9 additions & 0 deletions doc/cookbook/source/references.bib
Expand Up @@ -109,3 +109,12 @@ @article{Breiman2001
volume={45},
pages={5--32}
}
@article{weinberger2009distance,
title={Distance metric learning for large margin nearest neighbor classification},
author={KQ Weinberger and LK Saul},
journal={Journal of Machine Learning Research},
volume={10},
number={Feb},
pages={207--244},
year={2009}
}
@@ -0,0 +1,36 @@
CSVFile f_feats_train("../../data/classifier_4class_2d_linear_features_train.dat")
CSVFile f_feats_test("../../data/classifier_4class_2d_linear_features_test.dat")
CSVFile f_labels_train("../../data/classifier_4class_2d_linear_labels_train.dat")
CSVFile f_labels_test("../../data/classifier_4class_2d_linear_labels_test.dat")

#![create_features]
RealFeatures features_train(f_feats_train)
RealFeatures features_test(f_feats_test)
MulticlassLabels labels_train(f_labels_train)
MulticlassLabels labels_test(f_labels_test)
#![create_features]

#![create_instance]
int k = 3
LMNN lmnn(features_train, labels_train, k)
#![create_instance]

#![train_metric]
lmnn.train()
CustomMahalanobisDistance lmnn_distance = lmnn.get_distance()
#![train_metric]

#![train_and_apply]
KNN knn(k, lmnn_distance,labels_train)
knn.train()
MulticlassLabels labels_predict = knn.apply_multiclass(features_test)
#![train_and_apply]

#![evaluate_accuracy]
MulticlassAccuracy acc()
real accuracy = acc.evaluate(labels_predict, labels_test)
#![evaluate_accuracy]

# additional integration testing variables
RealVector output = labels_predict.get_labels()

0 comments on commit e88f452

Please sign in to comment.