### Interactive notebook for moving points with PPCA

In [1]:
from __future__ import print_function

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import edward2 as ed

from sklearn.decomposition import PCA

from bqplot import (
    Axis, ColorAxis, LinearScale, DateScale, DateColorScale, OrdinalScale,
    OrdinalColorScale, ColorScale, Scatter, Lines, Figure, Tooltip
)
from ipywidgets import VBox, HBox, Layout
from ipywidgets import Label, Button, IntSlider, FloatSlider
from ipywidgets import interactive_output

from observations import iris, mnist
DATA_DIR = './data'

seed = 2018
np.random.seed(seed)
tf.set_random_seed(seed)

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*

### Prepare dataset

In [2]:
dataset_name='iris'

In [23]:
def load_dataset(name):
    if name=='iris':
        x_train, y_class = load_iris()
    elif name=='mnist':
        x_train, y_class = load_mnist()
    else:
        raise ValueError('Invalid dataset name: {}'.format(name))
    return x_train, y_class

In [4]:
def load_iris():
    x_train, y_train, _ = iris(DATA_DIR)
    y_labels = np.unique(y_train)
    y_label_map = dict((label, i) for i, label in enumerate(y_labels))
    y_class = np.array(list(map(lambda label: y_label_map[label] , y_train)))
    return x_train, y_class

In [5]:
def load_mnist(N=500):
    (x_train, y_train), (x_test, y_test) = mnist(DATA_DIR)
    x_train = x_train[:N]
    y_class = y_train[:N]
    return x_train, y_class

In [75]:
# prepare a list of fixed points and a list of indices of the dataset, which is
# used when rearrange the dataset to put the fixed points to the bottom
x_train, y_class = load_dataset(dataset_name)
print(x_train.shape, y_class.shape)
                                
point_indices = list(range(x_train.shape[0]))
fixed_points = {
    # point_id => ([old_x, old_y], [new_x, new_y])
}

# keep track of old positions of shape NxK, K=2
z_init = None

(150, 4) (150,)


In [77]:
def rearrange_fixed_points():
    """Put a list of fixed points to the bottom of the dataset"""
    global x_train
    global y_class
    global point_indices
    global z_init
    
    fixed_indices = fixed_points.keys()
    new_indices = [i for i in point_indices if i not in fixed_indices]
    new_indices += fixed_indices
    
    point_indices = new_indices
    x_train = x_train[new_indices]
    y_class = y_class[new_indices]
    if z_init is not None:
        z_init = z_init[new_indices]
        assert(x_train.shape[0] == z_init.shape[0])

### Util functions for interactive viz

In [82]:
def on_moving_started(source_obj, target):
    pos = target['point']
    idx = target['index']
    fixed_points[idx] = ([pos['x'], pos['y']], [])
    lbl_info.value = 'Moving id {} class {}'.format(idx, y_class[idx])
    
def update_fixed_points(source_obj, target):
    pos = target['point']
    idx = target['index']
    old_pos = fixed_points[idx][0]
    new_pos = [pos['x'], pos['y']]
    fixed_points[idx] = (old_pos, new_pos)
    
    info = ''
    for p in fixed_points:
        ([x0, y0], [x1, y1]) = fixed_points[p]
        info += '\nid {}, class {}: [{:.2f}, {:.2f}] -> [{:.2f}, {:.2f}]'.format(
            p, y_class[p], x0, y0, x1, y1)
    lbl_info.value = info
    
def reset_fixed_points():
    global fixed_points
    global x_train
    global y_class
    global point_indices
    global z_init
    
    fixed_points = {}
    x_train, y_class = load_dataset('iris')
    point_indices = list(range(x_train.shape[0]))
    z_init = None
    
    lbl_info.value = 'Fixed points: []'

In [120]:
def viz(x_2d, y_class): # x_2d of shape NxK, K=2
    sc_x = LinearScale()
    sc_y = LinearScale()
    sc_c = OrdinalColorScale(scheme='Paired') #scheme='RdYlGn'
    def_tt = Tooltip(fields=['x', 'y', 'color'], formats=['.2f', '.2f', ''])
    
    # plot current projected points
    scatt = Scatter(x=x_2d[:,0], y=x_2d[:,1], color=y_class,
                    scales={'x': sc_x, 'y': sc_y, 'color': sc_c},
                    # tooltip=def_tt,
                    enable_move=True,
                    display_legend=False)
    
    scatt.on_drag_start(on_moving_started)
    scatt.on_drag_end(update_fixed_points)
    
    ax_x = Axis(scale=sc_x)
    ax_y = Axis(scale=sc_y, tick_format='0.2f', orientation='vertical')
    
    # plot trace from old position
    traces = Lines(x=[], y=[], scales={'x': sc_x, 'y': sc_y},)
    
