Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Large Margin Nearest Neighbor implementation #8602

Open
wants to merge 201 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 194 commits
Commits
Show all changes
201 commits
Select commit Hold shift + click to select a range
b016b70
first commit of lmnn
johny-c Mar 16, 2017
e233427
add import in neighbors/__init__.py
johny-c Mar 16, 2017
d325d6d
[MRG] add lmnn example in neighbors
johny-c Mar 16, 2017
f51641a
[MRG] Large Margin Nearest Neighbor Implementation
johny-c Mar 16, 2017
81594c5
setup logger in init, not in fit
johny-c Mar 16, 2017
6444d66
remove logger, add disp and warnings
johny-c Mar 16, 2017
81fd15a
[MRG] lmnn pep8 compliance
johny-c Mar 16, 2017
6dc7829
[MRG] Large Margin Nearest Neighbor example
johny-c Mar 16, 2017
d38fb0b
[MRG] Large Margin Nearest Neighbor implementation
johny-c Mar 16, 2017
6a08089
[MRG] Large Margin Nearest Neighbor implementation
johny-c Mar 17, 2017
a840e15
[MRG] Large Margin Nearest Neighbor implementation, fix utf8-encoding…
johny-c Mar 17, 2017
e26c0ca
[MRG] Large Margin Nearest Neighbor implementation, remove empty line…
johny-c Mar 17, 2017
62e4413
[MRG] Large Margin Nearest Neighbor implementation, import argpartiti…
johny-c Mar 17, 2017
08b0f08
[MRG] Large Margin Nearest Neighbor implementation, pep8 fix
johny-c Mar 17, 2017
197bfdd
[MRG] Large Margin Nearest Neighbor implementation, fix potential fai…
johny-c Mar 17, 2017
617eaae
add LargeMarginNearestNeighbor to check_non_transformer_estimators_n_…
johny-c Mar 17, 2017
13db483
catch fmin_lbfgs_b exceptions of old scipy
johny-c Mar 17, 2017
1e58e83
add args to fmin_lbfgs_b for intent(inout)
johny-c Mar 17, 2017
942c2a5
add order F flag in check_X for fmin_lbfgs_b for intent(inout)
johny-c Mar 18, 2017
191d9da
raise exception for old version of scipy with lbfgs
johny-c Mar 18, 2017
eef115e
catch exception and terminate for old version of scipy with lbfgs
johny-c Mar 18, 2017
ca2dad8
many changes, according to first comments on PR
johny-c Mar 19, 2017
f83a358
catch - reraise exceptions caused by old Scipy versions
johny-c Mar 20, 2017
d42e05b
add SkipTest for old versions of SciPy
johny-c Mar 20, 2017
18a03f5
merge exceptions caused by old SciPy versions
johny-c Mar 20, 2017
27a1980
Merge branch 'master' into pylmnn
johny-c Mar 20, 2017
3bfbc4c
found bug in sum_outer_products
johny-c Mar 21, 2017
c594a69
found bug in sum_outer_products from old numpy/scipy versions
johny-c Mar 21, 2017
b89aa5f
remove unused try,except, add test_lmnn: partial copy of test_neighbors
johny-c Mar 21, 2017
40f477c
add validate_params(), add test_lmnn: partial copy of test_neighbors
johny-c Mar 22, 2017
cd4b4e5
small fix in checks, fix lmnn test file
johny-c Mar 22, 2017
f3b30c1
fix some tests in lmnn test file
johny-c Mar 22, 2017
9c12cfd
fix some tests in lmnn test file
johny-c Mar 22, 2017
dd6f3c9
fix random.choice from utils.random for numpy < 1.7.0
johny-c Mar 22, 2017
1911f05
fix rng -> rng_, decrease iterations in test_digits
johny-c Mar 22, 2017
00af167
add more tests to check params
johny-c Mar 22, 2017
de1ff46
first commit on out of PR branch
johny-c Mar 23, 2017
b6a88f7
ensure min_samples=2, should be 4 (2+2)
johny-c Mar 23, 2017
2a4c8c5
make verbosity more consistent
johny-c Mar 23, 2017
a302532
add empty char in print message
johny-c Mar 23, 2017
e2bb486
changes according to further comments on PR
johny-c Mar 31, 2017
8149c4c
add some more tests
johny-c Apr 2, 2017
24a87c4
Merge branch 'master' into pylmnn
johny-c Apr 2, 2017
38b727d
move print doc after import
johny-c Apr 2, 2017
3c79c3b
fix test_use_pca
johny-c Apr 2, 2017
4b62e7c
add some more tests
johny-c Apr 2, 2017
291e1d3
add more documentation
johny-c Apr 2, 2017
a6e1778
add doc raises NotFittedError in predict funcs
johny-c Apr 2, 2017
226bb7a
many small improvements in speed and memory, replace unique_pairs wit…
johny-c Apr 9, 2017
9b72731
use sklearn's PCA and NearestNeighbors, check-ignore singleton classes
johny-c May 1, 2017
247bff7
make copy of inputs in case of singleton classes
johny-c May 1, 2017
9c211b4
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c May 1, 2017
2be4d00
small changes
johny-c May 1, 2017
ad8c8c7
use csc_matrix for summing columns
johny-c May 2, 2017
46d63c8
minor changes
johny-c May 2, 2017
c0bec16
pass random_state to PCA
johny-c May 3, 2017
bc2548a
add test for callable and early termination message
johny-c May 3, 2017
aa50245
fix test train-test split order
johny-c May 3, 2017
4cedbcb
Update test_lmnn.py
johny-c May 3, 2017
3b55816
add algorithm parameter
johny-c May 24, 2017
0c7188c
Merge branch 'pylmnn' of github.com:johny-c/scikit-learn into pylmnn
johny-c May 24, 2017
35a156b
speedup by avoiding sparse nonzero(), restore pairs_distances_batch
johny-c May 28, 2017
8eb30cd
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c May 28, 2017
7370ac2
attempt to fix random_state test for some failing tests
johny-c May 28, 2017
b795acc
add first narrative docs
johny-c Jun 9, 2017
09102c0
convert LMNN to Transformer, add doc file, add dim reduction example
johny-c Jun 14, 2017
4ae2e1e
Merge branch 'master' into pylmnn
johny-c Jun 14, 2017
44cd16f
bincount was removed from utils.fixes
johny-c Jun 14, 2017
62ee31b
remove utils.random.choice usage, fix tests, fix doc typo
johny-c Jun 14, 2017
62cadeb
fix PEP8 in example plot, fix some tests
johny-c Jun 14, 2017
cd78724
fix pep8 typos
johny-c Jun 14, 2017
cb1aeea
corrections in documentation
johny-c Jun 18, 2017
aff5d08
matching MNIST accuracy as original code with consistent margin
johny-c Jul 28, 2017
26b1787
rename use_pca -> init_pca
johny-c Jul 28, 2017
40357cc
Merge branch 'pylmnn' of github.com:johny-c/scikit-learn into pylmnn
johny-c Jul 28, 2017
8bbb484
Merge branch 'pylmnn' of github.com:johny-c/scikit-learn into pylmnn
johny-c Jul 28, 2017
0b41d73
fix variable names, default params
johny-c Jul 28, 2017
68a0261
add flush to stdout after some print statements, reorder params
johny-c Jul 29, 2017
0a9d6f1
major speedup by using ravel in numpy where
johny-c Jul 30, 2017
06317b2
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c Jul 31, 2017
5f7e66e
replace np.where(loss > 0) with loss > 0 (bool indexing)
johny-c Jul 31, 2017
9ff8619
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c Jul 31, 2017
27849ae
remove if clauses for scipy < 0.13
johny-c Aug 22, 2017
eada78f
hard-code euclidean_distances batch version in find_impostors_batch a…
johny-c Aug 24, 2017
8f9c848
hard-code euclidean_distances batch version in find_impostors_batch a…
johny-c Aug 24, 2017
692044e
Merge branch 'pylmnn' of github.com:johny-c/scikit-learn into pylmnn
johny-c Aug 24, 2017
564a973
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c Aug 24, 2017
9d28498
add targets parameter, expose select_targets function, add test_targets
johny-c Aug 25, 2017
4de9082
Merge branch 'pylmnn' of github.com:johny-c/scikit-learn into pylmnn
johny-c Aug 25, 2017
91fba3f
linearize indices of impostor-pairs (small speedup)
johny-c Aug 26, 2017
fee84b6
fix list.extend(non-iterable)
johny-c Aug 26, 2017
c9e678b
Try replacing empty with zeros to avoid randomness
jnothman Aug 27, 2017
87bf2d0
In tests, replace assert_array_equal with assert_allclose for compari…
johny-c Aug 27, 2017
276a096
Try explicitly using np.int64 as dtype of targets
johny-c Aug 28, 2017
c907d3d
check targets dtype if given, add training time to the stored info dict
johny-c Aug 29, 2017
1f9b8fd
try loading data astype(float) in case incosistent distances are retu…
johny-c Aug 31, 2017
826c766
fix flakes error (line too long)
johny-c Aug 31, 2017
ed00abf
try digits instead of iris dataset (same targets in 32 vs 64bit) in t…
johny-c Aug 31, 2017
00b2d69
add a few comments on plot_lmnn_dim_reduction.py
bellet Sep 5, 2017
e395d36
update LMM doc
bellet Sep 7, 2017
92a97f2
corrected shape of L
bellet Sep 8, 2017
ca5b980
reduce init_pca, init_transformation and warm_start to single paramet…
johny-c Sep 11, 2017
6b95ede
restate warm_start as a separate boolean parameter
johny-c Sep 11, 2017
caaf5c0
a bit better error message
johny-c Sep 11, 2017
4eb156c
merge change in initialization params
johny-c Sep 11, 2017
a472941
fix docstring output
johny-c Sep 11, 2017
a7820c3
covered some more validation tests and written better error messages
johny-c Sep 12, 2017
5114a39
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c Sep 12, 2017
c1d9f65
extend docstring in validate_params to describe targets and init
johny-c Sep 13, 2017
2be1c01
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c Sep 14, 2017
5554c5a
remove classes from attributes and warm_start check, cleanup and bett…
johny-c Sep 14, 2017
a44d21b
avoid PCA in test_random_state to ensure reproducibility of transform…
johny-c Sep 17, 2017
84522a1
rename Lx to X_embedded
johny-c Sep 17, 2017
938e382
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c Sep 17, 2017
3322d97
add reproducibility warning in Notes, change example to classificatio…
johny-c Sep 18, 2017
f0d21f8
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c Sep 18, 2017
3990888
fix multiline code in docstring
johny-c Sep 18, 2017
7a68665
add changes according to AG review
johny-c Sep 20, 2017
98d4bc5
Merge branch 'pylmnn' of github.com:johny-c/scikit-learn into pylmnn
johny-c Sep 20, 2017
8c2bf26
Merge branch 'PR' of github.com:johny-c/scikit-learn into PR
johny-c Sep 20, 2017
2b6be78
return imp_row, imp_col as int from find_impostors when use_sparse=False
johny-c Sep 20, 2017
d8351f6
many changes (with benchmark) according to review by AG
johny-c Sep 21, 2017
c76cf2b
dont print transformation in the docstring (not reproducible)
johny-c Sep 21, 2017
f6f520f
add LMNN illustration script, better figure for dim-redux, result_ ->…
johny-c Sep 24, 2017
5205e6b
remove networkx dependency from plot_lmnn_illustration
johny-c Sep 24, 2017
c5a7f19
remove uncaught (?) flake8 error
johny-c Sep 24, 2017
ea36c85
add feature space classification example (iris)
johny-c Sep 25, 2017
7125640
minor fixes in lmnn.rst
johny-c Sep 25, 2017
cfcee16
minor changes in lmnn.py
johny-c Sep 25, 2017
a367478
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
johny-c Sep 27, 2017
2c29c0a
update whats_new
johny-c Sep 27, 2017
66ed49e
made batch function and variable names consistent with PR 7979 -> blo…
johny-c Sep 28, 2017
f26d827
print optimization header line in start of lbfgs_callback
johny-c Sep 28, 2017
b1de00b
simplify dim. reduction example script
johny-c Sep 29, 2017
b741062
remove unnecessary unit margin_radii, print num of active constraints
johny-c Sep 29, 2017
0235d27
correct paths to images in lmnn.rst
johny-c Sep 29, 2017
f0eb7e7
fix appearance of documentation plots
johny-c Sep 29, 2017
3893dd2
fix appearance of documentation plots
johny-c Sep 29, 2017
2865ef1
minor fixes in narrative
johny-c Sep 30, 2017
c9f5d50
minor fixes in narrative
johny-c Sep 30, 2017
01e1004
fill margin with color in illustration example
johny-c Oct 1, 2017
9a47e5a
separate push_loss computation, use impostors_graph
johny-c Oct 2, 2017
3db67b4
some pep8 fixes
johny-c Oct 2, 2017
9a9ca4a
some pep8 fixes
johny-c Oct 2, 2017
e22658b
rename max_constraints to max_impostors
johny-c Oct 2, 2017
5d980ed
remove unnecessary try-except clauses for list.extend(np.array)
johny-c Oct 3, 2017
5909a80
minor performance improvement by precomputing sample_range in _comput…
johny-c Oct 5, 2017
8206ee1
make changes requested by GV
johny-c Oct 6, 2017
4d48370
reset HEAD to a good commit
johny-c Oct 24, 2017
97bc63a
merge latest master
johny-c Oct 24, 2017
2993572
small fix in whats_new/v0.20.rst to match master
johny-c Oct 24, 2017
1164a03
make indices np.intp, rename targets as target_neighbors to avoid con…
johny-c Oct 25, 2017
444aff7
add LMNNClassifier class
johny-c Oct 25, 2017
ce67781
add LMNNClassifier and make_lmnn_pipeline
johny-c Oct 26, 2017
ad4746d
dont override kneighbors methods
johny-c Oct 26, 2017
c8c6172
Merge branch 'master' into pylmnn
johny-c Oct 26, 2017
da2dedb
add missing +ELLIPSIS for doctest
johny-c Oct 26, 2017
c98232f
remove p, metric, metric_params from make_lmnn_pipeline
johny-c Oct 31, 2017
c8eb3ac
remove LMNNClassifier class, keep make_lmnn_pipeline function
johny-c Oct 31, 2017
8670d8d
add kwargs in make_lmnn_pipeline for Pipeline args, improve examples …
johny-c Oct 31, 2017
23bfd07
fix whitespace typo in docs
johny-c Oct 31, 2017
476d76f
fix wrong indentation in docs
johny-c Oct 31, 2017
44636b6
remove classes_inverse_non_singleton from class attributes
johny-c Nov 14, 2017
5337752
pass classes to _select_target_neighbors, add test for neighbors_params
johny-c Nov 14, 2017
e35d78e
small speedup (~3.5%) from _sum_weighted_outer_differences
johny-c Nov 15, 2017
eade527
remove spdiags import
johny-c Nov 15, 2017
4b3cf21
add documentation for 'clip' argument of euclidean_distances and writ…
johny-c Nov 15, 2017
17abaeb
add tradeoff parameter 'mu' = weight_push_loss
johny-c Nov 15, 2017
8ddcafe
better docstrings for weight_push_loss (<- @bellet), move LMNN params…
johny-c Nov 16, 2017
8dd6cd4
merge master
johny-c Nov 28, 2017
8d8c3d6
merge master
johny-c Jan 2, 2018
57eb704
a few changes due to JN comments
johny-c Jan 3, 2018
7afa6bb
changes according to latest JN review
johny-c Jan 10, 2018
7e68321
add opt_result attributes description, move euclidean_distances_witho…
johny-c Jan 13, 2018
97d0b49
merge master
johny-c Feb 28, 2019
5342c2f
bring PR up to date, remove make_lmnn_pipeline
johny-c Feb 28, 2019
467e896
remove deprecated assert statements
johny-c Feb 28, 2019
c5e4659
fix some doctests by adding +ELLIPSIS
johny-c Feb 28, 2019
8649035
add missing whitespace
johny-c Mar 1, 2019
981d139
fix flake8 errors
johny-c Mar 1, 2019
78ee381
address comments by banilo
johny-c Mar 1, 2019
5647432
numpy 1.16 renamed dims to shape in unravel_index()
johny-c Mar 1, 2019
7dc57de
address more comments by @banilo
johny-c Mar 2, 2019
90c53b6
add missing tests
johny-c Mar 2, 2019
294d84a
changes according to review by @wdevazelhes
johny-c Mar 9, 2019
f6ac9b8
allow utilizing working memory from sklearn config
johny-c Mar 15, 2019
b37c02b
fix flake8 error
johny-c Mar 15, 2019
f71d01e
merge master
johny-c Dec 18, 2020
c378e15
bring PR up to date with sklearn changes
johny-c Jan 13, 2021
6100c26
fix merge conflict of whats_new v1.0
johny-c Jan 13, 2021
c2d4a9e
remove doctest ELLIPSIS from comments
johny-c Jan 13, 2021
d3de965
remove six import and fix some flake errors
johny-c Jan 13, 2021
cc6eafc
fix triggering errors in documentation
johny-c Jan 13, 2021
f8c8005
address @cmarmo 's comments
johny-c Jan 14, 2021
d10a90e
Merge branch 'master' into pylmnn
johny-c Jan 14, 2021
dd257a1
fix-remove ConvergenceWarning of dim reduction example
johny-c Jan 14, 2021
87cde6b
small improvements in documentation
johny-c Jan 15, 2021
835a86b
Merge remote-tracking branch 'origin/main' into pr/johny-c/8602
glemaitre Jul 29, 2022
316b122
update class docstring
glemaitre Jul 29, 2022
0085610
udpate more docstring
glemaitre Jul 29, 2022
4d7559c
use parameter validation framework
glemaitre Jul 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Expand Up @@ -1324,6 +1324,7 @@ Model validation
neighbors.NearestCentroid
neighbors.NearestNeighbors
neighbors.NeighborhoodComponentsAnalysis
neighbors.LargeMarginNearestNeighbor

