# t-SNE Animations

*fastTSNE* includes a callback system, with can be triggered every *n* iterations and can also be used to control optimization and when to stop.

In this notebook, we'll look at an example and use callbacks to generate an animation of the optimization. In practice, this serves no real purpose other than being fun to look at.

In [1]:
from fastTSNE import TSNE
from fastTSNE.callbacks import ErrorLogger

from examples import utils

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import Image, display

For this example, we'll be using the MNIST handwritten digits data set, which has probably been the most famous tSNE success.

In [2]:
x, y = utils.get_mnist()

In [3]:
print('MNIST contains %d images with %d pixels each.' % x.shape)

MNIST contains 70000 images with 784 pixels each.


We pass a callback that will take the current embedding, make a copy (this is important because the embedding is changed inplace during optimization) and add it to a list. We can also specify how often the callbacks should be called. In this instance, we'll call it at every iteration.

In [4]:
embeddings = []

tsne = TSNE(
    # Let's use the fast approximation methods
    neighbors='approx', negative_gradient_method='fft', initialization='random',
    # The embedding will be appended to the list we defined above, make sure we copy the
    # embedding, otherwise the same object reference will be stored for every iteration
    callbacks=lambda it, err, emb: embeddings.append(np.array(emb)),
    # This should be done on every iteration
    callbacks_every_iters=1,
    # -2 will use all but one core so I can look at cute cat pictures while this computes
    n_jobs=-2
)

In [5]:
%time tsne.fit(x)

CPU times: user 13min 17s, sys: 2.31 s, total: 13min 20s
Wall time: 2min 47s


TSNEEmbedding([[-22.13216169,  -2.206027  ],
               [ -2.47832169,  28.91961601],
               [-22.39443692, -26.16850575],
               ...,
               [ -8.16410769, -13.20077047],
               [-25.1144806 , -15.23230912],
               [-37.41359755,   6.07385153]])

Now that we have all the iterations in our list, we need to create the animation. We do this here using matplotlib, which is relatively straightforward. Generating the animation can take a long time, so we will save it as a gif so we can come back to it whenever we want, without having to wait again.

In [6]:
%%time
fig = plt.figure(figsize=(7, 7))
ax = fig.add_axes([0, 0, 1, 1])
ax.set_xticks([]), ax.set_yticks([])

pathcol = ax.scatter(embeddings[0][:, 0], embeddings[0][:, 1], c=y, s=1, cmap='tab10')

def update(embedding, ax, pathcol):
    # Update point positions
    pathcol.set_offsets(embedding)
    
    # Adjust x/y limits so all the points are visible
    ax.set_xlim(np.min(embedding[:, 0]), np.max(embedding[:, 0]))
    ax.set_ylim(np.min(embedding[:, 1]), np.max(embedding[:, 1]))
    
    return [pathcol]

anim = animation.FuncAnimation(
    fig, update, fargs=(ax, pathcol), interval=20,
    frames=embeddings, blit=True,
)

anim.save('mnist.gif', dpi=60, writer='imagemagick')
plt.close()

CPU times: user 11min 54s, sys: 2min 43s, total: 14min 38s
Wall time: 13min 5s
