### Interactive notebook for moving points with PPCA

In [1]:
from __future__ import print_function

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import edward2 as ed

from sklearn.decomposition import PCA

from bqplot import (
    Axis, ColorAxis, LinearScale, DateScale, DateColorScale, OrdinalScale,
    OrdinalColorScale, ColorScale, Scatter, Lines, Figure, Tooltip
)
from ipywidgets import VBox, HBox, Layout
from ipywidgets import Label

from observations import iris, mnist
DATA_DIR = './data'

import warnings
warnings.filterwarnings('ignore')

seed = 2018
np.random.seed(seed)
tf.set_random_seed(seed)

  return f(*args, **kwds)
  from ._conv import register_converters as _register_converters
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)

### Prepare dataset

In [2]:
x_train, y_train, _ = iris(DATA_DIR)
y_labels = np.unique(y_train)
y_label_map = dict((label, i) for i, label in enumerate(y_labels))
y_class = np.array(list(map(lambda label: y_label_map[label] , y_train)))
print(x_train.shape, y_train.shape)

(150, 4) (150,)


In [None]:
(x_train, y_train), (x_test, y_test) = mnist(DATA_DIR)
N = 500
x_train = x_train[:N]
y_class = y_train[:N]
print(x_train.shape, y_class.shape)

In [3]:
# prepare a list of fixed points and a list of indices of the dataset, which is
# used when rearrange the dataset to put the fixed points to the bottom
N, D = x_train.shape
K = 2
point_indices = list(range(N))
fixed_points = {}

In [4]:
def rearrange_fixed_points():
    """Put a list of fixed points to the bottom of the dataset"""
    global x_train
    global y_class
    global point_indices
    
    fixed_indices = fixed_points.keys()
    new_indices = [i for i in point_indices if i not in fixed_indices]
    new_indices += fixed_indices
    
    point_indices = new_indices    
    x_train = x_train[new_indices]
    y_class = y_class[new_indices]

### Util functions for interactive viz

In [5]:
lbl_info = Label(color='Green', font_size='32px')
lbl_info.value = 'Fixed points: []'

In [150]:
def update_fixed_points(name, value):
    pos = value['point']
    idx = value['index']
    fixed_points[idx] = [pos['x'], pos['y']]
    lbl_info.value = "Fixed points: [{}]".format('  ,'.join([str(pi) for pi in fixed_points]))
    
def reset_fixed_points():
    global fixed_points
    fixed_points = {}
    lbl_info.value = 'Fixed points: []'

In [7]:
def viz(x_2d, y_class):
    sc_x = LinearScale()
    sc_y = LinearScale()
    sc_c = ColorScale(scheme='RdYlGn')
    def_tt = Tooltip(fields=['x', 'y', 'color'], formats=['.2f', '.2f', ''])
    scatt = Scatter(x=x_2d[:,0], y=x_2d[:,1], color=y_class,
                    scales={'x': sc_x, 'y': sc_y, 'color': sc_c},
                    tooltip=def_tt,
                    enable_move=True)
    scatt.on_drag_end(update_fixed_points)
    
    ax_x = Axis(scale=sc_x)
    ax_y = Axis(scale=sc_y, tick_format='0.2f', orientation='vertical')
    fig = Figure(marks=[scatt], axes=[ax_x, ax_y])
    return fig, scatt

In [151]:
def update_scatter(scatt, x_2d, y_class):
    scatt.x = x_2d[:,0]
    scatt.y = x_2d[:,1]
    scatt.color = y_class

In [152]:
def loss_chart(losses):
    sc_x = LinearScale()
    sc_y = LinearScale()
    def_tt = Tooltip(fields=['x', 'y'])
    line = Lines(x=np.arange(len(losses)), y=losses,
                 scales={'x': sc_x, 'y': sc_y})
    
    ax_x = Axis(scale=sc_x)
    ax_y = Axis(scale=sc_y, tick_format='0.2f', orientation='vertical')
    fig = Figure(marks=[line], axes=[ax_x, ax_y])
    return fig, line

def update_loss_chart(chart, losses):
    chart.x=np.arange(len(losses))
    chart.y=losses

### Base model with original PCA in scikit-learn

In [153]:
x_2d = PCA(n_components=2).fit_transform(x_train)

In [154]:
fig1, scatt1 = viz(x_2d, y_class)
scatt1.enable_move = False
fig1

A Jupyter Widget

### Probabilistic PCA

In [179]:
def ppca_model(sigma_noise=0.5, sigma_fixed=0.1):
    
    w = ed.Normal(loc=tf.zeros([D, K]),
                  scale=tf.ones([D, K]),
                  name='w')
        
    n_fixed = len(fixed_points)
    if 0 == n_fixed:
        z = ed.Normal(loc=tf.zeros([K, N]),
                      scale=tf.ones([K, N]),
                      name='z')
    else:
        print('{} fixed points:'.format(n_fixed))
        for fixed_id in fixed_points.keys():
            print('id: {}, class: {}'.format(fixed_id, y_class[fixed_id]))
        fixed_pos = list(fixed_points.values())
        fixed_pos = np.array(fixed_pos, dtype=np.float32)
        
        z_loc = tf.concat([
            tf.zeros([K, N - n_fixed]),
            tf.constant(fixed_pos.T)
        ], axis=1)
        
        stddv_one_fixed_point = [sigma_fixed] * K
        stddv_all_fixed_points = [stddv_one_fixed_point for _ in range(n_fixed)]
        stddv_all_fixed_points = np.array(stddv_all_fixed_points, dtype=np.float32)
        
        z_std = tf.concat([
            tf.ones([K, N - n_fixed]),
            tf.constant(stddv_all_fixed_points.T)
        ], axis=1)
        
        z = ed.Normal(loc=z_loc, scale=z_std, name='z')
    
    x = ed.Normal(loc=tf.matmul(w, z),
                  scale=sigma_noise * tf.ones([D, N]),
                  name='x')
    return x, w, z

In [157]:
def train_map(num_epochs=1000):
    log_joint = ed.make_log_joint_fn(ppca_model)
    
    tf.reset_default_graph()
    w = tf.Variable(np.ones([D,K]), dtype=tf.float32)
    z = tf.Variable(np.ones([K,N]), dtype=tf.float32)
    
    map_obj = -log_joint(
        fixed_points=list(fixed_points.values()),
        x=x_train.T, w=w,z=z)
    train_proc = tf.train.AdamOptimizer(learning_rate=0.05).minimize(map_obj)

    losses = []
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for i in range(1000):
            sess.run(train_proc)
            if i % 10 == 0:
                loss = sess.run(map_obj)
                losses.append(loss)
                update_loss_chart(loss_line, losses)
        w,z = sess.run([w,z])

    return losses, w, z

In [156]:
fig_loss, loss_line = loss_chart([])
fig_loss

A Jupyter Widget

In [191]:
losses, w_map, z_mean = train_map(num_epochs=1000)

2 fixed points:
id: 50, class: 1
id: 41, class: 0


In [192]:
fig2, scatt2 = viz(z_mean.T, y_class)
VBox([lbl_info, fig2])

A Jupyter Widget

In [190]:
update_scatter(scatt2, z_mean.T, y_class)

In [181]:
print(z_mean.T[50])

[0.72764516 0.7498041 ]


In [193]:
reset_fixed_points()
print(fixed_points)

{}
