# Mixed Likelihood GPLVM

In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

In [None]:
import time
import os

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from matplotlib import pyplot as plt
from IPython import display
%matplotlib inline
import seaborn as sns

In [None]:
sns.set()
sns.set_context("paper")

In [None]:
import tfgp
from tfgp.util import data
from tfgp.model import MLGPLVM
print(f"Succesfully imported package: {tfgp.__file__}")

## Generate data

In [None]:
num_data = None
num_classes = 36
y_train, likelihood, labels_train = data.make_binaryalphadigits(num_data, num_classes)

In [None]:
y_test = np.loadtxt("../../util/binaryalphadigits_test.csv", delimiter=",")
labels_test = np.array([[i] * 9 for i in range(num_classes)]).flatten()

In [None]:
idx = np.zeros(y_test.shape, dtype=bool)
frac_missing = 0.2
num_missing = int(frac_missing * y_test.shape[1])
idx[:, :num_missing] = 1
_ = np.apply_along_axis(np.random.shuffle, 1, idx)

In [None]:
y_test[idx] = None

In [None]:
y_noisy = np.vstack([y_train, y_test])
labels = np.hstack([labels_train, labels_test])

## Create model

In [None]:
latent_dim = 10
num_inducing = 50

In [None]:
kernel = tfgp.kernel.ARDRBF(variance=0.5, gamma=0.5, xdim=latent_dim, name="kernel")
m = MLGPLVM(y_noisy, latent_dim, num_inducing=num_inducing, kernel=kernel, likelihood=likelihood)
m.initialize()

## Build graph

In [None]:
loss = tf.losses.get_total_loss()
learning_rate = 1e-3
with tf.name_scope("train"):
    trainable_vars = tf.trainable_variables()
    optimizer = tf.train.RMSPropOptimizer(learning_rate, name="RMSProp")
    train_all = optimizer.minimize(loss, 
                                   var_list=tf.trainable_variables(),
                                   global_step=tf.train.create_global_step(),
                                   name="train")
with tf.name_scope("summary"):
    m.create_summaries()
    tf.summary.scalar("total_loss", loss, family="Loss")
    for reg_loss in tf.losses.get_regularization_losses():
        tf.summary.scalar(f"{reg_loss.name}", reg_loss, family="Loss")
    merged_summary = tf.summary.merge_all()
init = tf.global_variables_initializer()

## Callback

In [None]:
def plot(x: np.ndarray, *, z: np.ndarray = None, gammas: np.ndarray = None, loss) -> None:
    ax1.scatter(*x[labels<12].T, c=labels[labels<12], cmap="Paired", marker="d")
    ax1.scatter(*x[np.logical_and(labels>=12, labels<24)].T, c=labels[np.logical_and(labels>=12, labels<24)], cmap="Paired", marker="x")
    ax1.scatter(*x[labels>=24].T, c=labels[labels>=24], cmap="Paired", marker="*")
    if z is not None:
        ax1.scatter(*z.T, c="k", marker="x")
    ax_x_min, ax_y_min = np.min(x, axis=0)
    ax_x_max, ax_y_max = np.max(x, axis=0)
    ax1.set_xlim(ax_x_min, ax_x_max)
    ax1.set_ylim(ax_y_min, ax_y_max)
    ax1.set_title(f"Step {i}")
    
    ax2.plot(*np.array(loss).T)
    ax2.set_title(f"Loss: {train_loss}")
    
    if gammas is not None:
        ax3.bar(range(len(gammas)), gammas, tick_label=range(len(gammas)))
    
    display.display(f)
    display.clear_output(wait=True)

## Setup optimisation

In [None]:
root_dir = f"../.."
dataset = "alphadigits"
start_time = f"{time.strftime('%Y%m%d%H%M%S')}"
log_dir = f"{root_dir}/log/{dataset}/{start_time}"
save_dir = f"{root_dir}/save/{dataset}/{start_time}"
output_dir = f"{root_dir}/output/{dataset}/{start_time}"
os.makedirs(save_dir)
os.makedirs(output_dir)

In [None]:
sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))
saver = tf.train.Saver()
# saver.restore(sess, f"{save_dir}/model.ckpt")

## Run optimisation

In [None]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
loss_list = []
n_iter = 40000
print_interval = 500
save_interval = 5000
try:
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
    sess.run(init)
    for i in range(n_iter):
        sess.run(train_all)
        if i % print_interval == 0:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            train_loss, summary = sess.run([loss, merged_summary], options=run_options, run_metadata=run_metadata)
            summary_writer.add_run_metadata(run_metadata, f"step_{i}", global_step=i)
            summary_writer.add_summary(summary, i)
            gammas = m.kernel._gamma.eval()
            x_mean = m.qx_mean.eval().T
            x_mean = x_mean[:, np.argsort(gammas)[-2:]]
            z = m.z.eval()
            loss_list.append([i, train_loss])
            plot(x_mean, gammas=gammas, loss=loss_list)
            ax1.cla()
            ax2.cla()
            ax3.cla()
        if i % save_interval == 0:
            saver.save(sess, f"{save_dir}/model.ckpt", global_step=i)
            np.savetxt(f"{output_dir}/x_mean_{i}.csv", x_mean)
            np.savetxt(f"{output_dir}/z_{i}.csv", z)
            np.savetxt(f"{output_dir}/labels.csv", labels)
            plot(x_mean, gammas=gammas, loss=loss_list)
            plt.savefig(f"{output_dir}/fig_{i}.eps")
            ax1.cla()
            ax2.cla()
            ax3.cla()
except KeyboardInterrupt:
    pass
