Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Improve initialization and learning rate in t-SNE #19491

Merged
merged 71 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
bc1df0c
Rescale PCA initialization to std=1e-4
dkobak Feb 18, 2021
699d66b
Add future warning about PCA init
dkobak Feb 18, 2021
06d69d5
Add learning_rate='auto' and comment on factor 4
dkobak Feb 18, 2021
2ca7375
Add future warning for the learning rate
dkobak Feb 18, 2021
bda63c0
Add whitespace
dkobak Feb 18, 2021
0d13285
Remove whitespaces
dkobak Feb 18, 2021
ac33824
Update sklearn/manifold/_t_sne.py
dkobak Feb 19, 2021
84cc82c
Update sklearn/manifold/_t_sne.py
dkobak Feb 19, 2021
66405d1
Future warning for PCA init
dkobak Feb 19, 2021
869cd80
Do not overwrite default attributes
dkobak Feb 22, 2021
6496935
Remove trailing whitespaces
dkobak Feb 22, 2021
22e6870
Ignore PCA/init future warnings in tests
dkobak Feb 23, 2021
1b2988a
Fix whitespaces before inline comments
dkobak Feb 23, 2021
fac85bd
Fix more whitespaces
dkobak Feb 23, 2021
e527e1a
Add tests for new future warnings and for lr=auto
dkobak Feb 23, 2021
5082626
Fix a bug in one test
dkobak Feb 23, 2021
88f08f9
Fix another bug
dkobak Feb 23, 2021
9378f10
Split one test into two
dkobak Feb 23, 2021
1cd2c80
Rename new tests
dkobak Feb 23, 2021
69ab5da
Correctly handle None to invoke default
dkobak Feb 23, 2021
53aa8a6
Fix the auto learning rate text
dkobak Feb 23, 2021
e8fc787
Fix random state for assert_allclose
dkobak Feb 23, 2021
b252ea1
Fix max bug
dkobak Feb 23, 2021
0c18279
Fix max bug
dkobak Feb 23, 2021
8120814
Ignore future warnings in test
dkobak Feb 23, 2021
652b4ef
Ignore future warnings in more tests
dkobak Feb 23, 2021
1730bd1
Import ignore_warnings
dkobak Feb 23, 2021
0281899
Avoid future warnings in docstring test
dkobak Feb 23, 2021
be28ba3
Describe t-SNE changes in whats_new
dkobak Feb 24, 2021
5ce0eab
Update the docstring example to avoid warnings
dkobak Feb 24, 2021
eb6589d
Fix newline for doctest
dkobak Feb 24, 2021
9cb6bfb
Delete the slow test
dkobak Feb 25, 2021
d416f3f
Add a test for negative learning rate
dkobak Feb 25, 2021
cb71422
Add whitespace
dkobak Feb 25, 2021
04f88f2
Remove unused variable
dkobak Feb 25, 2021
3fe3763
Use pytest.raises
dkobak Feb 25, 2021
cff6b2c
Remove obsolete param
dkobak Feb 25, 2021
f64878f
Merge remote-tracking branch 'upstream/main' into tsne-defaults
dkobak Feb 26, 2021
fd09082
Ignore future warnings in check_n_features_in_after_fitting
dkobak Feb 26, 2021
4dfa098
Use random init for precomputed metric in tests
dkobak Mar 1, 2021
d8312c7
Merge branch 'tsne-defaults' of https://github.com/dkobak/scikit-lear…
dkobak Mar 1, 2021
7b00863
Remove double arg
dkobak Mar 1, 2021
6b34891
Fix one test
dkobak Mar 1, 2021
c076ce3
Init can be ndarray, so compare properly
dkobak Mar 1, 2021
1fbd14f
error message on sparse input with pca init
dkobak Mar 8, 2021
56d1b79
linting fix
dkobak Mar 8, 2021
f7f9680
error type fix
dkobak Mar 8, 2021
41f641e
future warning fix
dkobak Mar 8, 2021
cb25483
Update sklearn/manifold/_t_sne.py
dkobak Apr 7, 2021
34c19d4
Update sklearn/manifold/tests/test_t_sne.py
dkobak Apr 7, 2021
bbdcb92
Update sklearn/manifold/tests/test_t_sne.py
dkobak Apr 7, 2021
2b66d21
Update sklearn/manifold/tests/test_t_sne.py
dkobak Apr 7, 2021
f2890bf
Update sklearn/manifold/tests/test_t_sne.py
dkobak Apr 7, 2021
09e8c9a
Update sklearn/manifold/tests/test_t_sne.py
dkobak Apr 7, 2021
3d946cb
Add TODO comments and docstrings for tests
dkobak Apr 7, 2021
486b578
Merge remote-tracking branch 'upstream/main' into tsne-defaults
dkobak Apr 7, 2021
795caa4
Describe learning_rate=auto in t-SNE doc
dkobak Apr 7, 2021
37a25ff
Fix ignoring future warning
dkobak Apr 7, 2021
fa36a12
Update sklearn/manifold/_t_sne.py
dkobak Apr 16, 2021
861a439
Update sklearn/manifold/_t_sne.py
dkobak Apr 16, 2021
d80ad13
Update sklearn/manifold/tests/test_t_sne.py
dkobak Apr 16, 2021
4c6fa58
Update sklearn/manifold/tests/test_t_sne.py
dkobak Apr 16, 2021
8ea825b
linting fix
dkobak Apr 16, 2021
d0c5696
format references
dkobak Apr 16, 2021
682ebde
linting fix
dkobak Apr 16, 2021
3bf131a
replace ignore_warnings with filterwarning
dkobak Apr 16, 2021
a35bc2d
handle future warnings in tests in other files
dkobak Apr 16, 2021
78dc3f0
import pytest
dkobak Apr 16, 2021
6cfa7f9
roll back because of failing test
dkobak Apr 16, 2021
245d72e
Merge remote-tracking branch 'upstream/main' into pr/19491
thomasjpfan Apr 26, 2021
f1ed1d5
CLN Remove comment for check_n_features_in_after_fitting
thomasjpfan Apr 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ Changelog
during affinity matrix computation for :class:`manifold.TSNE`.
:pr:`19472` by :user:`Dmitry Kobak <dkobak>`.

