### Interactive notebook for moving points with PPCA

In [185]:
from __future__ import print_function

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import edward2 as ed

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from bqplot import (
    Axis, ColorAxis, LinearScale, DateScale, DateColorScale, OrdinalScale,
    OrdinalColorScale, ColorScale, Scatter, Lines, Figure, Tooltip
)
from ipywidgets import VBox, HBox, Layout
from ipywidgets import Label, Button, IntSlider, FloatSlider
from ipywidgets import interact, interactive_output

from observations import iris, mnist
DATA_DIR = './data'

seed = 2018
np.random.seed(seed)
tf.set_random_seed(seed)

warnings.filterwarnings('ignore')

### Prepare dataset

In [107]:
dataset_name='mnist'

In [186]:
def load_dataset(name):
    if name=='iris':
        x_train, y_class = load_iris()
    elif name=='mnist':
        x_train, y_class = load_mnist()
    else:
        raise ValueError('Invalid dataset name: {}'.format(name))
        
    x_train = StandardScaler().fit_transform(x_train)
    return x_train, y_class

In [187]:
def load_iris():
    x_train, y_train, _ = iris(DATA_DIR)
    y_labels = np.unique(y_train)
    y_label_map = dict((label, i) for i, label in enumerate(y_labels))
    y_class = np.array(list(map(lambda label: y_label_map[label] , y_train)))
    return x_train, y_class

In [212]:
selected_classes = [0,1]
num_datapoints = 123
def load_mnist():
    (x_train, y_train), (x_test, y_test) = mnist(DATA_DIR)
    mask = [True if yclass in selected_classes else False for yclass in y_train]
    x_train = x_train[mask][:num_datapoints]
    y_class = y_train[mask][:num_datapoints]
    return x_train, y_class

In [213]:
# prepare a list of fixed points and a list of indices of the dataset, which is
# used when rearrange the dataset to put the fixed points to the bottom
x_train, y_class = load_dataset(dataset_name)
print(x_train.shape, y_class.shape)
                                
point_indices = list(range(x_train.shape[0]))
fixed_points = {
    # point_id => ([old_x, old_y], [new_x, new_y])
}

# keep track of old positions of shape NxK, K=2
z_init = None

(123, 784) (123,)


In [190]:
def rearrange_fixed_points():
    """Put a list of fixed points to the bottom of the dataset"""
    global x_train
    global y_class
    global point_indices
    global z_init
    
    fixed_indices = fixed_points.keys()
    new_indices = [i for i in point_indices if i not in fixed_indices]
    new_indices += fixed_indices
    
    point_indices = new_indices
    x_train = x_train[new_indices]
    y_class = y_class[new_indices]
    if z_init is not None:
        z_init = z_init[new_indices]
        assert(x_train.shape[0] == z_init.shape[0])

### Util functions for interactive viz

In [204]:
def on_moving_started(source_obj, target):
    pos = target['point']
    idx = target['index']
    fixed_points[idx] = ([pos['x'], pos['y']], [])
    lbl_info.value = 'Moving id {} class {}'.format(idx, y_class[idx])
    
def show_fixed_points():
    info = ''
    for p in fixed_points:
        ([x0, y0], [x1, y1]) = fixed_points[p]
        info += '\nid {}, class {}: [{:.2f}, {:.2f}] -> [{:.2f}, {:.2f}]'.format(
            p, y_class[p], x0, y0, x1, y1)
    return info
    
def update_fixed_points(source_obj, target):
    pos = target['point']
    idx = target['index']
    old_pos = fixed_points[idx][0]
    new_pos = [pos['x'], pos['y']]
    fixed_points[idx] = (old_pos, new_pos)
    lbl_info.value = show_fixed_points()
    
def reset_fixed_points():
    global fixed_points
    global x_train
    global y_class
    global point_indices
    
    fixed_points = {}
    x_train, y_class = load_dataset(dataset_name)
    point_indices = list(range(x_train.shape[0]))
    lbl_info.value = 'Fixed points: []'
    
def reset_all():
    reset_fixed_points()
    global z_init
    z_init = None

In [192]:
def viz(x_2d, y_class): # x_2d of shape NxK, K=2
    sc_x = LinearScale() # min=-2, max=2
    sc_y = LinearScale() # min=-2, max=2
    sc_c = OrdinalColorScale(scheme='Paired') #scheme='RdYlGn'
    def_tt = Tooltip(fields=['x', 'y', 'color'], formats=['.2f', '.2f', ''])
        
    # plot current projected points
    scatt = Scatter(x=x_2d[:,0], y=x_2d[:,1], color=y_class,
                    # names=point_indices,
                    stroke='black', stroke_width=0.2,
                    scales={'x': sc_x, 'y': sc_y, 'color': sc_c},
                    # tooltip=def_tt,
                    enable_move=True)
    
    scatt.on_drag_start(on_moving_started)
    scatt.on_drag_end(update_fixed_points)
    
    ax_x = Axis(scale=sc_x)
    ax_y = Axis(scale=sc_y, tick_format='0.2f', orientation='vertical')
    
    # plot trace from old position
    traces = Lines(x=[], y=[], colors=['black'], opacities=[0.6],stroke_width=0.6,
                   scales={'x': sc_x, 'y': sc_y})
    
    # plot trace of fixed points
    traces_fixed_points = Lines(x=[], y=[],scales={'x': sc_x, 'y': sc_y})

    fig = Figure(marks=[scatt, traces, traces_fixed_points], axes=[ax_x, ax_y])
    return fig, scatt, traces, traces_fixed_points

In [226]:
def update_scatter(scatt, traces, traces_fixed_points, x_2d):   
    scatt.x = x_2d[:,0]
    scatt.y = x_2d[:,1]
    scatt.color = y_class
    scatt.default_opacities = [0.6]

    x_pos = []
    y_pos = []
    if z_init is not None:
        assert( (x_2d.shape[0] == z_init.shape[0]) and (x_2d.shape[1] == z_init.shape[1]) )            
        for i in range(x_2d.shape[0]):
            x0, y0 = z_init[i, :]
            x1, y1 = x_2d[i, :]
            x_pos.append([x0, x1])
            y_pos.append([y0, y1])
    traces.x = x_pos
    traces.y = y_pos
    
    x_fixed = []
    y_fixed = []
    for p in fixed_points:
        ([x0, y0], [x1, y1]) = fixed_points[p]
        x_fixed.append([x0, x1])
        y_fixed.append([y0, y1])
    traces_fixed_points.x = x_fixed
    traces_fixed_points.y = y_fixed

In [227]:
def loss_chart(losses):
    sc_x = LinearScale()
    sc_y = LinearScale()
    def_tt = Tooltip(fields=['x', 'y'])
    line = Lines(x=np.arange(len(losses)), y=losses,
                 scales={'x': sc_x, 'y': sc_y})
    
    ax_x = Axis(scale=sc_x)
    ax_y = Axis(scale=sc_y, orientation='vertical') # tick_format='0.2f'
    fig = Figure(marks=[line], axes=[ax_x, ax_y])
    return fig, line

def update_loss_chart(chart, losses):
    chart.x=np.arange(len(losses))
    chart.y=losses

In [228]:
fig_margin = {'top':10, 'bottom':20, 'left':65, 'right':0}
fig_loss, loss_line = loss_chart([])
fig_loss.layout.height = '250px'
fig_scatter, scatt, traces, traces_fixed_points = viz(np.zeros([1, 2]), [0])
fig_loss.fig_margin = fig_margin
fig_scatter.fig_margin = fig_margin

ctrl_sigma_data = FloatSlider(
    value=0.5,
    min=0,
    max=2.0,
    step=0.1,
    description=r'\(\sigma^2\) data',
    readout_format='.1f',
)

ctrl_sigma_fixed = FloatSlider(
    value=0.1,
    min=0.001,
    max=2.0,
    step=0.01,
    description=r'\(\sigma^2\) fix',
    readout_format='.3f',
)

ctrl_num_epochs = IntSlider(
    value=1000,
    min=10,
    max=2000,
    step=10,
    description='n_epochs'
)

def reset_data_and_gui():
    reset_all()
    scatt.x = []; scatt.y = []
    loss_line.x = []; loss_line.y = []
    traces.x = []; traces.y = []
    traces_fixed_points.x = []; traces_fixed_points.y = []

btn_train = Button(description='Train MAP')
btn_reset = Button(description='Reset fixed points')
btn_reset_all = Button(description='Reset all')
btn_reset.on_click(lambda _:reset_fixed_points())
btn_reset_all.on_click(lambda _:reset_data_and_gui())
btn_train.on_click(lambda _:train_update_result())


def train_update_result():
    global z_init
    fixed_pos = [fixed_points[p][1] for p in fixed_points ]
    
    losses, w_map, z_mean = train_map(
        num_epochs=ctrl_num_epochs.value, 
        sigma_data=ctrl_sigma_data.value, 
        sigma_fixed=ctrl_sigma_fixed.value,
        fixed_pos=fixed_pos,
        loss_line=loss_line)
    
    update_scatter(scatt, traces, traces_fixed_points, z_mean.T)
    if z_init is None:
        z_init = z_mean.T
    # reset_fixed_points()
    
lbl_info = Label(color='Green', font_size='32px')
lbl_info.value = 'Fixed points: []'

ctrl_box = VBox([ctrl_sigma_data, ctrl_sigma_fixed, ctrl_num_epochs])
btn_box = HBox([btn_train, btn_reset_all])
left_box = VBox([ctrl_box, btn_box, lbl_info],
                layout=Layout(flex='2 1 0%', width='auto'))
right_box = VBox([fig_loss, fig_scatter],
                 layout=Layout(flex='5 1 0%', width='auto'))
gui = HBox([left_box, right_box],layout=Layout(display='flex' ,width='100%', height='600px'))

### Base model with original PCA in scikit-learn

In [229]:
x_2d = PCA(n_components=2, svd_solver='randomized') \
    .fit_transform(x_train)

