### Interactive notebook for moving points with PPCA

In [2]:
from __future__ import print_function

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import edward2 as ed

from sklearn.decomposition import PCA

from bqplot import (
    Axis, ColorAxis, LinearScale, DateScale, DateColorScale, OrdinalScale,
    OrdinalColorScale, ColorScale, Scatter, Lines, Figure, Tooltip
)
from ipywidgets import VBox, HBox, Layout
from ipywidgets import Label

from observations import iris, mnist
DATA_DIR = './data'

import warnings
warnings.filterwarnings('ignore')

seed = 2018
np.random.seed(seed)
tf.set_random_seed(seed)

### Prepare dataset

In [107]:
x_train, y_train, _ = iris(DATA_DIR)
y_labels = np.unique(y_train)
y_label_map = dict((label, i) for i, label in enumerate(y_labels))
y_class = np.array(list(map(lambda label: y_label_map[label] , y_train)))
print(x_train.shape, y_train.shape)

(150, 4) (150,)


In [79]:
(x_train, y_train), (x_test, y_test) = mnist(DATA_DIR)
N = 500
x_train = x_train[:N]
y_class = y_train[:N]
print(x_train.shape, y_class.shape)

(500, 784) (500,)


In [108]:
# prepare a list of fixed points and a list of indices of the dataset, which is
# used when rearrange the dataset to put the fixed points to the bottom
N, D = x_train.shape
K = 2
point_indices = list(range(N))
fixed_points = {}

In [109]:
def rearrange_fixed_points():
    """Put a list of fixed points to the bottom of the dataset"""
    global x_train
    global y_class
    global point_indices
    
    fixed_indices = fixed_points.keys()
    new_indices = [i for i in point_indices if i not in fixed_indices]
    new_indices += fixed_indices
    
    point_indices = new_indices    
    x_train = x_train[new_indices]
    y_class = y_class[new_indices]

### Util functions for interactive viz

In [83]:
lbl_info = Label(color='Green', font_size='32px')
lbl_info.value = 'Fixed points: []'

In [84]:
def update_fixed_points(name, value):
    pos = value['point']
    idx = value['index']
    fixed_points[idx] = [pos['x'], pos['y']]
    lbl_info.value = "Fixed points: [{}]".format('  ,'.join([str(pi) for pi in fixed_points]))
    
def reset_fixed_points():
    fixed_points = {}
    lbl_info.value = 'Fixed points: []'

In [97]:
def viz(x_2d, y_class):
    sc_x = LinearScale()
    sc_y = LinearScale()
    sc_c = ColorScale(scheme='RdYlGn')
    scatt = Scatter(x=x_2d[:,0], y=x_2d[:,1], color=y_class,
                    scales={'x': sc_x, 'y': sc_y, 'color': sc_c},
                    enable_move=True)
    scatt.on_drag_end(update_fixed_points)
    
    ax_x = Axis(scale=sc_x)
    ax_y = Axis(scale=sc_y, tick_format='0.2f', orientation='vertical')
    fig = Figure(marks=[scatt], axes=[ax_x, ax_y])
    return fig, scatt

In [86]:
def update_scatter(scatt, x_2d, y_class):
    scatt.x = x_2d[:,0]
    scatt.y = x_2d[:,1]
    scatt.color = y_class

### Base model with original PCA in scikit-learn

In [104]:
x_2d = PCA(n_components=2).fit_transform(x_train)

In [105]:
fig, scatt = viz(x_2d, y_class)
VBox([lbl_info, fig])

A Jupyter Widget

In [100]:
print(point_indices)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 149, 131, 50, 52, 148, 21, 28]


In [101]:
rearrange_fixed_points()
print(point_indices)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 24, 26, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 145, 146, 147, 149, 131, 50, 52, 148, 21, 28, 15, 144, 71, 25, 133, 57, 27]


In [94]:
print(x_train[28], x_2d[28])

[4.8 3.1 1.6 0.2] [-2.63982127  0.31929007]