.. autosummary::
:toctree: generated/
Expand Down
3 changes: 2 additions & 1 deletion doc/modules/decomposition.rst
Expand Up @@ -971,4 +971,5 @@ when data can be fetched sequentially.
H. F. Kaiser, 1958

See also :ref:`nca_dim_reduction` for dimensionality reduction with
Neighborhood Components Analysis.
Neighborhood Components Analysis or :ref:`lmnn_dim_reduction` for
dimensionality reduction with Large Margin Nearest Neighbor.
211 changes: 211 additions & 0 deletions doc/modules/neighbors.rst
Expand Up @@ -820,3 +820,214 @@ added space complexity in the operation.

`Wikipedia entry on Neighborhood Components Analysis
<https://en.wikipedia.org/wiki/Neighbourhood_components_analysis>`_


.. _lmnn:

Large Margin Nearest Neighbor
=============================

.. sectionauthor:: John Chiotellis <johnyc.code@gmail.com>

Large Margin Nearest Neighbor (LMNN, :class:`LargeMarginNearestNeighbor`) is
a metric learning algorithm which aims to improve the accuracy of
nearest neighbors classification compared to the standard Euclidean distance.

.. |lmnn_illustration_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_illustration_001.png
:target: ../auto_examples/neighbors/plot_lmnn_illustration.html
:scale: 50