- |Enhancement| Implement `'auto'` heuristic for the `learning_rate` in
:class:`manifold.TSNE`. It will become default in 1.2. The default
initialization will change to `pca` in 1.2. PCA initialization will
be scaled to have standard deviation 1e-4 in 1.2.
:pr:`19491` by :user:`Dmitry Kobak <dkobak>`.

:mod:`sklearn.metrics`
......................

Expand Down
62 changes: 51 additions & 11 deletions sklearn/manifold/_t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,13 +517,21 @@ class TSNE(BaseEstimator):
optimization, the early exaggeration factor or the learning rate
might be too high.

learning_rate : float, default=200.0
learning_rate : float or 'auto', default=200.0
The learning rate for t-SNE is usually in the range [10.0, 1000.0]. If
the learning rate is too high, the data may look like a 'ball' with any
point approximately equidistant from its nearest neighbours. If the
learning rate is too low, most points may look compressed in a dense
cloud with few outliers. If the cost function gets stuck in a bad local
minimum increasing the learning rate may help.
Note that many other t-SNE implementations (bhtsne, FIt-SNE, openTSNE,
etc.) use a definition of learning_rate that is 4 times smaller than
ours. So our learning_rate=200 corresponds to learning_rate=800 in
those other implementations.
The 'auto' option sets the learning_rate to N / early_exaggeration / 4,
dkobak marked this conversation as resolved.
Show resolved Hide resolved
where N is the sample size, following Belkina et al. 2019 and
Kobak et al. 2019, Nature Communications (or to 50.0, if
N / early_exaggeration / 4 < 50). This will become default in 1.2.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this references into the References section below and link it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, fixed.


n_iter : int, default=1000
Maximum number of iterations for the optimization. Should be at
Expand Down Expand Up @@ -559,7 +567,8 @@ class TSNE(BaseEstimator):
Initialization of embedding. Possible options are 'random', 'pca',
and a numpy array of shape (n_samples, n_components).
PCA initialization cannot be used with precomputed distances and is
usually more globally stable than random initialization.
usually more globally stable than random initialization. It will
become default in 1.2.
dkobak marked this conversation as resolved.
Show resolved Hide resolved

verbose : int, default=0
Verbosity level.
Expand Down Expand Up @@ -631,7 +640,8 @@ class TSNE(BaseEstimator):
>>> import numpy as np
>>> from sklearn.manifold import TSNE
>>> X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
>>> X_embedded = TSNE(n_components=2).fit_transform(X)
>>> X_embedded = TSNE(n_components=2, learning_rate='auto',
... init='random').fit_transform(X)
>>> X_embedded.shape
(4, 2)

