From e88f4529150fcbf5e97b3400d59fb42e2f7535bd Mon Sep 17 00:00:00 2001 From: Saurabh7 Date: Tue, 21 Jun 2016 19:35:00 +0530 Subject: [PATCH] lmnn cookbook --- .../large_margin_nearest_neighbours.rst | 43 +++++++++++++++++++ doc/cookbook/source/references.bib | 9 ++++ .../large_margin_nearest_neighbours.sg | 36 ++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 doc/cookbook/source/examples/multiclass_classifier/large_margin_nearest_neighbours.rst create mode 100644 examples/meta/src/multiclass_classifier/large_margin_nearest_neighbours.sg diff --git a/doc/cookbook/source/examples/multiclass_classifier/large_margin_nearest_neighbours.rst b/doc/cookbook/source/examples/multiclass_classifier/large_margin_nearest_neighbours.rst new file mode 100644 index 00000000000..7ff23a77b4c --- /dev/null +++ b/doc/cookbook/source/examples/multiclass_classifier/large_margin_nearest_neighbours.rst @@ -0,0 +1,43 @@ +=============================== +Large Margin Nearest Neighbours +=============================== + +Large margin nearest neighbours is a metric learning algorithm. It learns a metric that can be used with the :doc:`knn` algorithm. + +The Mahalanobis distance metric which is an instance of :sgclass:`CCustomMahalanobisDistance` is obtained as a result. + +See :cite:`weinberger2009distance` for a detailed introduction. + +------- +Example +------- + +Imagine we have files with training and test data. We create CDenseFeatures (here 64 bit floats aka RealFeatures) and :sgclass:`CMulticlassLabels` as + +.. sgexample:: large_margin_nearest_neighbours.sg:create_features + +We create an instance of :sgclass:`CLMNN` and provide number of nearest neighbours as parameters. + +.. sgexample:: large_margin_nearest_neighbours.sg:create_instance + +Next we train the LMNN algorithm and get the learned metric. + +.. sgexample:: large_margin_nearest_neighbours.sg:train_metric + +Then we train the :sgclass:`CKNN` algorithm using the learned metric and apply it to test data, which here gives :sgclass:`CMulticlassLabels`. + +.. sgexample:: large_margin_nearest_neighbours.sg:train_and_apply + +We can evaluate test performance via e.g. :sgclass:`CMulticlassAccuracy`. + +.. sgexample:: large_margin_nearest_neighbours.sg:evaluate_accuracy + + +---------- +References +---------- +:wiki:`Large_margin_nearest_neighbor` + +.. bibliography:: ../../references.bib + :filter: docname in docnames + diff --git a/doc/cookbook/source/references.bib b/doc/cookbook/source/references.bib index 74597e5eb02..a59f26d04a7 100644 --- a/doc/cookbook/source/references.bib +++ b/doc/cookbook/source/references.bib @@ -109,3 +109,12 @@ @article{Breiman2001 volume={45}, pages={5--32} } +@article{weinberger2009distance, + title={Distance metric learning for large margin nearest neighbor classification}, + author={KQ Weinberger and LK Saul}, + journal={Journal of Machine Learning Research}, + volume={10}, + number={Feb}, + pages={207--244}, + year={2009} +} diff --git a/examples/meta/src/multiclass_classifier/large_margin_nearest_neighbours.sg b/examples/meta/src/multiclass_classifier/large_margin_nearest_neighbours.sg new file mode 100644 index 00000000000..e33c492e7e1 --- /dev/null +++ b/examples/meta/src/multiclass_classifier/large_margin_nearest_neighbours.sg @@ -0,0 +1,36 @@ +CSVFile f_feats_train("../../data/classifier_4class_2d_linear_features_train.dat") +CSVFile f_feats_test("../../data/classifier_4class_2d_linear_features_test.dat") +CSVFile f_labels_train("../../data/classifier_4class_2d_linear_labels_train.dat") +CSVFile f_labels_test("../../data/classifier_4class_2d_linear_labels_test.dat") + +#![create_features] +RealFeatures features_train(f_feats_train) +RealFeatures features_test(f_feats_test) +MulticlassLabels labels_train(f_labels_train) +MulticlassLabels labels_test(f_labels_test) +#![create_features] + +#![create_instance] +int k = 3 +LMNN lmnn(features_train, labels_train, k) +#![create_instance] + +#![train_metric] +lmnn.train() +CustomMahalanobisDistance lmnn_distance = lmnn.get_distance() +#![train_metric] + +#![train_and_apply] +KNN knn(k, lmnn_distance,labels_train) +knn.train() +MulticlassLabels labels_predict = knn.apply_multiclass(features_test) +#![train_and_apply] + +#![evaluate_accuracy] +MulticlassAccuracy acc() +real accuracy = acc.evaluate(labels_predict, labels_test) +#![evaluate_accuracy] + +# additional integration testing variables +RealVector output = labels_predict.get_labels() +