.. |lmnn_illustration_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_illustration_002.png
:target: ../auto_examples/neighbors/plot_lmnn_illustration.html
:scale: 50

.. centered:: |lmnn_illustration_1| |lmnn_illustration_2|


For each training sample, the algorithm fixes :math:`k` "target neighbors",
namely the :math:`k`-nearest training samples (as measured by the Euclidean
distance) that share the same label. Given these target neighbors, LMNN
learns a linear transformation of the data by optimizing a trade-off between
two goals. The first one is to make each (transformed) point closer to its
target neighbors than to any differently-labeled point by a large margin,
thereby enclosing the target neighbors in a sphere around the reference
sample. Data samples from different classes that violate this margin are
called "impostors". The second goal is to minimize the distances of each
sample to its target neighbors, which can be seen as a form of regularization.

Classification
--------------

Combined with a nearest neighbors classifier (:class:`KNeighborsClassifier`),
this method is attractive for classification because it can naturally
handle multi-class problems without any increase in the model size, and only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this statement here is fully bullet-proof. As a memory-based algorithm, kNN-type procedure use the training data as 'model', which is why the notion of 'model size' may be less immediately clear in this context.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but even in that case, it holds that nothing needs to be modified in the algorithm to handle multi-class problems (in contrast to e.g. SVMs). Should we change maybe any increase in the model size to any change in the algorithm. I thing this paragraph was by @bellet ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the idea is that the model size of LMNN (as well as kNN) does not depend on the number of classes. In the sense that you can have as many labels as you want in the training data, the number of parameters to learn in LMNN remains the same (and the model size/complexity of kNN remains the same for fixed training set size).