Expand All @@ -656,9 +666,9 @@ class TSNE(BaseEstimator):

@_deprecate_positional_args
def __init__(self, n_components=2, *, perplexity=30.0,
early_exaggeration=12.0, learning_rate=200.0, n_iter=1000,
early_exaggeration=12.0, learning_rate="warn", n_iter=1000,
n_iter_without_progress=300, min_grad_norm=1e-7,
metric="euclidean", init="random", verbose=0,
metric="euclidean", init="warn", verbose=0,
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
random_state=None, method='barnes_hut', angle=0.5,
n_jobs=None, square_distances='legacy'):
self.n_components = n_components
Expand All @@ -681,12 +691,35 @@ def __init__(self, n_components=2, *, perplexity=30.0,
def _fit(self, X, skip_num_points=0):
"""Private function to fit the model using X as training data."""

if self.init == 'warn':
# See issue #18018
warnings.warn("The default initialization in TSNE will change "
"from 'random' to 'pca' in 1.2.", FutureWarning)
self._init = 'random'
else:
self._init = self.init
if self.learning_rate == 'warn':
# See issue #18018
warnings.warn("The default learning rate in TSNE will change "
"from 200.0 to 'auto' in 1.2.", FutureWarning)
self._learning_rate = 200.0
else:
self._learning_rate = self.learning_rate

if self.method not in ['barnes_hut', 'exact']:
raise ValueError("'method' must be 'barnes_hut' or 'exact'")
if self.angle < 0.0 or self.angle > 1.0:
raise ValueError("'angle' must be between 0.0 - 1.0")
if self.square_distances not in [True, 'legacy']:
raise ValueError("'square_distances' must be True or 'legacy'.")
if self._learning_rate == 'auto':
# See issue #18018
self._learning_rate = X.shape[0] / self.early_exaggeration / 4
self._learning_rate = np.maximum(self._learning_rate, 50)
else:
if not (self._learning_rate > 0):
raise ValueError("'learning_rate' must be a positive number "
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
"or 'auto'.")
if self.metric != "euclidean" and self.square_distances is not True:
warnings.warn(
"'square_distances' has been introduced in 0.24 to help phase "
Expand All @@ -706,7 +739,7 @@ def _fit(self, X, skip_num_points=0):
X = self._validate_data(X, accept_sparse=['csr', 'csc', 'coo'],
dtype=[np.float32, np.float64])
if self.metric == "precomputed":
if isinstance(self.init, str) and self.init == 'pca':
if isinstance(self._init, str) and self._init == 'pca':
raise ValueError("The parameter init=\"pca\" cannot be "
"used with metric=\"precomputed\".")
if X.shape[0] != X.shape[1]:
Expand Down Expand Up @@ -817,13 +850,20 @@ def _fit(self, X, skip_num_points=0):
P = _joint_probabilities_nn(distances_nn, self.perplexity,
self.verbose)

if isinstance(self.init, np.ndarray):
X_embedded = self.init
elif self.init == 'pca':
if isinstance(self._init, np.ndarray):
X_embedded = self._init
elif self._init == 'pca':
pca = PCA(n_components=self.n_components, svd_solver='randomized',
random_state=random_state)
X_embedded = pca.fit_transform(X).astype(np.float32, copy=False)
elif self.init == 'random':
# PCA is rescaled so that PC1 has standard deviation 1e-4 which is
dkobak marked this conversation as resolved.
Show resolved Hide resolved
# the default value for random initialization. See issue #18018.
warnings.warn("The PCA initialization in TSNE will change to "
"have the standard deviation of PC1 equal to 1e-4 "
"in 1.2. This will ensure better convergence.",
FutureWarning)
# X_embedded = X_embedded / np.std(X_embedded[:, 0]) * 1e-4
elif self._init == 'random':
# The embedding is initialized with iid samples from Gaussians with
# standard deviation 1e-4.
X_embedded = 1e-4 * random_state.randn(
Expand Down Expand Up @@ -857,7 +897,7 @@ def _tsne(self, P, degrees_of_freedom, n_samples, X_embedded,
"it": 0,
"n_iter_check": self._N_ITER_CHECK,
"min_grad_norm": self.min_grad_norm,
"learning_rate": self.learning_rate,
"learning_rate": self._learning_rate,
"verbose": self.verbose,
"kwargs": dict(skip_num_points=skip_num_points),
"args": [P, degrees_of_freedom, n_samples, self.n_components],
Expand Down
Loading