New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Reducing t-SNE memory usage #9032

Merged
merged 55 commits into from Jul 12, 2017

Conversation

Projects
None yet
7 participants
@tomMoral
Contributor

tomMoral commented Jun 7, 2017

The barnes-hut algorithm for t-SNE currently have a O(N^2) memory complexity, it could use O(uN). This PR intend to improve the memory usage. (see Issue scikit-learn/scikit-learn#7089)

Step to proceed

  • Only compute the nearest neighbors distances
  • Validate _barnes_hut2 with _barnes_hut functions
  • Check memory usage
  • Set default values of optimizer hyperparams to match the reference implementation and check that the results are qualitatively and quantitatively matching (compute trustworthiness)
  • Make TSNE raise a ValueError when n_components > 3 or n_components < 2 with the BH solver enabled.

Related

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 7, 2017

Member
Member

jnothman commented Jun 7, 2017

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 7, 2017

Member
Member

jnothman commented Jun 7, 2017

@vene

This comment has been minimized.

Show comment
Hide comment
@vene

vene Jun 7, 2017

Member

Is there a reason why you opted to make a new Cython file rather than change the existing one and rely on unit tests for ensuring the same behaviour? In particular, did you spot any differences in your refactoring that could turn to regression tests?

Just saying because the current file layout makes reviewing somewhat more difficult.

(answered irl, my bad for rushing into this pr)

Member

vene commented Jun 7, 2017

Is there a reason why you opted to make a new Cython file rather than change the existing one and rely on unit tests for ensuring the same behaviour? In particular, did you spot any differences in your refactoring that could turn to regression tests?

Just saying because the current file layout makes reviewing somewhat more difficult.

(answered irl, my bad for rushing into this pr)

Show outdated Hide outdated sklearn/manifold/t_sne.py
# set the neighbors to n - 1
distances_nn, neighbors_nn = knn.kneighbors(
X, n_neighbors=k + 1)
distances_nn = distances_nn[:, 1:]

This comment has been minimized.

@jnothman

jnothman Jun 7, 2017

Member

This isn't quite doing the right thing. If two instances are identical, you can't be certain which will be output first. Use kneighbors(None) instead.

@jnothman

jnothman Jun 7, 2017

Member

This isn't quite doing the right thing. If two instances are identical, you can't be certain which will be output first. Use kneighbors(None) instead.

Show outdated Hide outdated sklearn/manifold/t_sne.py
if self.metric == 'precomputed':
# Use the precomputed distances to find
# the k nearest neighbors and their distances
neighbors_nn = np.argsort(distances, axis=1)[:, :k]

This comment has been minimized.

@jnothman

jnothman Jun 7, 2017

Member

You can just use NearestNeighbors(metric='precomputed') for this.

@jnothman

jnothman Jun 7, 2017

Member

You can just use NearestNeighbors(metric='precomputed') for this.

Show outdated Hide outdated sklearn/manifold/t_sne.py
else:
# Find the nearest neighbors for every point
# TODO: argument for class knn_estimator=None
# TODO: assert that the knn metric is euclidean

This comment has been minimized.

@jnothman

jnothman Jun 7, 2017

Member

Can't we use another metric?

@jnothman

jnothman Jun 7, 2017

Member

Can't we use another metric?

Show outdated Hide outdated sklearn/manifold/_barnes_hut_tsne.pyx
printf("[t-SNE] [d=%i] Inserting pos %i [%f, %f] duplicate_count=%i "
"into child %p\n", depth, point_index, pos[0], pos[1],
printf("[t-SNE] [d=%li] Inserting pos %li [%f, %f] duplicate_count=%li"
" into child %p\n", depth, point_index, pos[0], pos[1],

This comment has been minimized.

@tomMoral

tomMoral Jun 8, 2017

Contributor

Using %li is required when parsing long integer to avoid compiler warnings.

@tomMoral

tomMoral Jun 8, 2017

Contributor

Using %li is required when parsing long integer to avoid compiler warnings.

Show outdated Hide outdated sklearn/manifold/_barnes_hut_tsne.pyx
float[:,:] pos_reference,
np.int64_t[:,:] neighbors,
np.int64_t[:] neighbors,
np.int64_t[:] indptr,

This comment has been minimized.

@jnothman

jnothman Jun 8, 2017

Member

If we really do have n_neighbors non-zero values in each row, I think the previous approach with a 2d array of (samples, neighbors) was better than having an indptr design. Why did you change it?

@jnothman

jnothman Jun 8, 2017

Member

If we really do have n_neighbors non-zero values in each row, I think the previous approach with a 2d array of (samples, neighbors) was better than having an indptr design. Why did you change it?

This comment has been minimized.

@tomMoral

tomMoral Jun 8, 2017

Contributor

This makes it easy to efficiently symmetrize the conditional_P matrix using scipy operation for sparse matrices.
Symmetrization is done in the reference implementation so I tried to be as close as possible to it. Although I haven't reviewed it all yet.

@tomMoral

tomMoral Jun 8, 2017

Contributor

This makes it easy to efficiently symmetrize the conditional_P matrix using scipy operation for sparse matrices.
Symmetrization is done in the reference implementation so I tried to be as close as possible to it. Although I haven't reviewed it all yet.

Show outdated Hide outdated sklearn/manifold/t_sne.py
range(0, n_samples * K + 1, K)),
shape=(n_samples, n_samples))
P = P + P.T

This comment has been minimized.

@tomMoral

tomMoral Jun 8, 2017

Contributor

Sparse symmetrization is done here.

@tomMoral

tomMoral Jun 8, 2017

Contributor

Sparse symmetrization is done here.

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

Indicate it as a comment in the code, rather than a comment in the PR :)

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

Indicate it as a comment in the code, rather than a comment in the PR :)

@tomMoral

This comment has been minimized.

Show comment
Hide comment
@tomMoral

tomMoral Jun 8, 2017

Contributor

Some benchmark performance on MNIST against the reference implementation. The script to reproduce is included in the benchmark folder.
I have not yet run it on the full MNIST dataset as it should take more than 1h30 per implementation but the number looks okay on subsamples and the memory usage does not grow quadratically with the number of sample anymore.

There is probably room for improvment, but maybe in the next PR? :)

$ python bench_tsne_mnist.py 
Fitting TSNE on 100 samples...
Fitting T-SNE on 100 samples took 0.226s
Fitting bhtsne on 100 samples took 0.797s
Fitting TSNE on 1000 samples...
Fitting T-SNE on 1000 samples took 4.462s
Fitting bhtsne on 1000 samples took 6.774s
Fitting TSNE on 10000 samples...
Fitting T-SNE on 10000 samples took 133.299s
Fitting bhtsne on 10000 samples took 158.046s

Here is the memory usage reported with memory_profiler.

bench_mnist_tsne

The memory grows most when NN build its ball tree. Then it drops a bit. In any case, it stays around 300MB for 10000 samples.

EDIT: a run of this bench with master gives:

$ python bench_tsne_mnist.py 
Fitting TSNE on 100 samples...
Fitting T-SNE on 100 samples took 0.269s
Fitting TSNE on 1000 samples...
Fitting T-SNE on 1000 samples took 4.811s
Fitting TSNE on 5000 samples...
Fitting T-SNE on 5000 samples took 62.745s
Fitting TSNE on 10000 samples...
Fitting T-SNE on 10000 samples took 271.673s

So this PR is a bit faster.

Contributor

tomMoral commented Jun 8, 2017