a single parameter (``n_neighbors``) has to be selected by the user before
training.

Large Margin Nearest Neighbor classification has been shown to work well in
practice for data sets of varying size and difficulty. In contrast to
related methods such as Linear Discriminant Analysis, LMNN does not make any
assumptions about the class distributions. The nearest neighbor classification
can naturally produce highly irregular decision boundaries.

To use this model for classification, one needs to combine a :class:`LargeMarginNearestNeighbor`
instance that learns the optimal transformation with a :class:`KNeighborsClassifier`
instance that performs the classification in the embedded space. Here is an
example using the two classes:

>>> from sklearn.neighbors import LargeMarginNearestNeighbor
>>> from sklearn.neighbors import KNeighborsClassifier
>>> from sklearn.datasets import load_iris
>>> from sklearn.model_selection import train_test_split
>>> X, y = load_iris(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
... stratify=y, test_size=0.7, random_state=42)
>>> lmnn = LargeMarginNearestNeighbor(n_neighbors=3, random_state=42)
>>> lmnn.fit(X_train, y_train)
LargeMarginNearestNeighbor(...)
>>> # Apply the learned transformation when using KNeighborsClassifier
>>> knn = KNeighborsClassifier(n_neighbors=3)
>>> knn.fit(lmnn.transform(X_train), y_train)
KNeighborsClassifier(...)
>>> print(knn.score(lmnn.transform(X_test), y_test))
0.971428...