#     z_init = x_2d + 0.2
#     x_pos = []
#     y_pos = []
#     assert( (x_2d.shape[0] == z_init.shape[0]) and (x_2d.shape[1] == z_init.shape[1]) )
#     for i in range(x_2d.shape[0]):
#         x0, y0 = z_init[i, :]
#         x1, y1 = x_2d[i, :]
#         x_pos.append([x0, x1])
#         y_pos.append([y0, y1])
#     traces.x = x_pos
#     traces.y = y_pos
    
    fig = Figure(marks=[scatt, traces], axes=[ax_x, ax_y])
    return fig, scatt, traces

In [130]:
def update_scatter(scatt, traces, x_2d, y_class):   
    scatt.x = x_2d[:,0]
    scatt.y = x_2d[:,1]
    scatt.color = y_class
    
    if z_init is not None:
        # draw lines indicating the fixed points
        x_pos = []
        y_pos = []
        for p in fixed_points:
            ([x0, y0], [x1, y1]) = fixed_points[p]
            [x2, y2] = x_2d[p, :]
            x_pos.append([x0, x1, x2])
            y_pos.append([y0, y1, y2])
#         assert( (x_2d.shape[0] == z_init.shape[0]) and (x_2d.shape[1] == z_init.shape[1]) )
#         for i in range(x_2d.shape[0]):
#             x0, y0 = z_init[i, :]
#             x1, y1 = x_2d[i, :]
#             x_pos.append([x0, x1])
#             y_pos.append([y0, y1])
        traces.x = x_pos
        traces.y = y_pos

In [104]:
def loss_chart(losses):
    sc_x = LinearScale()
    sc_y = LinearScale()
    def_tt = Tooltip(fields=['x', 'y'])
    line = Lines(x=np.arange(len(losses)), y=losses,
                 scales={'x': sc_x, 'y': sc_y})
    
    ax_x = Axis(scale=sc_x)
    ax_y = Axis(scale=sc_y, tick_format='0.2f', orientation='vertical')
    fig = Figure(marks=[line], axes=[ax_x, ax_y])
    return fig, line

def update_loss_chart(chart, losses):
    chart.x=np.arange(len(losses))
    chart.y=losses

In [122]:
fig_loss, loss_line = loss_chart([])
fig_scat, scatter, traces = viz(np.zeros([1, 2]), [0])

ctrl_sigma_data = FloatSlider(
    value=0.5,
    min=0,
    max=1.0,
    step=0.1,
    description=r'\(\sigma^2\) data',
    readout_format='.1f',
)

ctrl_sigma_fixed = FloatSlider(
    value=0.1,
    min=0,
    max=1.0,
    step=0.05,
    description=r'\(\sigma^2\) fix',
    readout_format='.2f',
)

ctrl_num_epochs = IntSlider(
    value=1000,
    min=10,
    max=2000,
    step=10,
    description='num epochs'
)

btn_train = Button(description='Train MAP')
btn_reset = Button(description='Reset fixed points')
btn_reset.on_click(lambda _:reset_fixed_points())
btn_train.on_click(lambda _:train_update_result())

def train_update_result():
    global z_init
    fixed_pos = [fixed_points[p][1] for p in fixed_points ]
    
    losses, w_map, z_mean = train_map(
        num_epochs=ctrl_num_epochs.value, 
        sigma_data=ctrl_sigma_data.value, 
        sigma_fixed=ctrl_sigma_fixed.value,
        fixed_pos=fixed_pos,
        loss_line=loss_line)
    
    if z_init is None:
        z_init = z_mean.T
    update_scatter(scatter, traces, z_mean.T, y_class)
    reset_fixed_points()
    
lbl_info = Label(color='Green', font_size='32px')
lbl_info.value = 'Fixed points: []'

ctrl_box = VBox([ctrl_sigma_data, ctrl_sigma_fixed, ctrl_num_epochs])
btn_box = VBox([btn_train, btn_reset])
left_box = VBox([ctrl_box, btn_box, fig_loss], layout=Layout(width='32%'))
right_box = VBox([lbl_info, fig_scat], layout=Layout(width='68%'))
gui = HBox([left_box, right_box],layout=Layout(display='flex' ,width='100%', height='580px'))

  if scales[name].rtype != trait.get_metadata('rtype'):
  if scales[name].rtype != trait.get_metadata('rtype'):
  if scales[name].rtype != trait.get_metadata('rtype'):
  if scales[name].rtype != trait.get_metadata('rtype'):
  if scales[name].rtype != trait.get_metadata('rtype'):
  if scales[name].rtype != trait.get_metadata('rtype'):
  if scales[name].rtype != trait.get_metadata('rtype'):