Some benchmark performance on MNIST against the reference implementation. The script to reproduce is included in the benchmark folder.
I have not yet run it on the full MNIST dataset as it should take more than 1h30 per implementation but the number looks okay on subsamples and the memory usage does not grow quadratically with the number of sample anymore.

There is probably room for improvment, but maybe in the next PR? :)

$ python bench_tsne_mnist.py 
Fitting TSNE on 100 samples...
Fitting T-SNE on 100 samples took 0.226s
Fitting bhtsne on 100 samples took 0.797s
Fitting TSNE on 1000 samples...
Fitting T-SNE on 1000 samples took 4.462s
Fitting bhtsne on 1000 samples took 6.774s
Fitting TSNE on 10000 samples...
Fitting T-SNE on 10000 samples took 133.299s
Fitting bhtsne on 10000 samples took 158.046s

Here is the memory usage reported with memory_profiler.

bench_mnist_tsne

The memory grows most when NN build its ball tree. Then it drops a bit. In any case, it stays around 300MB for 10000 samples.

EDIT: a run of this bench with master gives:

$ python bench_tsne_mnist.py 
Fitting TSNE on 100 samples...
Fitting T-SNE on 100 samples took 0.269s
Fitting TSNE on 1000 samples...
Fitting T-SNE on 1000 samples took 4.811s
Fitting TSNE on 5000 samples...
Fitting T-SNE on 5000 samples took 62.745s
Fitting TSNE on 10000 samples...
Fitting T-SNE on 10000 samples took 271.673s

So this PR is a bit faster.

@@ -79,34 +81,20 @@ cpdef np.ndarray[np.float32_t, ndim=2] _binary_search_perplexity(
# Compute current entropy and corresponding probabilities
# computed just over the nearest neighbors or over all data
# if we're not using neighbors

This comment has been minimized.

@vene

vene Jun 8, 2017

Member

comment seems outdated, what does it mean "if we're not using neighbors" now? sorry, I misread

@vene

vene Jun 8, 2017

Member

comment seems outdated, what does it mean "if we're not using neighbors" now? sorry, I misread

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 8, 2017

Member

FYI, I am currently running a memory benchmark on the full (70000 samples) MNIST dataset. The memory usage of my python process is constant at 1.2GB.

Update: running MNIST on 70000 samples took ~52 minutes. Here is the memory profile:

tsne_mnist

Most of the time is spent in the ball tree.

Update 2*: : running MNIST on 70000 samples took ~53 minutes with the reference implementation and 1.3GB or RAM (basically same behavior as or Cython impl in this PR).

Member

ogrisel commented Jun 8, 2017

FYI, I am currently running a memory benchmark on the full (70000 samples) MNIST dataset. The memory usage of my python process is constant at 1.2GB.

Update: running MNIST on 70000 samples took ~52 minutes. Here is the memory profile:

tsne_mnist

Most of the time is spent in the ball tree.

Update 2*: : running MNIST on 70000 samples took ~53 minutes with the reference implementation and 1.3GB or RAM (basically same behavior as or Cython impl in this PR).

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 9, 2017

Member

Yes, I was going to ask how much of the time was in KNN.

Member

jnothman commented Jun 9, 2017

Yes, I was going to ask how much of the time was in KNN.

@jnothman

If ball tree is the slowest part:

  • we should consider offering n_jobs if multithreaded balltree queries help reduce runtime
  • we might check how the current implementation compares to other single-processor implementations to check our timings are in the ballpark of state of the art. I just realised you did that above
Show outdated Hide outdated sklearn/manifold/t_sne.py
metric=self.metric)
knn.fit(X)
# LvdM uses 3 * perplexity as the number of neighbors
# And we add one to not count the data point itself

This comment has been minimized.

@jnothman

jnothman Jun 9, 2017

Member

we do not add 1; you could explain why we pass None, or more explicitly use neighbors.kneighbors_graph(X, include_self=False).

@jnothman

jnothman Jun 9, 2017

Member

we do not add 1; you could explain why we pass None, or more explicitly use neighbors.kneighbors_graph(X, include_self=False).