Alternatively, one can create a :class:`sklearn.pipeline.Pipeline` instance
that automatically applies the transformation when fitting or predicting:

>>> from sklearn.pipeline import Pipeline
>>> lmnn = LargeMarginNearestNeighbor(n_neighbors=3, random_state=42)
>>> knn = KNeighborsClassifier(n_neighbors=3)
>>> lmnn_pipe = Pipeline([('lmnn', lmnn), ('knn', knn)])
>>> lmnn_pipe.fit(X_train, y_train)
Pipeline(...)
>>> print(lmnn_pipe.score(X_test, y_test))
0.971428...

.. |lmnn_classification_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_classification_001.png
:target: ../auto_examples/neighbors/plot_lmnn_classification.html
:scale: 50

.. |lmnn_classification_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_classification_002.png
:target: ../auto_examples/neighbors/plot_lmnn_classification.html
:scale: 50

.. centered:: |lmnn_classification_1| |lmnn_classification_2|


The plot shows decision boundaries for nearest neighbor classification and
large margin nearest neighbor classification.

.. _lmnn_dim_reduction:

Dimensionality reduction
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be referenced in doc/modules/decomposition.rst. In fact, I'm starting to think this basically

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...should be in sklearn/decomposition? I thing decomposition implies unsupervised methods, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just mention it there, but I'm a bit ambivalent. Often people use decomposition despite available supervision, so informing them of supervised alternatives seems helpful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar cases probably include linear discrimant analysis, CCA, PLS, principal component regression, ...