finally:
    gammas = m.kernel._gamma.eval()
    x_mean = m.qx_mean.eval().T
    x_mean = x_mean[:, np.argsort(gammas)[-2:]]
    z = m.z.eval()
    loss_list.append([i, loss.eval()])
    plot(x_mean, gammas=gammas, loss=loss_list)


## Some more plotting

In [None]:
def plot_scale():
    scale_0, scale_1 = m.qx_scale.eval()
    scale_0  = np.dot(scale_0, scale_0.T)
    scale_1  = np.dot(scale_1, scale_1.T)
    fig, ax = plt.subplots(2, 2)
    im_full_0 = ax[0, 0].imshow(scale_0)
    im_off_diag_0 = ax[0, 1].imshow(scale_0 - np.diag(np.diag(scale_0)))
    im_full_1 = ax[1, 0].imshow(scale_1)
    im_off_diag_1 = ax[1, 1].imshow(scale_1 - np.diag(np.diag(scale_1)))
    plt.colorbar(im_full_0, ax=ax[0, 0])
    plt.colorbar(im_off_diag_0, ax=ax[0, 1])
    plt.colorbar(im_full_1, ax=ax[1, 0])
    plt.colorbar(im_off_diag_1, ax=ax[1, 1])

In [None]:
plot_scale()

## PCA

In [None]:
x_pca = tfgp.util.pca_reduce(y, 2)
plt.scatter(*x_pca[labels<12].T, c=labels[labels<12], cmap="Paired", marker="d")
plt.scatter(*x_pca[np.logical_and(labels>=12, labels<24)].T, c=labels[np.logical_and(labels>=12, labels<24)], cmap="Paired", marker="x")
plt.scatter(*x_pca[labels>=24].T, c=labels[labels>=24], cmap="Paired", marker="*")

## Compute 1NN error

In [None]:
k = 2
err_mlgplvm = tfgp.util.knn_error(x_mean, labels, k)
err_pca = tfgp.util.knn_error(x_pca, labels, k)
print(f"Missclasifications with MLGPLVM: {err_mlgplvm}")
print(f"Missclasifications with PCA: {err_pca}")

# Save figures

In [None]:
plt.scatter(*x_mean[labels<12].T, c=labels[labels<12], cmap="Paired", marker="d")
plt.scatter(*x_mean[np.logical_and(labels>=12, labels<24)].T, c=labels[np.logical_and(labels>=12, labels<24)], cmap="Paired", marker="x")
plt.scatter(*x_mean[labels>=24].T, c=labels[labels>=24], cmap="Paired", marker="*")
plt.savefig(f"{output_dir}/{dataset}.eps", format="eps", dpi=1000)

# PERPLEXITY

In [None]:
split = y_train.shape[0]
y_true = np.loadtxt("../../util/binaryalphadigits_test.csv", delimiter=",")
#y_true = y_test

The likelihood $p(y)$ should be computed as $\int p(y \,|\, l) p(l \,|\, \text{model}) \, dl$ where $l$ is the logit parameter of the Bernoulli distribution. Since this integration is hard, we approximate it as $\frac{1}{T} \sum_{t=1}^T p(y \,|\, l_t)$ where $l_t$ is sampled from $p(l \,|\, \text{model})$

## Construct $p(l \,|\, \text{model})$

In [None]:
kzz = m.kernel(m.z)
kzz_inv = tf.matrix_inverse(kzz)
kxx = m.kernel(tf.matrix_transpose(m.qx_mean)[split:])
kxz = m.kernel(tf.matrix_transpose(m.qx_mean)[split:], m.z)
kzx = tf.matrix_transpose(kxz)

In [None]:
mean = kxz @ kzz_inv @ tf.matrix_transpose(m.qu_mean)
cov = kxx - kxz @ kzz_inv @ kzx

In [None]:
cov_chol = tf.cholesky(cov)# + tf.diag((tf.ones([kxx.shape[0]]) * 1)))

In [None]:
norm = tfp.distributions.MultivariateNormalTriL(tf.matrix_transpose(mean), cov_chol)

## Compute $\frac{1}{T} \sum_{t=1}^T p(y_i \,|\, l_t)$

In [None]:
logits = tf.matrix_transpose(norm.sample(100)).eval()
ber = tfp.distributions.Bernoulli(logits=logits)
mean_prob = np.mean(ber.prob(y_true).eval(), axis=0)

## Compute perplexity $2^{-\sum_{i=1}^N \log_2 p(y_i)}$

In [None]:
log2_perplexity = -np.log2(mean_prob).mean()
perplexity = 2 ** log2_perplexity

In [None]:
print(f"The log2 perplexity is {log2_perplexity} and the perplexity is {perplexity}")

# OLD

### Sample logits

In [None]:
logits = tf.matrix_transpose(norm.sample(1)).eval()

In [None]:
logits_mean = np.mean(logits, axis=0)

In [None]:
#ber = tfp.distributions.Bernoulli(logits=logits_mean[idx])
ber = tfp.distributions.Bernoulli(logits=logits_mean)

### Compute perplexity

In [None]:
log_lik = ber.log_prob(y_true)
mean_log_lik = log_lik.eval().mean()
log_perplexity = -mean_log_lik
perplexity = np.exp(log_perplexity)

In [None]:
print(f"The log perplexity is {log_perplexity} and the perplexity is {perplexity}")

In [None]:
prob = ber.prob(y_true)
log2_prob = np.log2(prob.eval())
mean_log2_lik = log2_prob.mean()
log2_perplexity = -mean_log2_lik
perplexity = 2 ** log2_perplexity

In [None]:
print(f"The log2 perplexity is {log2_perplexity} and the perplexity is {perplexity}")