Show outdated Hide outdated sklearn/manifold/t_sne.py
# And we add one to not count the data point itself
# In the event that we have very small # of points
# set the neighbors to n - 1
distances_nn, neighbors_nn = knn.kneighbors(

This comment has been minimized.

@jnothman

jnothman Jun 9, 2017

Member

we could use kneighbors_graph which will already return a sparse matrix representation. This will allow us to rely on kneighbors_graph being as memory efficient as possible, rather than putting it into a sparse matrix in _joint_probabilities_nn.

@jnothman

jnothman Jun 9, 2017

Member

we could use kneighbors_graph which will already return a sparse matrix representation. This will allow us to rely on kneighbors_graph being as memory efficient as possible, rather than putting it into a sparse matrix in _joint_probabilities_nn.

This comment has been minimized.

@tomMoral

tomMoral Jun 9, 2017

Contributor

We do not put the distances_nn in a sparse matrix in _join_probabilities_nn but we put the conditional_P as a sparse matrix for symmetrization purpose.
Using kneighbors_graph would increase the memory to store range(0, n_samples * K + 1, K). I don't see a need to use it except if it is more efficient internally.
What do you think?

@tomMoral

tomMoral Jun 9, 2017

Contributor

We do not put the distances_nn in a sparse matrix in _join_probabilities_nn but we put the conditional_P as a sparse matrix for symmetrization purpose.
Using kneighbors_graph would increase the memory to store range(0, n_samples * K + 1, K). I don't see a need to use it except if it is more efficient internally.
What do you think?

This comment has been minimized.

@ogrisel

ogrisel Jun 9, 2017

Member

We don't need any of the features of the scipy sparse matrix API for the neighbors info itself (we call into Cython code directly after its computation).

Only the conditional probability matrix benefits from being represented as a scipy sparse matrix to make it a one liner python snippet to do the symmetrization.

@ogrisel

ogrisel Jun 9, 2017

Member

We don't need any of the features of the scipy sparse matrix API for the neighbors info itself (we call into Cython code directly after its computation).

Only the conditional probability matrix benefits from being represented as a scipy sparse matrix to make it a one liner python snippet to do the symmetrization.

@tomMoral

This comment has been minimized.

Show comment
Hide comment
@tomMoral

tomMoral Jun 9, 2017

Contributor

I ran the benchmark with n_jobs=6

Fitting TSNE on 100 samples...
   Fitting KNN: 0.0008363723754882812
   Predict KNN: 0.1035609245300293
Fitting T-SNE on 100 samples took 0.377s
Fitting TSNE on 1000 samples...
    Fitting KNN: 0.02005767822265625
    Predict KNN: 0.30463409423828125
Fitting T-SNE on 1000 samples took 3.960s
Fitting TSNE on 5000 samples...
    Fitting KNN: 0.14362692832946777
    Predict KNN: 5.516682147979736
Fitting T-SNE on 5000 samples took 17.994s
Fitting TSNE on 10000 samples...
    Fitting KNN: 0.3776566982269287
    Predict KNN: 23.55470323562622
Fitting T-SNE on 10000 samples took 53.554s
Fitting TSNE on 70000 samples...
    Fitting KNN: 13.219616413116455
    Predict KNN: 1376.9010944366455
Fitting T-SNE on 70000 samples took 1579.028s

There is a non-negligible speedup as it took a bit less than 30min to fit the full dataset.
The most expensive computation is still the knn.

Contributor

tomMoral commented Jun 9, 2017

I ran the benchmark with n_jobs=6

Fitting TSNE on 100 samples...
   Fitting KNN: 0.0008363723754882812
   Predict KNN: 0.1035609245300293
Fitting T-SNE on 100 samples took 0.377s
Fitting TSNE on 1000 samples...
    Fitting KNN: 0.02005767822265625
    Predict KNN: 0.30463409423828125
Fitting T-SNE on 1000 samples took 3.960s
Fitting TSNE on 5000 samples...
    Fitting KNN: 0.14362692832946777
    Predict KNN: 5.516682147979736
Fitting T-SNE on 5000 samples took 17.994s
Fitting TSNE on 10000 samples...
    Fitting KNN: 0.3776566982269287
    Predict KNN: 23.55470323562622
Fitting T-SNE on 10000 samples took 53.554s
Fitting TSNE on 70000 samples...
    Fitting KNN: 13.219616413116455
    Predict KNN: 1376.9010944366455
Fitting T-SNE on 70000 samples took 1579.028s

There is a non-negligible speedup as it took a bit less than 30min to fit the full dataset.
The most expensive computation is still the knn.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 9, 2017

Member

I am currently running the original bhtsne code on 70000 MNIST and the memory usage of the python process is 1.3GB so I think we can officially declare that our implementation is as memory efficient as it should be 🎉

Update: the original bhtsne code on 70000 MNIST took 53min which is the same as our Cython code: 🎉²

Member

ogrisel commented Jun 9, 2017

I am currently running the original bhtsne code on 70000 MNIST and the memory usage of the python process is 1.3GB so I think we can officially declare that our implementation is as memory efficient as it should be 🎉

Update: the original bhtsne code on 70000 MNIST took 53min which is the same as our Cython code: 🎉²

@agramfort

This comment has been minimized.

Show comment
Hide comment
@agramfort

agramfort Jun 9, 2017

Member
Member

agramfort commented Jun 9, 2017

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 9, 2017

Member

Here is a snippet to plot the resulting embedding:

from sklearn.datasets import fetch_mldata
from sklearn.utils import check_array
import matplotlib.pyplot as plt
import numpy as np

X_embedded = np.load('mnist_tsne_70000.npy')


def load_data(dtype=np.float32, order='C'):
    """Load the data, then cache and memmap the train/test split"""
    print("Loading dataset...")
    data = fetch_mldata('MNIST original')
    X = check_array(data['data'], dtype=dtype, order=order)
    y = data["target"]

    # Normalize features
    X /= 255
    return X, y

_, y = load_data()


plt.figure(figsize=(12, 12))
for c in np.unique(y):
    X_c = X_embedded[y == c]
    plt.scatter(X_c[:, 0], X_c[:, 1], alpha=0.1, label=int(c))
plt.legend(loc='best')

With my 52min run with the code of this PR, this yields:

image

This might have converged to a local minima but it's seem to work well enough

Update: here is the output of the reference implementation on the same 70000 samples MNIST dataset:

image

We probably have discrepancies in the hyperparameters that are worth investigating more thoroughly before merging this PR.

Member

ogrisel commented Jun 9, 2017

Here is a snippet to plot the resulting embedding:

from sklearn.datasets import fetch_mldata
from sklearn.utils import check_array
import matplotlib.pyplot as plt
import numpy as np

X_embedded = np.load('mnist_tsne_70000.npy')


def load_data(dtype=np.float32, order='C'):
    """Load the data, then cache and memmap the train/test split"""
    print("Loading dataset...")
    data = fetch_mldata('MNIST original')
    X = check_array(data['data'], dtype=dtype, order=order)
    y = data["target"]

    # Normalize features
    X /= 255
    return X, y

_, y = load_data()


plt.figure(figsize=(12, 12))
for c in np.unique(y):
    X_c = X_embedded[y == c]
    plt.scatter(X_c[:, 0], X_c[:, 1], alpha=0.1, label=int(c))
plt.legend(loc='best')

With my 52min run with the code of this PR, this yields:

image

This might have converged to a local minima but it's seem to work well enough

Update: here is the output of the reference implementation on the same 70000 samples MNIST dataset:

image

We probably have discrepancies in the hyperparameters that are worth investigating more thoroughly before merging this PR.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 9, 2017

Member

In ran another session with PCA preprocessing of the MNIST 70000 dataset with 50 principal components:

Preprocessing the data with 50 dim PCA...
PCA took 4.054s
Fitting TSNE on 100 samples...
Fitting T-SNE on 100 samples took 0.164s
Fitting bhtsne on 100 samples took 0.346s
Fitting TSNE on 1000 samples...
Fitting T-SNE on 1000 samples took 2.488s
Fitting bhtsne on 1000 samples took 4.500s
Fitting TSNE on 5000 samples...
Fitting T-SNE on 5000 samples took 9.055s
Fitting bhtsne on 5000 samples took 26.767s
Fitting TSNE on 10000 samples...
Fitting T-SNE on 10000 samples took 26.286s
Fitting bhtsne on 10000 samples took 65.405s
Fitting TSNE on 70000 samples...
Fitting T-SNE on 70000 samples took 665.962s
Fitting bhtsne on 70000 samples took 870.480s

Our PR (reported as "T-SNE" in the benchmark script output):

image

Reference implementation (bhtsne):

image

I have not changed anything in the way we set the default hyperparameters: this is still to investigate. However it shows that we should probably do a X = PCA(n_components=50).fit_transform(X) in the benchmark script as the results are visually similar to applying the TSNE directly to the original 784-dimensional data.

It's interesting to note that in that lower dimensional regime we are significantly faster than the reference implementation.

Also I think we can probably do an MNIST example with 5000 samples by default in the scikit-learn doc (with 50-dim PCA preprocessing).

Member

ogrisel commented Jun 9, 2017

In ran another session with PCA preprocessing of the MNIST 70000 dataset with 50 principal components:

Preprocessing the data with 50 dim PCA...
PCA took 4.054s
Fitting TSNE on 100 samples...
Fitting T-SNE on 100 samples took 0.164s
Fitting bhtsne on 100 samples took 0.346s
Fitting TSNE on 1000 samples...
Fitting T-SNE on 1000 samples took 2.488s
Fitting bhtsne on 1000 samples took 4.500s
Fitting TSNE on 5000 samples...
Fitting T-SNE on 5000 samples took 9.055s
Fitting bhtsne on 5000 samples took 26.767s
Fitting TSNE on 10000 samples...
Fitting T-SNE on 10000 samples took 26.286s
Fitting bhtsne on 10000 samples took 65.405s
Fitting TSNE on 70000 samples...
Fitting T-SNE on 70000 samples took 665.962s
Fitting bhtsne on 70000 samples took 870.480s

Our PR (reported as "T-SNE" in the benchmark script output):

image

Reference implementation (bhtsne):

image

I have not changed anything in the way we set the default hyperparameters: this is still to investigate. However it shows that we should probably do a X = PCA(n_components=50).fit_transform(X) in the benchmark script as the results are visually similar to applying the TSNE directly to the original 784-dimensional data.

It's interesting to note that in that lower dimensional regime we are significantly faster than the reference implementation.

Also I think we can probably do an MNIST example with 5000 samples by default in the scikit-learn doc (with 50-dim PCA preprocessing).

Show outdated Hide outdated sklearn/manifold/_barnes_hut_tsne.pyx
float C = 0.0
float exponent = (dof + 1.0) / -2.0
cdef clock_t t1, t2
cdef float* buff = <float*> malloc(sizeof(float) * n_dimensions)

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

I would rather avoid using manual memory allocation in the scikit-learn codebase (it's easy to lead to bug). It should be possible to use an array here: http://docs.cython.org/en/latest/src/tutorial/array.html
or a memoryview:
http://cython.readthedocs.io/en/latest/src/userguide/memoryviews.html

You should bench to check that it doesn't introduce an overhead

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

I would rather avoid using manual memory allocation in the scikit-learn codebase (it's easy to lead to bug). It should be possible to use an array here: http://docs.cython.org/en/latest/src/tutorial/array.html
or a memoryview:
http://cython.readthedocs.io/en/latest/src/userguide/memoryviews.html

You should bench to check that it doesn't introduce an overhead

This comment has been minimized.

@ogrisel

ogrisel Jun 10, 2017

Member

it's a very tiny array though, between 2 or 3 elements. Maybe we can just use the former

cdef float[3] buff

and let the compiler handle the memory management based on the local variable scoping.

@ogrisel

ogrisel Jun 10, 2017

Member

it's a very tiny array though, between 2 or 3 elements. Maybe we can just use the former

cdef float[3] buff

and let the compiler handle the memory management based on the local variable scoping.

This comment has been minimized.

@ogrisel

ogrisel Jun 10, 2017

Member

To clarify, n_dimensions can be 2 or 3 only when using the Barnes Hut approximation.

@ogrisel

ogrisel Jun 10, 2017

Member

To clarify, n_dimensions can be 2 or 3 only when using the Barnes Hut approximation.

This comment has been minimized.

@ogrisel

ogrisel Jun 10, 2017

Member

Actually the code does not raise an error when n_components > 3 with the Barnes Hut solver. This is a real bug. Let me add it to the TODO.

@ogrisel

ogrisel Jun 10, 2017

Member

Actually the code does not raise an error when n_components > 3 with the Barnes Hut solver. This is a real bug. Let me add it to the TODO.

This comment has been minimized.

@tomMoral

tomMoral Jun 22, 2017

Contributor

It can also be in dimension 1 (and their is a test enforcing that).
So now TSNE raise a ValueError when using barnes_hut and n_components > 3

@tomMoral

tomMoral Jun 22, 2017

Contributor

It can also be in dimension 1 (and their is a test enforcing that).
So now TSNE raise a ValueError when using barnes_hut and n_components > 3

Show outdated Hide outdated sklearn/manifold/t_sne.py
Distances of samples to its K nearest neighbors.
neighbors : array, shape (n_samples, K)
K nearest-neighbors for each samples.

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

What does it contains? The indices of the neighbors? It might be useful to indicate it.

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

What does it contains? The indices of the neighbors? It might be useful to indicate it.

P : array, shape (n_samples * (n_samples-1) / 2,)
Condensed joint probability matrix.
P : csr sparse matrix, shape (n_samples, n_samples)
Condensed joint probability matrix with only nearest neighbors.

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

Are both half (above diagonal and below diagonal) of the array useful, or it is only one? In the latter case, it would be useful to indicate it.

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

Are both half (above diagonal and below diagonal) of the array useful, or it is only one? In the latter case, it would be useful to indicate it.

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

OK, forget it. Sorry

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

OK, forget it. Sorry

Show outdated Hide outdated sklearn/manifold/t_sne.py
[ 0.00009501, -0.00001388]])
>>> X_embedded = model.fit_transform(X)
>>> X_embedded.shape
(4, 2)

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

Why the change?

@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

Why the change?

This comment has been minimized.

@ogrisel

ogrisel Jun 10, 2017

Member

Because the values depend on the fact that we fixed the use of the "early exaggeration" parameter in this PR. Furthermore we are likely to change the default values of the optimizers hyperparams to make it match the reference implementation and the actual absolute values the embedding (as long as the trustworthiness score is good). What is important for the reader is that the data has been reduced to 2D and is therefore suitable for visualization as a scatter plot.

@ogrisel

ogrisel Jun 10, 2017

Member

Because the values depend on the fact that we fixed the use of the "early exaggeration" parameter in this PR. Furthermore we are likely to change the default values of the optimizers hyperparams to make it match the reference implementation and the actual absolute values the embedding (as long as the trustworthiness score is good). What is important for the reader is that the data has been reduced to 2D and is therefore suitable for visualization as a scatter plot.

This comment has been minimized.

@jnothman

jnothman Jul 2, 2017

Member

The range of values may be of interest to the user, and worth outputting them to illustrate.

@jnothman

jnothman Jul 2, 2017

Member

The range of values may be of interest to the user, and worth outputting them to illustrate.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Jun 10, 2017

Member

I made a pass with comments. Will wait for the hyper-parameters to be set before given a 👍

Member

GaelVaroquaux commented Jun 10, 2017

I made a pass with comments. Will wait for the hyper-parameters to be set before given a 👍

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 10, 2017

Member

I am integrating @AlexanderFabisch's change in the default value of the learning_rate from #8752 as well as trying to make the optimization schedule follow the one used in the reference implementation.

Here is the current state of the output of the benchmark script:

PCA preprocessing down to 50 dimensions took 5.155s
Fitting sklearn TSNE on 100 samples...
Fitting sklearn TSNE on 100 samples took 0.517s in 274 iterations, trustworthiness: 0.966
Fitting lvdmaaten/bhtsne on 100 samples...
Fitting lvdmaaten/bhtsne on 100 samples took 0.528s in -1 iterations, trustworthiness: 0.967
Fitting sklearn TSNE on 1000 samples...
Fitting sklearn TSNE on 1000 samples took 2.890s in 124 iterations, trustworthiness: 0.979
Fitting lvdmaaten/bhtsne on 1000 samples...
Fitting lvdmaaten/bhtsne on 1000 samples took 6.718s in -1 iterations, trustworthiness: 0.981
Fitting sklearn TSNE on 5000 samples...
Fitting sklearn TSNE on 5000 samples took 13.441s in 99 iterations, trustworthiness: 0.966
Fitting lvdmaaten/bhtsne on 5000 samples...
Fitting lvdmaaten/bhtsne on 5000 samples took 57.042s in -1 iterations, trustworthiness: 0.992
Fitting sklearn TSNE on 10000 samples...
Fitting sklearn TSNE on 10000 samples took 47.039s in 149 iterations, trustworthiness: 0.980
Fitting lvdmaaten/bhtsne on 10000 samples...
Fitting lvdmaaten/bhtsne on 10000 samples took 142.284s in -1 iterations, trustworthiness: 0.996

There is still a significant discrepancy. I need to find a way to get the effective number of iterations done by the reference implementation and get a better understanding of the stopping criterion in both implementations.

Member

ogrisel commented Jun 10, 2017

I am integrating @AlexanderFabisch's change in the default value of the learning_rate from #8752 as well as trying to make the optimization schedule follow the one used in the reference implementation.

Here is the current state of the output of the benchmark script:

PCA preprocessing down to 50 dimensions took 5.155s
Fitting sklearn TSNE on 100 samples...
Fitting sklearn TSNE on 100 samples took 0.517s in 274 iterations, trustworthiness: 0.966
Fitting lvdmaaten/bhtsne on 100 samples...
Fitting lvdmaaten/bhtsne on 100 samples took 0.528s in -1 iterations, trustworthiness: 0.967
Fitting sklearn TSNE on 1000 samples...
Fitting sklearn TSNE on 1000 samples took 2.890s in 124 iterations, trustworthiness: 0.979
Fitting lvdmaaten/bhtsne on 1000 samples...
Fitting lvdmaaten/bhtsne on 1000 samples took 6.718s in -1 iterations, trustworthiness: 0.981
Fitting sklearn TSNE on 5000 samples...
Fitting sklearn TSNE on 5000 samples took 13.441s in 99 iterations, trustworthiness: 0.966
Fitting lvdmaaten/bhtsne on 5000 samples...
Fitting lvdmaaten/bhtsne on 5000 samples took 57.042s in -1 iterations, trustworthiness: 0.992
Fitting sklearn TSNE on 10000 samples...
Fitting sklearn TSNE on 10000 samples took 47.039s in 149 iterations, trustworthiness: 0.980
Fitting lvdmaaten/bhtsne on 10000 samples...
Fitting lvdmaaten/bhtsne on 10000 samples took 142.284s in -1 iterations, trustworthiness: 0.996

There is still a significant discrepancy. I need to find a way to get the effective number of iterations done by the reference implementation and get a better understanding of the stopping criterion in both implementations.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 10, 2017

Member

Our reporting of the number of iterations is also broken: it should always be higher than 250...

Member

ogrisel commented Jun 10, 2017

Our reporting of the number of iterations is also broken: it should always be higher than 250...

@tomMoral

This comment has been minimized.

Show comment
Hide comment
@tomMoral

tomMoral Jun 11, 2017

Contributor

I don't think the reporting of n_iter is broken. The number of iterations can be smaller than 250 if the first step of optimization converges before the 250th iteration.

Contributor

tomMoral commented Jun 11, 2017

I don't think the reporting of n_iter is broken. The number of iterations can be smaller than 250 if the first step of optimization converges before the 250th iteration.

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 12, 2017

Member

Is this still WIP? I'd like to see it in 0.19.

Member

jnothman commented Jun 12, 2017

Is this still WIP? I'd like to see it in 0.19.

@jnothman jnothman added the Bug label Jun 12, 2017

@jnothman jnothman added this to the 0.19 milestone Jun 12, 2017

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 12, 2017

Member

Sorry, I'd not seen some of @ogrisel's comments regarding making this the one-pr-to-fix-them-all.

Member

jnothman commented Jun 12, 2017

Sorry, I'd not seen some of @ogrisel's comments regarding making this the one-pr-to-fix-them-all.

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 13, 2017

Member

Intentional or otherwise, I can't produce #8992's error on this branch.

Member

jnothman commented Jun 13, 2017

Intentional or otherwise, I can't produce #8992's error on this branch.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 14, 2017

Member

Here is a status report:

  • the current PR has a fixed the memory usage issue;
  • in doing so @tomMoral found out that the early_exaggeration param was not properly used by the optimizer. Fixing this causes a behavioral change in the optimizer's but this is a bugfix so I don't think we should try to implement a deprecation cycle for this;
  • while reading the code and running the benchmarks on MNIST and comparing to the reference implementation we found out that both the optimizer scheduling and the stopping criterion where behaving very differently: this leads to a behavioral bug as our optimizer would often stop far before convergence leading to underwhelming quantitative (trustworthiness) and qualitative (blurry scatter plots on MNIST) results;
  • I made some experimental changes in this branch to follow the behavior and default parametrization of the reference implementation as much as possible: I think this branch is now optimizing the model correctly (I will re run a bench on the full MNIST data to confirm);
  • I can still trigger assertions failures similar to #8992 from time to time. I have not entirely sure whether those checks are required or not. Those assertion failures are probably more likely to happen because we now optimize for longer. We need to investigate more before being able to merge this PR.
  • we also need to fix the n_components > 3 bug I reported in this PR description.

Speed-wise, our stopping criterion does not stop the computation too early as was the case before. With the correct stopping criterion, the gradient descent part of the fitting procedure is now dominating (as is the case in the reference implementation) and we are twice slower that the reference implementation. There are useless mallocs & frees in the inner gradient routine for large temporary buffers. There are probably other things to optimize but this should better be tackled in a dedicated PR.

Also note: the gradient descent computation in Cython is embarrassingly parallel and would probably benefit from thread based multicore parallelism using OpenMP via Cython prange (once loky is integrated in joblib and sklearn). This will also be explored in a later PR.

Member

ogrisel commented Jun 14, 2017

Here is a status report:

  • the current PR has a fixed the memory usage issue;
  • in doing so @tomMoral found out that the early_exaggeration param was not properly used by the optimizer. Fixing this causes a behavioral change in the optimizer's but this is a bugfix so I don't think we should try to implement a deprecation cycle for this;
  • while reading the code and running the benchmarks on MNIST and comparing to the reference implementation we found out that both the optimizer scheduling and the stopping criterion where behaving very differently: this leads to a behavioral bug as our optimizer would often stop far before convergence leading to underwhelming quantitative (trustworthiness) and qualitative (blurry scatter plots on MNIST) results;
  • I made some experimental changes in this branch to follow the behavior and default parametrization of the reference implementation as much as possible: I think this branch is now optimizing the model correctly (I will re run a bench on the full MNIST data to confirm);
  • I can still trigger assertions failures similar to #8992 from time to time. I have not entirely sure whether those checks are required or not. Those assertion failures are probably more likely to happen because we now optimize for longer. We need to investigate more before being able to merge this PR.
  • we also need to fix the n_components > 3 bug I reported in this PR description.

Speed-wise, our stopping criterion does not stop the computation too early as was the case before. With the correct stopping criterion, the gradient descent part of the fitting procedure is now dominating (as is the case in the reference implementation) and we are twice slower that the reference implementation. There are useless mallocs & frees in the inner gradient routine for large temporary buffers. There are probably other things to optimize but this should better be tackled in a dedicated PR.

Also note: the gradient descent computation in Cython is embarrassingly parallel and would probably benefit from thread based multicore parallelism using OpenMP via Cython prange (once loky is integrated in joblib and sklearn). This will also be explored in a later PR.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 14, 2017

Member

There is a bunch of broken tests to fix too but I will wait to get the results on the full MNIST before doing so.

Member

ogrisel commented Jun 14, 2017

There is a bunch of broken tests to fix too but I will wait to get the results on the full MNIST before doing so.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 14, 2017

Member

My run has crashed with a memory error because of the safe_sparse_dot in the trustworthiness computation. It needs to be chunked to be memory efficient.

Member

ogrisel commented Jun 14, 2017

My run has crashed with a memory error because of the safe_sparse_dot in the trustworthiness computation. It needs to be chunked to be memory efficient.

@tomMoral

This comment has been minimized.

Show comment
Hide comment
@tomMoral

tomMoral Jun 18, 2017

Contributor

Status update :
I just finished reimplementing the QuadTree as discussed with @ogrisel.
There is still some work to do to properly test it and to enable the serialization so it is WIP.
The test for t-SNE seems to pass and I do not see the InsertionError anymore so this could fix #8992

Contributor

tomMoral commented Jun 18, 2017

Status update :
I just finished reimplementing the QuadTree as discussed with @ogrisel.
There is still some work to do to properly test it and to enable the serialization so it is WIP.
The test for t-SNE seems to pass and I do not see the InsertionError anymore so this could fix #8992

@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Jun 19, 2017

Member

I think I'll retag for 0.20. I'd rather release sooner than later and with well-tested features...

Member

amueller commented Jun 19, 2017

I think I'll retag for 0.20. I'd rather release sooner than later and with well-tested features...

@amueller amueller modified the milestones: 0.20, 0.19 Jun 19, 2017

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 19, 2017

Member
Member

jnothman commented Jun 19, 2017

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 20, 2017

Member

I propose we add a note in the docstring that bhtsne is currently more expensive than it should be

And also it currently does not converge (optimization is stopped too early in master).

Member

ogrisel commented Jun 20, 2017

I propose we add a note in the docstring that bhtsne is currently more expensive than it should be

And also it currently does not converge (optimization is stopped too early in master).

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 20, 2017

Member

Note: the appveyor failure is caused by a broken network connection. All tests pass under windows.

Member

ogrisel commented Jun 20, 2017

Note: the appveyor failure is caused by a broken network connection. All tests pass under windows.

ogrisel added some commits Jul 11, 2017

FIX various optimization schedule issues in TSNE
Respect class level _EXPLORATION_N_ITER.

Disable min_error_diff.

Fix docstring about min_grad_norm.
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 11, 2017

Member

For some reason github did not detect my last commit to fix the broken test. I just rebased on current master and force pushed.

In the mean time I have run the MNIST 70000 benchmark and it works as expected now.

Member

ogrisel commented Jul 11, 2017

For some reason github did not detect my last commit to fix the broken test. I just rebased on current master and force pushed.

In the mean time I have run the MNIST 70000 benchmark and it works as expected now.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 11, 2017

Member

Actually there is still something odd, here is the output of the verbose mode on 70000 MNIST:

[t-SNE] Computed conditional probabilities in 5.937s
[t-SNE] Iteration 50: error = 20.6228523, gradient norm = 0.0012072 (50 iterations in 109.420s)
[t-SNE] Iteration 100: error = 20.6228523, gradient norm = 0.0003245 (50 iterations in 112.185s)
[t-SNE] Iteration 150: error = 20.6228523, gradient norm = 0.0003467 (50 iterations in 109.183s)
[t-SNE] Iteration 200: error = 20.6228523, gradient norm = 0.0002613 (50 iterations in 108.042s)
[t-SNE] Iteration 250: error = 20.6228523, gradient norm = 0.0001693 (50 iterations in 108.003s)
[t-SNE] KL divergence after 250 iterations with early exaggeration: 20.622852
[t-SNE] Iteration 300: error = 0.2076470, gradient norm = 0.0011019 (50 iterations in 93.393s)
[t-SNE] Iteration 350: error = 0.2076470, gradient norm = 0.0006730 (50 iterations in 85.729s)
[t-SNE] Iteration 400: error = 0.2076470, gradient norm = 0.0004486 (50 iterations in 82.922s)
[t-SNE] Iteration 450: error = 0.2076470, gradient norm = 0.0003258 (50 iterations in 80.475s)
[t-SNE] Iteration 500: error = 0.2076470, gradient norm = 0.0002516 (50 iterations in 78.311s)
[t-SNE] Iteration 550: error = 0.2076470, gradient norm = 0.0002036 (50 iterations in 77.857s)
[t-SNE] Iteration 600: error = 0.2076470, gradient norm = 0.0001707 (50 iterations in 78.732s)
[t-SNE] Iteration 650: error = 0.2076470, gradient norm = 0.0001466 (50 iterations in 78.704s)
[t-SNE] Iteration 650: did not make any progress during the last 300 episodes. Finished.
[t-SNE] Error after 650 iterations: 0.207647
Fitting sklearn TSNE on 70000 samples took 2147.127s in 649 iterations, nn accuracy: 0.112

The gradient is non-zero, but the reported error is constant. I suspect an issue in the way the error is computed.

This is weird because the final nn accuracy and the 2D plot look as good as the reference bhtsne implementation.

Member

ogrisel commented Jul 11, 2017

Actually there is still something odd, here is the output of the verbose mode on 70000 MNIST:

[t-SNE] Computed conditional probabilities in 5.937s
[t-SNE] Iteration 50: error = 20.6228523, gradient norm = 0.0012072 (50 iterations in 109.420s)
[t-SNE] Iteration 100: error = 20.6228523, gradient norm = 0.0003245 (50 iterations in 112.185s)
[t-SNE] Iteration 150: error = 20.6228523, gradient norm = 0.0003467 (50 iterations in 109.183s)
[t-SNE] Iteration 200: error = 20.6228523, gradient norm = 0.0002613 (50 iterations in 108.042s)
[t-SNE] Iteration 250: error = 20.6228523, gradient norm = 0.0001693 (50 iterations in 108.003s)
[t-SNE] KL divergence after 250 iterations with early exaggeration: 20.622852
[t-SNE] Iteration 300: error = 0.2076470, gradient norm = 0.0011019 (50 iterations in 93.393s)
[t-SNE] Iteration 350: error = 0.2076470, gradient norm = 0.0006730 (50 iterations in 85.729s)
[t-SNE] Iteration 400: error = 0.2076470, gradient norm = 0.0004486 (50 iterations in 82.922s)
[t-SNE] Iteration 450: error = 0.2076470, gradient norm = 0.0003258 (50 iterations in 80.475s)
[t-SNE] Iteration 500: error = 0.2076470, gradient norm = 0.0002516 (50 iterations in 78.311s)
[t-SNE] Iteration 550: error = 0.2076470, gradient norm = 0.0002036 (50 iterations in 77.857s)
[t-SNE] Iteration 600: error = 0.2076470, gradient norm = 0.0001707 (50 iterations in 78.732s)
[t-SNE] Iteration 650: error = 0.2076470, gradient norm = 0.0001466 (50 iterations in 78.704s)
[t-SNE] Iteration 650: did not make any progress during the last 300 episodes. Finished.
[t-SNE] Error after 650 iterations: 0.207647
Fitting sklearn TSNE on 70000 samples took 2147.127s in 649 iterations, nn accuracy: 0.112

The gradient is non-zero, but the reported error is constant. I suspect an issue in the way the error is computed.

This is weird because the final nn accuracy and the 2D plot look as good as the reference bhtsne implementation.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 11, 2017

Member

@tomMoral A similar problem appears with 10000 MNIST but only during the early exaggeration phase.

[t-SNE] Iteration 50: error = 42.3276405, gradient norm = 0.0032737 (50 iterations in 9.243s)
[t-SNE] Iteration 100: error = 42.3276405, gradient norm = 0.0007196 (50 iterations in 9.502s)
[t-SNE] Iteration 150: error = 42.3276405, gradient norm = 0.0002795 (50 iterations in 9.690s)
[t-SNE] Iteration 200: error = 42.3276405, gradient norm = 0.0012699 (50 iterations in 9.808s)
[t-SNE] Iteration 250: error = 42.3276405, gradient norm = 0.0004984 (50 iterations in 9.983s)
[t-SNE] KL divergence after 250 iterations with early exaggeration: 42.327641
[t-SNE] Iteration 300: error = 1.3027184, gradient norm = 0.0012434 (50 iterations in 6.936s)
[t-SNE] Iteration 350: error = 1.3027184, gradient norm = 0.0005483 (50 iterations in 6.948s)
[t-SNE] Iteration 400: error = 1.2963731, gradient norm = 0.0003284 (50 iterations in 7.496s)
[t-SNE] Iteration 450: error = 1.1954679, gradient norm = 0.0002261 (50 iterations in 7.911s)
[t-SNE] Iteration 500: error = 1.1003660, gradient norm = 0.0001684 (50 iterations in 7.460s)
[t-SNE] Iteration 550: error = 1.0248337, gradient norm = 0.0001321 (50 iterations in 7.088s)
[t-SNE] Iteration 600: error = 0.9644930, gradient norm = 0.0001083 (50 iterations in 6.812s)
[t-SNE] Iteration 650: error = 0.9163133, gradient norm = 0.0000917 (50 iterations in 6.825s)
[t-SNE] Iteration 700: error = 0.8773406, gradient norm = 0.0000800 (50 iterations in 6.628s)
[t-SNE] Iteration 750: error = 0.8454365, gradient norm = 0.0000717 (50 iterations in 6.623s)
[t-SNE] Iteration 800: error = 0.8192183, gradient norm = 0.0000644 (50 iterations in 7.418s)
[t-SNE] Iteration 850: error = 0.7974554, gradient norm = 0.0000588 (50 iterations in 7.031s)
[t-SNE] Iteration 900: error = 0.7794618, gradient norm = 0.0000551 (50 iterations in 7.582s)
[t-SNE] Iteration 950: error = 0.7651656, gradient norm = 0.0000527 (50 iterations in 8.875s)
[t-SNE] Iteration 1000: error = 0.7536564, gradient norm = 0.0000512 (50 iterations in 10.508s)
[t-SNE] Error after 1000 iterations: 0.753656
Fitting sklearn TSNE on 10000 samples took 170.083s in 999 iterations, nn accuracy: 0.505
Member

ogrisel commented Jul 11, 2017

@tomMoral A similar problem appears with 10000 MNIST but only during the early exaggeration phase.

[t-SNE] Iteration 50: error = 42.3276405, gradient norm = 0.0032737 (50 iterations in 9.243s)
[t-SNE] Iteration 100: error = 42.3276405, gradient norm = 0.0007196 (50 iterations in 9.502s)
[t-SNE] Iteration 150: error = 42.3276405, gradient norm = 0.0002795 (50 iterations in 9.690s)
[t-SNE] Iteration 200: error = 42.3276405, gradient norm = 0.0012699 (50 iterations in 9.808s)
[t-SNE] Iteration 250: error = 42.3276405, gradient norm = 0.0004984 (50 iterations in 9.983s)
[t-SNE] KL divergence after 250 iterations with early exaggeration: 42.327641
[t-SNE] Iteration 300: error = 1.3027184, gradient norm = 0.0012434 (50 iterations in 6.936s)
[t-SNE] Iteration 350: error = 1.3027184, gradient norm = 0.0005483 (50 iterations in 6.948s)
[t-SNE] Iteration 400: error = 1.2963731, gradient norm = 0.0003284 (50 iterations in 7.496s)
[t-SNE] Iteration 450: error = 1.1954679, gradient norm = 0.0002261 (50 iterations in 7.911s)
[t-SNE] Iteration 500: error = 1.1003660, gradient norm = 0.0001684 (50 iterations in 7.460s)
[t-SNE] Iteration 550: error = 1.0248337, gradient norm = 0.0001321 (50 iterations in 7.088s)
[t-SNE] Iteration 600: error = 0.9644930, gradient norm = 0.0001083 (50 iterations in 6.812s)
[t-SNE] Iteration 650: error = 0.9163133, gradient norm = 0.0000917 (50 iterations in 6.825s)
[t-SNE] Iteration 700: error = 0.8773406, gradient norm = 0.0000800 (50 iterations in 6.628s)
[t-SNE] Iteration 750: error = 0.8454365, gradient norm = 0.0000717 (50 iterations in 6.623s)
[t-SNE] Iteration 800: error = 0.8192183, gradient norm = 0.0000644 (50 iterations in 7.418s)
[t-SNE] Iteration 850: error = 0.7974554, gradient norm = 0.0000588 (50 iterations in 7.031s)
[t-SNE] Iteration 900: error = 0.7794618, gradient norm = 0.0000551 (50 iterations in 7.582s)
[t-SNE] Iteration 950: error = 0.7651656, gradient norm = 0.0000527 (50 iterations in 8.875s)
[t-SNE] Iteration 1000: error = 0.7536564, gradient norm = 0.0000512 (50 iterations in 10.508s)
[t-SNE] Error after 1000 iterations: 0.753656
Fitting sklearn TSNE on 10000 samples took 170.083s in 999 iterations, nn accuracy: 0.505
@tomMoral

This comment has been minimized.

Show comment
Hide comment
@tomMoral

tomMoral Jul 11, 2017

Contributor

@ogrisel I am not sure it is a problem if the results are good.. It means that we are in a local minima and that the optimization is more stable, no?
I spend a lot of time making the error computation more robust to numerical errors. There might still be some mistakes but I think we are computing the right quantity.

Contributor

tomMoral commented Jul 11, 2017

@ogrisel I am not sure it is a problem if the results are good.. It means that we are in a local minima and that the optimization is more stable, no?
I spend a lot of time making the error computation more robust to numerical errors. There might still be some mistakes but I think we are computing the right quantity.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 11, 2017

Member

It means that we are in a local minima and that the optimization is more stable, no?

It's not necessarily a local or a global minimum. I also monitor the norms of the updates and they are quite large (between 0.01 and 2). I find it weird to get an error value that stays constant with large updates, that's fishy. Setting verbose to 20 does not seem to output more detailed logs even after I changed the following:

diff --git a/sklearn/manifold/_barnes_hut_tsne.pyx b/sklearn/manifold/_barnes_hut_tsne.pyx
index 5041c1d..a132935 100644
--- a/sklearn/manifold/_barnes_hut_tsne.pyx
+++ b/sklearn/manifold/_barnes_hut_tsne.pyx
@@ -236,7 +236,7 @@ def gradient(float[:] val_P,
     assert n == indptr.shape[0] - 1, m
     if verbose > 10:
         printf("[t-SNE] Initializing tree of n_dimensions %i\n", n_dimensions)
-    cdef quad_tree._QuadTree qt = quad_tree._QuadTree(pos_output.shape[1], 0)
+    cdef quad_tree._QuadTree qt = quad_tree._QuadTree(pos_output.shape[1], verbose)
     if verbose > 10:
         printf("[t-SNE] Inserting %li points\n", pos_output.shape[0])
     qt.build_tree(pos_output)
Member

ogrisel commented Jul 11, 2017

It means that we are in a local minima and that the optimization is more stable, no?

It's not necessarily a local or a global minimum. I also monitor the norms of the updates and they are quite large (between 0.01 and 2). I find it weird to get an error value that stays constant with large updates, that's fishy. Setting verbose to 20 does not seem to output more detailed logs even after I changed the following:

diff --git a/sklearn/manifold/_barnes_hut_tsne.pyx b/sklearn/manifold/_barnes_hut_tsne.pyx
index 5041c1d..a132935 100644
--- a/sklearn/manifold/_barnes_hut_tsne.pyx
+++ b/sklearn/manifold/_barnes_hut_tsne.pyx
@@ -236,7 +236,7 @@ def gradient(float[:] val_P,
     assert n == indptr.shape[0] - 1, m
     if verbose > 10:
         printf("[t-SNE] Initializing tree of n_dimensions %i\n", n_dimensions)
-    cdef quad_tree._QuadTree qt = quad_tree._QuadTree(pos_output.shape[1], 0)
+    cdef quad_tree._QuadTree qt = quad_tree._QuadTree(pos_output.shape[1], verbose)
     if verbose > 10:
         printf("[t-SNE] Inserting %li points\n", pos_output.shape[0])
     qt.build_tree(pos_output)

ogrisel and others added some commits Jul 11, 2017

Show outdated Hide outdated sklearn/manifold/t_sne.py
@@ -803,6 +803,8 @@ def _tsne(self, P, degrees_of_freedom, n_samples, random_state, X_embedded,
if self.method == 'barnes_hut':
obj_func = _kl_divergence_bh
opt_args['kwargs']['angle'] = self.angle
# Repeat verbose argument for _kl_divergence_bh
opt_args['kwargs']['verbose'] = self.verbose

This comment has been minimized.

@ogrisel

ogrisel Jul 12, 2017

Member

Good catch.

@ogrisel

ogrisel Jul 12, 2017

Member

Good catch.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 12, 2017

Member

@tomMoral I am investigating why the errors are often constant: I think that's because we truncate with too large of an EPSILON in the compute_gradient_positive. I am working on a fix.

Member

ogrisel commented Jul 12, 2017

@tomMoral I am investigating why the errors are often constant: I think that's because we truncate with too large of an EPSILON in the compute_gradient_positive. I am working on a fix.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 12, 2017

Member

Alright the bug on the error computation is fixed, all tests pass (the circle ci failure is the stock market stuff), examples look good and the MNIST benchmark is both accurate, reasonably fast and memory efficient.

Merging!

Member

ogrisel commented Jul 12, 2017

Alright the bug on the error computation is fixed, all tests pass (the circle ci failure is the stock market stuff), examples look good and the MNIST benchmark is both accurate, reasonably fast and memory efficient.

Merging!

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Jul 12, 2017

Member
Member

GaelVaroquaux commented Jul 12, 2017

@ogrisel ogrisel merged commit cb1b6c4 into scikit-learn:master Jul 12, 2017

2 of 3 checks passed

ci/circleci Your tests failed on CircleCI
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 12, 2017

Member

There is still plenty of things to improve in that code base but I think this is a great first step. Thanks @tomMoral for your patience :)

Member

ogrisel commented Jul 12, 2017

There is still plenty of things to improve in that code base but I think this is a great first step. Thanks @tomMoral for your patience :)

@tomMoral

This comment has been minimized.

Show comment
Hide comment
@tomMoral

tomMoral Jul 12, 2017

Contributor

Thanks @ogrisel for fixing the EPSILON bug!
Hopefully, I will get some time to work on some acceleration of our t-SNE soon!

Contributor

tomMoral commented Jul 12, 2017

Thanks @ogrisel for fixing the EPSILON bug!
Hopefully, I will get some time to work on some acceleration of our t-SNE soon!

@tomMoral tomMoral deleted the tomMoral:optim_tsne branch Jul 12, 2017

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jul 13, 2017

Member
Member

jnothman commented Jul 13, 2017

massich added a commit to massich/scikit-learn that referenced this pull request Jul 13, 2017

FIX t-SNE memory usage and many other optimizer issues (#9032)
Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.

yarikoptic added a commit to yarikoptic/scikit-learn that referenced this pull request Jul 27, 2017

Merge tag '0.19b2' into releases
Release 0.19b2

* tag '0.19b2': (808 commits)
  Preparing 0.19b2
  [MRG+1] FIX out of bounds array access in SAGA (#9376)
  FIX make test_importances pass on 32 bit linux
  Release 0.19b1
  DOC remove 'in dev' header in whats_new.rst
  DOC typos in whats_news.rst [ci skip]
  [MRG] DOC cleaning up what's new for 0.19 (#9252)
  FIX t-SNE memory usage and many other optimizer issues (#9032)
  FIX broken link in gallery and bad title rendering
  [MRG] DOC Replace \acute by prime (#9332)
  Fix typos (#9320)
  [MRG + 1 (rv) + 1 (alex) + 1] Add a check to test the docstring params and their order (#9206)
  DOC Residual sum vs. regression sum (#9314)
  [MRG] [HOTFIX] Fix capitalization in test and hence fix failing travis at master (#9317)
  More informative error message for classification metrics given regression output (#9275)
  [MRG] COSMIT Remove unused parameters in private functions (#9310)
  [MRG+1] Ridgecv normalize (#9302)
  [MRG + 2] ENH Allow `cross_val_score`, `GridSearchCV` et al. to evaluate on multiple metrics (#7388)
  Add data_home parameter to fetch_kddcup99 (#9289)
  FIX makedirs(..., exists_ok) not available in Python 2 (#9284)
  ...

yarikoptic added a commit to yarikoptic/scikit-learn that referenced this pull request Jul 27, 2017

Merge branch 'releases' into dfsg (reremoved joblib and jquery)
* releases: (808 commits)
  Preparing 0.19b2
  [MRG+1] FIX out of bounds array access in SAGA (#9376)
  FIX make test_importances pass on 32 bit linux
  Release 0.19b1
  DOC remove 'in dev' header in whats_new.rst
  DOC typos in whats_news.rst [ci skip]
  [MRG] DOC cleaning up what's new for 0.19 (#9252)
  FIX t-SNE memory usage and many other optimizer issues (#9032)
  FIX broken link in gallery and bad title rendering
  [MRG] DOC Replace \acute by prime (#9332)
  Fix typos (#9320)
  [MRG + 1 (rv) + 1 (alex) + 1] Add a check to test the docstring params and their order (#9206)
  DOC Residual sum vs. regression sum (#9314)
  [MRG] [HOTFIX] Fix capitalization in test and hence fix failing travis at master (#9317)
  More informative error message for classification metrics given regression output (#9275)
  [MRG] COSMIT Remove unused parameters in private functions (#9310)
  [MRG+1] Ridgecv normalize (#9302)
  [MRG + 2] ENH Allow `cross_val_score`, `GridSearchCV` et al. to evaluate on multiple metrics (#7388)
  Add data_home parameter to fetch_kddcup99 (#9289)
  FIX makedirs(..., exists_ok) not available in Python 2 (#9284)
  ...

yarikoptic added a commit to yarikoptic/scikit-learn that referenced this pull request Jul 27, 2017

Merge branch 'dfsg' into debian
* dfsg: (808 commits)
  Preparing 0.19b2
  [MRG+1] FIX out of bounds array access in SAGA (#9376)
  FIX make test_importances pass on 32 bit linux
  Release 0.19b1
  DOC remove 'in dev' header in whats_new.rst
  DOC typos in whats_news.rst [ci skip]
  [MRG] DOC cleaning up what's new for 0.19 (#9252)
  FIX t-SNE memory usage and many other optimizer issues (#9032)
  FIX broken link in gallery and bad title rendering
  [MRG] DOC Replace \acute by prime (#9332)
  Fix typos (#9320)
  [MRG + 1 (rv) + 1 (alex) + 1] Add a check to test the docstring params and their order (#9206)
  DOC Residual sum vs. regression sum (#9314)
  [MRG] [HOTFIX] Fix capitalization in test and hence fix failing travis at master (#9317)
  More informative error message for classification metrics given regression output (#9275)
  [MRG] COSMIT Remove unused parameters in private functions (#9310)
  [MRG+1] Ridgecv normalize (#9302)
  [MRG + 2] ENH Allow `cross_val_score`, `GridSearchCV` et al. to evaluate on multiple metrics (#7388)
  Add data_home parameter to fetch_kddcup99 (#9289)
  FIX makedirs(..., exists_ok) not available in Python 2 (#9284)
  ...

dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017

FIX t-SNE memory usage and many other optimizer issues (#9032)
Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.

dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017

FIX t-SNE memory usage and many other optimizer issues (#9032)
Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.

NelleV added a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017

FIX t-SNE memory usage and many other optimizer issues (#9032)
Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.

paulha added a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017

FIX t-SNE memory usage and many other optimizer issues (#9032)
Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.

AishwaryaRK added a commit to AishwaryaRK/scikit-learn that referenced this pull request Aug 29, 2017

FIX t-SNE memory usage and many other optimizer issues (#9032)
Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.

@rth rth referenced this pull request Sep 5, 2017

Open

T-SNE fails for CSR matrix #9691

maskani-moh added a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017

FIX t-SNE memory usage and many other optimizer issues (#9032)
Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.

jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017

FIX t-SNE memory usage and many other optimizer issues (#9032)
Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment