# Mixed Likelihood GPLVM

In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

In [None]:
import time
import os

In [None]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from IPython import display
%matplotlib inline
import seaborn as sns

In [None]:
sns.set()
sns.set_context("paper")

In [None]:
import tfgp
from tfgp.util import data
from tfgp.model import MLGPLVM
print(f"Succesfully imported package: {tfgp.__file__}")

## Generate data

In [None]:
num_data = None
y, likelihood, labels = data.make_cleveland(num_data)

## Create model

In [None]:
latent_dim = 10
num_inducing = 50

In [None]:
kernel = tfgp.kernel.ARDRBF(variance=0.5, gamma=0.5, xdim=latent_dim, name="kernel")
m = MLGPLVM(y, latent_dim, num_inducing=num_inducing, kernel=kernel, likelihood=likelihood)

## Build graph

In [None]:
m.initialize()
loss = tf.losses.get_total_loss()
learning_rate = 1e-3
with tf.name_scope("train"):
    trainable_vars = tf.trainable_variables()
    optimizer = tf.train.RMSPropOptimizer(learning_rate, name="RMSProp")
    train_all = optimizer.minimize(loss, 
                                   var_list=tf.trainable_variables(),
                                   global_step=tf.train.create_global_step(),
                                   name="train")
with tf.name_scope("summary"):
    m.create_summaries()
    tf.summary.scalar("total_loss", loss, family="Loss")
    for reg_loss in tf.losses.get_regularization_losses():
        tf.summary.scalar(f"{reg_loss.name}", reg_loss, family="Loss")
    merged_summary = tf.summary.merge_all()
init = tf.global_variables_initializer()

## Callback

In [None]:
def plot(x: np.ndarray, *, z: np.ndarray = None, gammas: np.ndarray = None, loss) -> None:
    ax1.scatter(*x.T, c=labels)
    if z is not None:
        ax1.scatter(*z.T, c="k", marker="x")
    elif gammas is not None:
        ax3.bar(range(len(gammas)), gammas)
    ax_x_min, ax_y_min = np.min(x, axis=0)
    ax_x_max, ax_y_max = np.max(x, axis=0)
    ax1.set_xlim(ax_x_min, ax_x_max)
    ax1.set_ylim(ax_y_min, ax_y_max)
    ax1.set_title(f"Step {i}")
    ax2.plot(*np.array(loss).T)
    ax2.set_title(f"Loss: {train_loss}")
    display.display(f)
    display.clear_output(wait=True)

## Setup optimisation

In [None]:
root_dir = f"../.."
dataset = "cleveland"
start_time = f"{time.strftime('%Y%m%d%H%M%S')}"
log_dir = f"{root_dir}/log/{dataset}/{start_time}"
save_dir = f"{root_dir}/save/{dataset}/{start_time}"
output_dir = f"{root_dir}/output/{dataset}/{start_time}"
os.makedirs(save_dir)
os.makedirs(output_dir)

In [None]:
sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))
saver = tf.train.Saver()
# saver.restore(sess, f"{save_dir}/model.ckpt")

## Run optimisation

In [None]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
loss_list = []
n_iter = 100000
print_interval = 500
save_interval = 5000
try:
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
    sess.run(init)
    for i in range(n_iter):
        sess.run(train_all)
        if i % print_interval == 0:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            train_loss, summary = sess.run([loss, merged_summary], options=run_options, run_metadata=run_metadata)
            summary_writer.add_run_metadata(run_metadata, f"step_{i}", global_step=i)
            summary_writer.add_summary(summary, i)
            gammas = m.kernel._gamma.eval()
            x_mean = m.qx_mean.eval().T
            x_mean = x_mean[:, np.argsort(gammas)[-2:]]
            z = m.z.eval()
            loss_list.append([i, train_loss])
            plot(x_mean, gammas=gammas, loss=loss_list)
            ax1.cla()
            ax2.cla()
            ax3.cla()
        if i % save_interval == 0:
            saver.save(sess, f"{save_dir}/model.ckpt", global_step=i)
            np.savetxt(f"{output_dir}/x_mean_{i}.csv", x_mean)
            np.savetxt(f"{output_dir}/z_{i}.csv", z)
            np.savetxt(f"{output_dir}/labels.csv", labels)
            plot(x_mean, gammas=gammas, loss=loss_list)
            plt.savefig(f"{output_dir}/fig_{i}.eps")
            ax1.cla()
            ax2.cla()
            ax3.cla()
except KeyboardInterrupt:
    pass
finally:
    gammas = m.kernel._gamma.eval()
    x_mean = m.qx_mean.eval().T
    x_mean = x_mean[:, np.argsort(gammas)[-2:]]
    z = m.z.eval()
    loss_list.append([i, loss.eval()])
    plot(x_mean, gammas=gammas, loss=loss_list)

## BGPLVM

In [None]:
import GPy

In [None]:
y_sparse = np.loadtxt("../../util/cleveland.csv", delimiter=",")

In [None]:
k = GPy.kern.RBF(latent_dim, ARD=True)
bgplvm = GPy.models.BayesianGPLVM(y_sparse, latent_dim, num_inducing=num_inducing, kernel=k)

In [None]:
bgplvm.optimize(messages=True)

In [None]:
x_bgplvm = np.array(bgplvm.latent_space.mean)
x_bgplvm = x_bgplvm[:, np.argsort(bgplvm.rbf.lengthscale)[:2]]
plt.scatter(*x_bgplvm.T, c=labels)

## PCA

In [None]:
x_pca = tfgp.util.pca_reduce(y_sparse, 2)
plt.scatter(*x_pca.T, c=labels)
plt.colorbar()

## Compute 1NN error

In [None]:
binary_labels = (labels >= 1).astype(int)

In [None]:
k = 2
l = binary_labels
err_mlgplvm = tfgp.util.knn_abs_error(x_mean, l, k)
err_bgplvm = tfgp.util.knn_abs_error(x_bgplvm, l, k)
err_pca = tfgp.util.knn_abs_error(x_pca, l, k)
print(f"Missclasifications with MLGPLVM: {err_mlgplvm}")
print(f"Missclasifications with BGPLVM: {err_bgplvm}")
print(f"Missclasifications with PCA: {err_pca}")

# Save figures

In [None]:
plt.scatter(*x_mean.T, c=labels)
plt.colorbar()
plt.savefig(f"{output_dir}/{dataset}_mlgplvm.eps", format="eps", dpi=1000)

In [None]:
plt.bar(range(len(gammas)), gammas)
plt.savefig(f"{output_dir}/{dataset}_gamma.eps", format="eps", dpi=1000)

In [None]:
plt.scatter(*x_bgplvm.T, c=labels)
plt.colorbar()
plt.savefig(f"{output_dir}/{dataset}_bgplvm.eps", format="eps", dpi=1000)

In [None]:
plt.scatter(*x_pca.T, c=labels)
plt.colorbar()
plt.savefig(f"{output_dir}/{dataset}_pca.eps", format="eps", dpi=1000)