------------------------

:class:`LargeMarginNearestNeighbor` can be used to perform supervised
dimensionality reduction. The input data are mapped to a linear subspace
consisting of the directions which minimize the LMNN objective. Unlike
unsupervised methods which aim to maximize the uncorrelatedness (PCA) or even
independence (ICA) of the components, LMNN aims to find components that
maximize the nearest neighbors classification accuracy of the transformed
inputs. The desired output dimensionality can be set using the parameter
``n_components``. For instance, the following shows a comparison of
dimensionality reduction with Principal Component Analysis (:class:`sklearn
.decomposition.PCA`), Linear Discriminant Analysis (:class:`sklearn
.discriminant_analysis.LinearDiscriminantAnalysis`) and Large Margin Nearest
Neighbor (:class:`LargeMarginNearestNeighbor`) on the Olivetti dataset, a
dataset with size :math:`n_{samples} = 400` and :math:`n_{features} = 64 \times 64 = 4096`.
The data set is splitted in a training and test set of equal size. For
evaluation the 3-nearest neighbor classification accuracy is computed on the
2-dimensional embedding found by each method. Each data sample belongs to one
of 40 classes.

.. |lmnn_dim_reduction_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_dim_reduction_001.png
:target: ../auto_examples/neighbors/plot_lmnn_dim_reduction.html
:width: 32%

