Join GitHub today
GitHub is home to over 31 million developers working together to host and review code, manage projects, and build software together.Sign up
[MRG+2] Neighborhood Components Analysis #10058
Hi, this PR is an implementation of the Neighborhood Components Analysis algorithm (NCA), a popular supervised distance metric learning algorithm. As LMNN (cf PR #8602) this algorithm takes as input a labeled dataset, instead of similar/dissimilar pairs like it is the case for most metric learning algorithms, and learns a linear transformation of the space. However, NCA and LMNN have different objective functions: NCA tries to maximise the probability of every sample to be correctly classified based on a stochastic nearest neighbors rule, and therefore does not need to fix in advance a set of target neighbors.
There have been several attempts to implement NCA (2 PRs: #5276 (closed) and #4789 (not closed)). I created a fresh PR for sake of clarity. Indeed, this code is intended to be as similar to LMNN as possible, which should allow the factorisation of some points of code which are the same in both algorithms.
At the time of writing, this algorithm uses scipy's L-BFGS-B solver to solve the optimisation problem, like LMNN. It has the big advantage of avoiding to tune a learning rate parameter.
The remaining tasks are the following:
What is more, some improvements could also be made in a second time:
Feedback is welcome !
Here is a snippet that shows it (on my machine which has 7.7GB of memory):
from sklearn.datasets import load_digits from metric_learn import NCA from sklearn.neighbors import NeighborhoodComponentsAnalysis from sklearn.utils.testing import assert_raises digits = load_digits() X, y = digits.data, digits.target nca_ml = NCA() assert_raises(MemoryError, nca_ml.fit, X, y) nca_sk = NeighborhoodComponentsAnalysis() nca_sk.fit(X, y) # does not raise any error
@@ Coverage Diff @@ ## master #10058 +/- ## ========================================== + Coverage 96.19% 96.21% +0.01% ========================================== Files 336 338 +2 Lines 62740 63025 +285 ========================================== + Hits 60354 60638 +284 - Misses 2386 2387 +1
I benchmarked this PR implementation of NCA against metric-learn one: I plotted the training curves (objective function vs time), for the same initialisation (identity), on
At some points, metric-learn NCA training is interrupted prematurely: this is due to numerical instabilities, and this warning is thrown:
Feb 28, 2019
2 of 7 checks passed
That's great ! Thanks a lot for your reviews and comments @jnothman @agramfort @GaelVaroquaux @bellet, and congrats to @johny-c too ! I'm excited to work on improvements later on. Also to transpose the changes of this PR to LMNN so that it can be merged, or maybe we'll see first how NCA goes for a bit of time and then see what to do for LMNN