forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nca.py
531 lines (422 loc) · 19.4 KB
/
nca.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
# coding: utf-8
"""
Neighborhood Component Analysis
"""
# License: BSD 3 Clause
from __future__ import print_function
import numpy as np
import sys
import time
from scipy.misc import logsumexp
from scipy.optimize import minimize
from sklearn.preprocessing import OneHotEncoder
from ..base import BaseEstimator, TransformerMixin
from ..preprocessing import LabelEncoder
from ..decomposition import PCA
from ..utils.multiclass import check_classification_targets
from ..utils.random import check_random_state
from ..utils.validation import check_is_fitted, check_array, check_X_y
from ..externals.six import integer_types
class NeighborhoodComponentsAnalysis(BaseEstimator, TransformerMixin):
"""Neighborhood Components Analysis
Parameters
----------
n_features_out : int, optional (default=None)
Preferred dimensionality of the embedding.
init : string or numpy array, optional (default='pca')
Initialization of the linear transformation. Possible options are
'pca', 'identity', 'random', and a numpy array of shape
(n_features_a, n_features_b).
pca:
``n_features_out`` many principal components of the inputs passed
to :meth:`fit` will be used to initialize the transformation.
identity:
If ``n_features_out`` is strictly smaller than the
dimensionality of the inputs passed to :meth:`fit`, the identity
matrix will be truncated to the first ``n_features_out`` rows.
random:
The initial transformation will be a random array of shape
(n_features_out, n_features). Each value is sampled from the
standard normal distribution.
numpy array:
n_features_b must match the dimensionality of the inputs passed to
:meth:`fit` and n_features_a must be less than or equal to that.
If ``n_features_out`` is not None, n_features_a must match it.
max_iter : int, optional (default=50)
Maximum number of iterations in the optimization.
tol : float, optional (default=1e-5)
Convergence tolerance for the optimization.
callback : callable, optional (default=None)
If not None, this function is called after every iteration of the
optimizer, taking as arguments the current solution (transformation)
and the number of iterations. This might be useful in case one wants
to examine or store the transformation found after each iteration.
store_opt_result : bool, optional (default=False)
If True, the :class:`scipy.optimize.OptimizeResult` object returned by
:meth:`minimize` of `scipy.optimize` will be stored as attribute
``opt_result_``.
verbose : int, optional (default=0)
If 0, no progress messages will be printed.
If 1, progress messages will be printed to stdout.
If > 1, progress messages will be printed and the ``iprint``
parameter of :meth:`_minimize_lbfgsb` of `scipy.optimize` will be set
to ``verbose - 2``.
random_state : int or numpy.RandomState or None, optional (default=None)
A pseudo random number generator object or a seed for it if int. If
``init='random'``, ``random_state`` is used to initialize the random
transformation. If ``init='pca'``, ``random_state`` is passed as an
argument to PCA when initializing the transformation.
Attributes
----------
transformation_ : array, shape (n_features_out, n_features)
The linear transformation learned during fitting.
n_iter_ : int
Counts the number of iterations performed by the optimizer.
opt_result_ : scipy.optimize.OptimizeResult (optional)
A dictionary of information representing the optimization result.
This is stored only if ``store_opt_result`` was True.
Examples
--------
>>> from sklearn.neighbors.nca import NeighborhoodComponentsAnalysis
>>> from sklearn.neighbors import KNeighborsClassifier
>>> from sklearn.datasets import load_iris
>>> from sklearn.model_selection import train_test_split
>>> X, y = load_iris(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
... stratify=y, test_size=0.7, random_state=42)
>>> nca = NeighborhoodComponentsAnalysis(random_state=42)
>>> nca.fit(X_train, y_train) # doctest: +ELLIPSIS
NeighborhoodComponentsAnalysis(...)
>>> knn = KNeighborsClassifier(n_neighbors=3)
>>> knn.fit(X_train, y_train) # doctest: +ELLIPSIS
KNeighborsClassifier(...)
>>> print(knn.score(X_test, y_test))
0.933333333333
>>> knn.fit(nca.transform(X_train), y_train) # doctest: +ELLIPSIS
KNeighborsClassifier(...)
>>> print(knn.score(nca.transform(X_test), y_test))
0.961904761905
Notes
-----
Neighborhood Component Analysis (NCA) is a machine learning algorithm for
metric learning. It learns a linear transformation in a supervised fashion
to improve the classification accuracy of a stochastic nearest neighbors
rule in the transformed space.
References
----------
.. [1] J. Goldberger, G. Hinton, S. Roweis, R. Salakhutdinov.
"Neighbourhood Components Analysis". Advances in Neural Information
Processing Systems. 17, 513-520, 2005.
http://www.cs.nyu.edu/~roweis/papers/ncanips.pdf
.. [2] Wikipedia entry on Neighborhood Components Analysis
https://en.wikipedia.org/wiki/Neighbourhood_components_analysis
"""
def __init__(self, n_features_out=None, init='pca', max_iter=50,
tol=1e-5, callback=None, store_opt_result=False, verbose=0,
random_state=None):
# Parameters
self.n_features_out = n_features_out
self.init = init
self.max_iter = max_iter
self.tol = tol
self.callback = callback
self.store_opt_result = store_opt_result
self.verbose = verbose
self.random_state = random_state
def fit(self, X, y):
"""Fit the model according to the given training data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The training samples.
y : array-like, shape (n_samples,)
The corresponding training labels.
Returns
-------
self : object
returns a trained NeighborhoodComponentsAnalysis model.
"""
# Verify inputs X and y and NCA parameters, and transform a copy if
# needed
X_valid, y_valid, init = self._validate_params(X, y)
# Initialize the random generator
self.random_state_ = check_random_state(self.random_state)
# Measure the total training time
t_train = time.time()
# Compute arrays that stay fixed during optimization:
# mask for fast lookup of same-class samples
masks = OneHotEncoder(sparse=False,
dtype=bool).fit_transform(y_valid[:, np.newaxis])
# pairwise differences
diffs = X_valid[:, np.newaxis] - X_valid[np.newaxis]
# Initialize the transformation
transformation = self._initialize(X_valid, init)
# Create a dictionary of parameters to be passed to the optimizer
disp = self.verbose - 2 if self.verbose > 1 else -1
optimizer_params = {'method': 'L-BFGS-B',
'fun': self._loss_grad_lbfgs,
'args': (X_valid, y_valid, diffs, masks, -1.0),
'jac': True,
'x0': transformation,
'tol': self.tol,
'options': dict(maxiter=self.max_iter, disp=disp),
'callback': self._callback
}
# Call the optimizer
self.n_iter_ = 0
opt_result = minimize(**optimizer_params)
# Reshape the solution found by the optimizer
self.transformation_ = opt_result.x.reshape(-1, X_valid.shape[1])
# Stop timer
t_train = time.time() - t_train
if self.verbose:
print('[{}] Training took {:8.2f}s.'.format(
self.__class__.__name__, t_train))
# Optionally store information returned by the optimizer
if self.store_opt_result:
self.opt_result_ = opt_result
return self
def transform(self, X):
"""Applies the learned transformation to the given data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Data samples.
Returns
-------
X_embedded: array, shape (n_samples, n_features_out)
The data samples transformed.
Raises
------
NotFittedError
If :meth:`fit` has not been called before.
"""
check_is_fitted(self, ['transformation_'])
X = check_array(X)
return np.dot(X, self.transformation_.T)
def _validate_params(self, X, y):
"""Validate parameters as soon as :meth:`fit` is called.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The training samples.
y : array-like, shape (n_samples,)
The corresponding training labels.
Returns
-------
X_valid : array, shape (n_samples, n_features)
The validated training samples.
y_valid : array, shape (n_samples,)
The validated training labels, encoded to be integers in
the range(0, n_classes).
init : string or numpy array of shape (n_features_a, n_features_b)
The validated initialization of the linear transformation.
Raises
-------
TypeError
If a parameter is not an instance of the desired type.
ValueError
If a parameter's value violates its legal value range or if the
combination of two or more given parameters is incompatible.
"""
# Validate the inputs X and y, and converts y to numerical classes.
X_valid, y_valid = check_X_y(X, y, ensure_min_samples=2)
check_classification_targets(y_valid)
y_valid = LabelEncoder().fit_transform(y_valid)
# Check the preferred embedding dimensionality
if self.n_features_out is not None:
_check_scalar(self.n_features_out, 'n_features_out',
integer_types, 1)
if self.n_features_out > X.shape[1]:
raise ValueError('The preferred embedding dimensionality '
'`n_features_out` ({}) cannot be greater '
'than the given data dimensionality ({})!'
.format(self.n_features_out, X.shape[1]))
_check_scalar(self.max_iter, 'max_iter', integer_types, 1)
_check_scalar(self.tol, 'tol', float, 0.)
_check_scalar(self.verbose, 'verbose', integer_types, 0)
if self.callback is not None:
if not callable(self.callback):
raise ValueError('`callback` is not callable.')
# Check how the linear transformation should be initialized
init = self.init
if isinstance(init, np.ndarray):
init = check_array(init)
# Assert that init.shape[1] = X.shape[1]
if init.shape[1] != X_valid.shape[1]:
raise ValueError(
'The input dimensionality ({}) of the given '
'linear transformation `init` must match the '
'dimensionality of the given inputs `X` ({}).'
.format(init.shape[1], X_valid.shape[1]))
# Assert that init.shape[0] <= init.shape[1]
if init.shape[0] > init.shape[1]:
raise ValueError(
'The output dimensionality ({}) of the given '
'linear transformation `init` cannot be '
'greater than its input dimensionality ({}).'
.format(init.shape[0], init.shape[1]))
if self.n_features_out is not None:
# Assert that self.n_features_out = init.shape[0]
if self.n_features_out != init.shape[0]:
raise ValueError(
'The preferred embedding dimensionality '
'`n_features_out` ({}) does not match '
'the output dimensionality of the given '
'linear transformation `init` ({})!'
.format(self.n_features_out,
init.shape[0]))
elif init in ['pca', 'identity', 'random']:
pass
else:
raise ValueError(
"`init` must be 'pca', 'identity', 'random' or a numpy "
"array of shape (n_features_out, n_features).")
return X_valid, y_valid, init
def _initialize(self, X, init):
"""Initialize the transformation.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Data samples.
init : string or numpy array of shape (n_features_a, n_features_b)
The validated initialization of the linear transformation.
Returns
-------
transformation : array, shape (n_features_out, n_features)
The initialized linear transformation.
"""
transformation = init
if isinstance(init, np.ndarray):
pass
else:
n_features_out = self.n_features_out or X.shape[1]
if init == 'identity':
transformation = np.eye(n_features_out, X.shape[1])
elif init == 'random':
transformation = self.random_state_.randn(n_features_out,
X.shape[1])
elif init == 'pca':
pca = PCA(n_components=n_features_out,
random_state=self.random_state_)
t_pca = time.time()
if self.verbose:
print('Finding principal components... ', end='')
sys.stdout.flush()
pca.fit(X)
if self.verbose:
print('done in {:5.2f}s'.format(time.time() - t_pca))
transformation = pca.components_
return transformation
def _callback(self, transformation):
"""Called after each iteration of the optimizer.
Parameters
----------
transformation : array, shape(n_features_out, n_features)
The solution computed by the optimizer in this iteration.
"""
if self.callback is not None:
self.callback(transformation, self.n_iter_)
self.n_iter_ += 1
def _loss_grad_lbfgs(self, transformation, X, y, diffs,
masks, sign=1.0):
"""Compute the loss and the loss gradient w.r.t. ``transformation``.
Parameters
----------
transformation : array, shape (n_features_out, n_features)
The linear transformation on which to compute loss and evaluate
gradient
X : array, shape (n_samples, n_features)
The training samples.
y : array, shape (n_samples,)
The corresponding training labels.
diffs : array, shape (n_samples, n_samples, n_features)
Pairwise differences between training samples.
masks : array, shape (n_samples, n_classes)
One-hot encoding of y.
Returns
-------
loss : float
The loss computed for the given transformation.
gradient : array, shape (n_features_out * n_features,)
The new (flattened) gradient of the loss.
"""
if self.n_iter_ == 0:
self.n_iter_ += 1
if self.verbose:
header_fields = ['Iteration', 'Objective Value', 'Time(s)']
header_fmt = '{:>10} {:>20} {:>10}'
header = header_fmt.format(*header_fields)
cls_name = self.__class__.__name__
print('[{}]'.format(cls_name))
print('[{}] {}\n[{}] {}'.format(cls_name, header,
cls_name, '-' * len(header)))
t_funcall = time.time()
transformation = transformation.reshape(-1, X.shape[1])
loss = 0
gradient = np.zeros(transformation.shape)
X_embedded = transformation.dot(X.T).T
# for every sample x_i, compute its contribution to loss and gradient
for i in range(X.shape[0]):
# compute squared distances to x_i in embedded space
diff_embedded = X_embedded[i] - X_embedded
dist_embedded = np.einsum('ij,ij->i', diff_embedded,
diff_embedded)
dist_embedded[i] = np.inf
# compute exponentiated distances (use the log-sum-exp trick to
# avoid numerical instabilities
exp_dist_embedded = np.exp(-dist_embedded -
logsumexp(-dist_embedded))
ci = masks[:, y[i]] # samples that are in the same class as x_i
p_i_j = exp_dist_embedded[ci]
diff_ci = diffs[i, ci, :]
diff_not_ci = diffs[i, ~ci, :]
sum_ci = diff_ci.T.dot(
(p_i_j[:, np.newaxis] * diff_embedded[ci, :]))
sum_not_ci = diff_not_ci.T.dot((exp_dist_embedded[~ci][:,
np.newaxis] *
diff_embedded[~ci, :]))
p_i = np.sum(p_i_j) # probability of x_i to be correctly
# classified
gradient += 2 * (p_i * (sum_ci.T + sum_not_ci.T) - sum_ci.T)
loss += p_i
if self.verbose:
t_funcall = time.time() - t_funcall
values_fmt = '[{}] {:>10} {:>20.6e} {:>10.2f}'
print(values_fmt.format(self.__class__.__name__, self.n_iter_,
loss, t_funcall))
sys.stdout.flush()
return sign * loss, sign * gradient.ravel()
##########################
# Some helper functions #
#########################
def _check_scalar(x, name, target_type, min_val=None, max_val=None):
"""Validate scalar parameters type and value.
Parameters
----------
x : object
The scalar parameter to validate.
name : str
The name of the parameter to be printed in error messages.
target_type : type or tuple
Acceptable data types for the parameter.
min_val : float or int, optional (default=None)
The minimum value value the parameter can take. If None (default) it
is implied that the parameter does not have a lower bound.
max_val: float or int, optional (default=None)
The maximum valid value the parameter can take. If None (default) it
is implied that the parameter does not have an upper bound.
Raises
-------
TypeError
If the parameter's type does not match the desired type.
ValueError
If the parameter's value violates the given bounds.
"""
if not isinstance(x, target_type):
raise TypeError('`{}` must be an instance of {}, not {}.'
.format(name, target_type, type(x)))
if min_val is not None and x < min_val:
raise ValueError('`{}`= {}, must be >= {}.'.format(name, x, min_val))
if max_val is not None and x > max_val:
raise ValueError('`{}`= {}, must be <= {}.'.format(name, x, max_val))