.. |lmnn_dim_reduction_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_dim_reduction_002.png
:target: ../auto_examples/neighbors/plot_lmnn_dim_reduction.html
:width: 32%

.. |lmnn_dim_reduction_3| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_dim_reduction_003.png
:target: ../auto_examples/neighbors/plot_lmnn_dim_reduction.html
:width: 32%

.. centered:: |lmnn_dim_reduction_1| |lmnn_dim_reduction_2| |lmnn_dim_reduction_3|


Mathematical formulation
------------------------

LMNN learns a linear transformation matrix :math:`L` of
size ``(n_components, n_features)``. The objective function consists of
two competing terms, the pull loss that pulls target neighbors closer to
their reference sample and the push loss that pushes impostors away:

.. math::
\varepsilon_{\text{pull}} (L) = \sum_{i, j \rightsquigarrow i} ||L(x_i - x_j)||^2,
.. math::
\varepsilon_{\text{push}} (L) = \sum_{i, j \rightsquigarrow i}
\sum_{l} (1 - y_{il}) [1 + || L(x_i - x_j)||^2 - || L
(x_i - x_l)||^2]_+,

where :math:`y_{il} = 1` if :math:`y_i = y_l` and :math:`0` otherwise,
:math:`[x]_+ = \max(0, x)` is the hinge loss, and :math:`j \rightsquigarrow i`
means that the :math:`j^{th}` sample is a target neighbor of the
:math:`i^{th}` sample.

LMNN solves the following (nonconvex) minimization problem:

.. math::
\min_L \varepsilon(L) = (1 - \mu) \varepsilon_{\text{pull}} (L) +
\mu \varepsilon_{\text{push}} (L) \text{, } \quad \mu \in [0,1].

The parameter :math:`\mu` (``weight_push_loss``) calibrates the trade-off
between penalizing large distances to target neighbors and penalizing margin
violations by impostors. In practice, the two terms are usually weighted
equally (:math:`\mu = 0.5`).


Mahalanobis distance
^^^^^^^^^^^^^^^^^^^^

LMNN can be seen as learning a (squared) Mahalanobis distance metric:

.. math::
|| L(x_i - x_j)||^2 = (x_i - x_j)^TM(x_i - x_j),

where :math:`M = L^T L` is a symmetric positive semi-definite matrix of size
``(n_features, n_features)``. The objective function of LMNN can be
rewritten and solved with respect to :math:`M` directly. This results in a
convex but constrained problem (since :math:`M` must be symmetric positive
semi-definite). See the journal paper in the References for more details.


Implementation
--------------