In [230]:
fig_scat, scatt1, _, _ = viz(x_2d, y_class)
scatt1.enable_move = False
VBox([fig_scat, lbl_info])

A Jupyter Widget

### Probabilistic PCA

In [231]:
show_debug = True

def ppca_model(N, D, K=2, sigma_data=0.5, sigma_fixed=0.005, fixed_pos=[]):
    if show_debug:
        print("Build PPCA model with sigma_data={}, sigma_fixed={}".format(sigma_data, sigma_fixed))
    
    w = ed.Normal(loc=tf.zeros([D, K]), scale=tf.ones([D, K]), name='w')
        
    n_fixed = len(fixed_pos)
    if 0 == n_fixed:
        z = ed.Normal(loc=tf.zeros([K, N]),scale=tf.ones([K, N]), name='z')
    else:
        fixed_pos = np.array(fixed_pos, dtype=np.float32)
        z_loc = tf.concat([tf.zeros([K, N - n_fixed]), tf.constant(fixed_pos.T)], axis=1)
        
        stddv_one_fixed_point = [sigma_fixed] * K
        stddv_all_fixed_points = [stddv_one_fixed_point for _ in range(n_fixed)]
        stddv_all_fixed_points = np.array(stddv_all_fixed_points, dtype=np.float32)
        z_std = tf.concat([tf.ones([K, N - n_fixed]), tf.constant(stddv_all_fixed_points.T)], axis=1)
        
        z = ed.Normal(loc=z_loc, scale=z_std, name='z')
    
    x = ed.Normal(loc=tf.matmul(w, z), scale=sigma_data * tf.ones([D, N]), name='x')
    return x, w, z

In [232]:
def train_map(num_epochs=1000, sigma_data=0.5, sigma_fixed=0.005, fixed_pos=[], loss_line=None):
    if show_debug:
        print('Classes: ', np.unique(y_class))
        print('Fixed points: ', show_fixed_points())
    
    N, D = x_train.shape
    K = 2
    
    log_joint = ed.make_log_joint_fn(ppca_model)
    
    tf.reset_default_graph()
    w = tf.Variable(np.ones([D,K]), dtype=tf.float32)
    z = tf.Variable(np.ones([K,N]), dtype=tf.float32)
    
    rearrange_fixed_points()
    
    map_obj = -log_joint(N=N, D=D, K=K,
        sigma_data=sigma_data, sigma_fixed=sigma_fixed,
        fixed_pos=np.array(fixed_pos, dtype=np.float32),
        x=x_train.T, w=w,z=z)
    train_proc = tf.train.AdamOptimizer(learning_rate=0.05).minimize(map_obj)

    losses = []
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for i in range(num_epochs):
            sess.run(train_proc)
            if i % 10 == 0:
                loss = sess.run(map_obj)
                losses.append(loss)
                if loss_line is not None:
                    update_loss_chart(loss_line, losses)
        w,z = sess.run([w,z])

    return losses, w, z

In [233]:
gui

A Jupyter Widget

Classes:  [0 1]
Fixed points:  
Build PPCA model with sigma_data=0.5, sigma_fixed=0.1
Classes:  [0 1]
Fixed points:  
id 85, class 1: [-0.55, -0.30] -> [-0.59, 0.05]
id 125, class 1: [-0.56, -0.40] -> [-0.61, 0.31]
id 74, class 1: [-0.53, -0.43] -> [-0.60, 0.18]
Build PPCA model with sigma_data=0.5, sigma_fixed=0.1
Classes:  [0 1]
Fixed points:  
Build PPCA model with sigma_data=0.1, sigma_fixed=0.1
Classes:  [0 1]
Fixed points:  
id 89, class 1: [-0.21, -0.44] -> [-0.01, -0.46]
id 85, class 1: [-0.25, -0.47] -> [0.04, -0.48]
id 125, class 1: [-0.33, -0.47] -> [0.09, -0.49]
Build PPCA model with sigma_data=0.1, sigma_fixed=0.1
Classes:  [0 1]
Fixed points:  
Build PPCA model with sigma_data=0.2, sigma_fixed=0.1


In [224]:
@interact(_show_label=False, _show_traces=True, _show_debug=True)
def togle_flags(_show_label, _show_traces, _show_debug):
    global show_debug
    show_debug = _show_debug

    scatt.names = point_indices
    scatt.display_names = _show_label
    
    if _show_traces:
        fig_scatter.marks = [scatt, traces, traces_fixed_points]
    else:
        fig_scatter.marks = [scatt]

A Jupyter Widget

In [211]:
@interact(n_selected=(10, 500), c0=True, c1=True, c2=False, c3=False, c4=False, c5=False, c6=False, c7=False, c8=False, c9=False)
def load_mnist_by_class(n_selected,c0,c1,c2,c3,c4,c5,c6,c7,c8,c9):
    global selected_classes
    global num_datapoints
    num_datapoints = n_selected
    params = locals()
    selected_classes = [i for i in range(10) if params['c{}'.format(i)]]
    print(selected_classes)

A Jupyter Widget