object.__init__() takes no parameters
This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.
  super(Widget, self).__init__(**kwargs)


### Base model with original PCA in scikit-learn

In [14]:
x_2d = PCA(n_components=2, svd_solver='randomized') \
    .fit_transform(x_train)

In [15]:
fig_scat, scatt1 = viz(x_2d, y_class)
scatt1.enable_move = False
VBox([fig_scat, lbl_info])

  if scales[name].rtype != trait.get_metadata('rtype'):
  if scales[name].rtype != trait.get_metadata('rtype'):
  if scales[name].rtype != trait.get_metadata('rtype'):


A Jupyter Widget

### Probabilistic PCA

In [16]:
def ppca_model(N, D, K=2, sigma_data=0.5, sigma_fixed=0.005, fixed_pos=[]):
    print("Build PPCA model with sigma_data={}, sigma_fixed={}, list fixed point: {}".format(
          sigma_data, sigma_fixed, fixed_pos))
    
    w = ed.Normal(loc=tf.zeros([D, K]), scale=tf.ones([D, K]), name='w')
        
    n_fixed = len(fixed_pos)
    if 0 == n_fixed:
        z = ed.Normal(loc=tf.zeros([K, N]),scale=tf.ones([K, N]), name='z')
    else:
        fixed_pos = np.array(fixed_pos, dtype=np.float32)
        z_loc = tf.concat([tf.zeros([K, N - n_fixed]), tf.constant(fixed_pos.T)], axis=1)
        
        stddv_one_fixed_point = [sigma_fixed] * K
        stddv_all_fixed_points = [stddv_one_fixed_point for _ in range(n_fixed)]
        stddv_all_fixed_points = np.array(stddv_all_fixed_points, dtype=np.float32)
        z_std = tf.concat([tf.ones([K, N - n_fixed]), tf.constant(stddv_all_fixed_points.T)], axis=1)
        
        z = ed.Normal(loc=z_loc, scale=z_std, name='z')
    
    x = ed.Normal(loc=tf.matmul(w, z), scale=sigma_data * tf.ones([D, N]), name='x')
    return x, w, z

In [21]:
def train_map(num_epochs=1000, sigma_data=0.5, sigma_fixed=0.005, fixed_pos=[], loss_line=None):
    N, D = x_train.shape
    K = 2
    
    log_joint = ed.make_log_joint_fn(ppca_model)
    
    tf.reset_default_graph()
    w = tf.Variable(np.ones([D,K]), dtype=tf.float32)
    z = tf.Variable(np.ones([K,N]), dtype=tf.float32)
    
    # TODO: fix this
    rearrange_fixed_points()
    
    map_obj = -log_joint(N=N, D=D, K=K,
        sigma_data=sigma_data, sigma_fixed=sigma_fixed,
        fixed_pos=np.array(fixed_pos, dtype=np.float32),
        x=x_train.T, w=w,z=z)
    train_proc = tf.train.AdamOptimizer(learning_rate=0.05).minimize(map_obj)

    losses = []
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for i in range(num_epochs):
            sess.run(train_proc)
            if i % 10 == 0:
                loss = sess.run(map_obj)
                losses.append(loss)
                if loss_line is not None:
                    update_loss_chart(loss_line, losses)
        w,z = sess.run([w,z])

    return losses, w, z

In [131]:
gui

A Jupyter Widget

Build PPCA model with sigma_data=0.5, sigma_fixed=0.1, list fixed point: []
Build PPCA model with sigma_data=0.5, sigma_fixed=0.1, list fixed point: [[0.6324415  0.6756253 ]
 [1.0820774  0.38298535]]
Build PPCA model with sigma_data=0.5, sigma_fixed=0.1, list fixed point: [[0.7455326 1.0018146]
 [1.0624291 1.0059854]]
Build PPCA model with sigma_data=0.5, sigma_fixed=0.1, list fixed point: [[0.6296884  0.46497095]
 [0.53870356 1.1237406 ]]
Build PPCA model with sigma_data=0.5, sigma_fixed=0.1, list fixed point: [[0.82871836 0.47940478]]
Build PPCA model with sigma_data=0.5, sigma_fixed=0.1, list fixed point: [[0.3078929  0.20805965]]
Build PPCA model with sigma_data=0.5, sigma_fixed=0.1, list fixed point: [[0.7641003 0.6143052]]
Build PPCA model with sigma_data=0.5, sigma_fixed=0.1, list fixed point: [[0.7911093  0.49217698]
 [0.92008626 0.6455749 ]
 [0.8638054  0.5985659 ]
 [0.8169047  0.46990955]
 [0.903671   0.5218669 ]
 [0.93415654 0.31403747]]
Build PPCA model with sigma_data=0.5,