This implementation follows closely the MATLAB implementation found at
https://bitbucket.org/mlcircus/lmnn which solves the unconstrained problem.
It finds a linear transformation :math:`L` by optimization with L-BFGS instead
of solving the constrained problem that finds the globally optimal distance
metric. Different from the paper, the problem solved by this implementation is
with the *squared* hinge loss (to make the problem differentiable).

See the examples below and the doc string of :meth:`LargeMarginNearestNeighbor.fit`
for further information.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_neighbors_plot_lmnn_classification.py`
* :ref:`sphx_glr_auto_examples_neighbors_plot_lmnn_dim_reduction.py`


.. topic:: References:

* `"Distance Metric Learning for Large Margin Nearest Neighbor Classification"
<http://jmlr.csail.mit.edu/papers/volume10/weinberger09a/weinberger09a.pdf>`_,
Weinberger, Kilian Q., and Lawrence K. Saul, Journal of Machine Learning
Research, Vol. 10, Feb. 2009, pp. 207-244.

* `Wikipedia entry on Large Margin Nearest Neighbor
<https://en.wikipedia.org/wiki/Large_margin_nearest_neighbor>`_
10 changes: 10 additions & 0 deletions doc/whats_new/v1.0.rst
Expand Up @@ -44,6 +44,15 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.


:mod:`sklearn.neighbors`
........................

- |MajorFeature| A metric learning algorithm:
:class:`neighbors.LargeMarginNearestNeighbor`, which implements the
Large Margin Nearest Neighbor algorithm described in Weinberger et al.
(2006). :pr:`8602` by :user:`John Chiotellis <johny-c>`.

:mod:`sklearn.feature_extraction`
.................................

Expand Down Expand Up @@ -87,6 +96,7 @@ Changelog
Use ``var_`` instead.
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.


Code and Documentation Contributors
-----------------------------------

Expand Down
85 changes: 85 additions & 0 deletions examples/neighbors/plot_lmnn_classification.py
@@ -0,0 +1,85 @@
"""
==========================================================================
Comparing Nearest Neighbors with and without Large Margin Nearest Neighbor
==========================================================================

This example compares nearest neighbors classification with and without
Large Margin Nearest Neighbor.

It will plot the decision boundaries for each class determined by a simple
Nearest Neighbors classifier against the decision boundaries determined by a
Large Margin Nearest Neighbor classifier. The latter aims to find a distance
metric that maximizes the nearest neighbor classification accuracy on a given
training set.
"""

# Author: John Chiotellis <johnyc.code@gmail.com>
# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier, LargeMarginNearestNeighbor
from sklearn.pipeline import Pipeline


print(__doc__)

n_neighbors = 3

# import some data to play with
iris = datasets.load_iris()

# we only take the first two features. We could avoid this ugly
# slicing by using a two-dim dataset
X = iris.data[:, :2]
y = iris.target

X_train, X_test, y_train, y_test = \
train_test_split(X, y, stratify=y, test_size=0.7, random_state=42)

h = .01 # step size in the mesh

# Create color maps
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])

names = ['K-Nearest Neighbors', 'Large Margin Nearest Neighbor']

classifiers = [KNeighborsClassifier(n_neighbors=n_neighbors),
Pipeline([('lmnn', LargeMarginNearestNeighbor(
n_neighbors=n_neighbors, random_state=42)),
('knn', KNeighborsClassifier(n_neighbors))
])
]

x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))

for name, clf in zip(names, classifiers):

clf.fit(X_train, y_train)
score = clf.score(X_test, y_test)

# Plot the decision boundary. For that, we will assign a color to each
# point in the mesh [x_min, x_max]x[y_min, y_max].
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=.8, shading='auto')

# Plot also the training and testing points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor='k', s=20)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("{} (k = {})".format(name, n_neighbors))
plt.text(0.9, 0.1, '{:.2f}'.format(score), size=15,
ha='center', va='center', transform=plt.gca().transAxes)

plt.show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the example should demo the use of the transform method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could show the feature space before and after learning the metric

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can do that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building on Gael's comment, it would be nice to somehow plot a form of the linear subspace for the sake of intuitive appeal to the user...