### Plot metamap for the resulted embedding using tsne

In [1]:
import numpy as np
from sklearn.manifold import TSNE
from dataset_utils import load_dataset

#### Read all calculated embedding from pickle file into memory

In [2]:
dataset_name = 'MNIST-SMALL'
embedding_dir = './output/{}'.format(dataset_name)
X, y, labels = load_dataset(dataset_name)

Loading dataset: MNIST-SMALL


In [4]:
# load calculated embedding for different value of perplexity
import joblib
import os

embeddings = []
perps = []
all_losses = []

for file in os.listdir(embedding_dir):
    if file.endswith('z'):
        in_name = os.path.join(embedding_dir, file)
        tsne_obj = joblib.load(in_name)
        embeddings.append(tsne_obj.embedding_.ravel())
        all_losses.append(tsne_obj.kl_divergence_)
        perps.append(tsne_obj.get_params()['perplexity'])

In [5]:
# build metatmap

from sklearn.preprocessing import normalize, scale

all_embeddings = np.array(embeddings)
meta_tsne = TSNE(random_state=0)
meta_map = meta_tsne.fit_transform(all_embeddings)

In [12]:
from bqplot import LinearScale, ColorScale, OrdinalColorScale, Axis, ColorAxis, Scatter, Figure
from bqplot.colorschemes import CATEGORY10
from ipywidgets import VBox, HBox

# meta scatter
sc_color = ColorScale()
ax_color = ColorAxis(scale=sc_color, label='Perplexity', orientation='vertical', side='left')
scatter = Scatter(x=meta_map[:, 0], y=meta_map[:, 1], color=perps, stroke='black',
                  scales={'x': LinearScale(), 'y': LinearScale(), 'color': sc_color})

# child scatter
sc_color2 = OrdinalColorScale(colors=CATEGORY10)
ax_color2 = ColorAxis(scale=sc_color2, label='Class', orientation='vertical', side='right')
child_scatter = Scatter(x=[], y=[], color=y, scales={'x': LinearScale(), 'y': LinearScale(), 'color': sc_color2}, default_size=12)

scat_fig = Figure(axes=[ax_color], marks=[scatter], title='Metamap for {}'.format(dataset_name))
child_fig = Figure(axes=[ax_color2],marks=[child_scatter], title='Detail view')

In [15]:
def plot_detail(name, value):
    idx = value['data']['index']
    perp = perps[idx]
    X2d = all_embeddings[idx].reshape(-1, 2)
    # with child_scatter.hold_sync():
    child_scatter.x = X2d[:, 0]
    child_scatter.y = X2d[:, 1]
    child_fig.title = 'Detail scatter for perp = {}'.format(perp)
    
scatter.on_element_click(plot_detail)

In [16]:
HBox([scat_fig, child_fig])

HBox(children=(Figure(axes=[ColorAxis(label='Perplexity', orientation='vertical', scale=ColorScale(), side='le…