- https://www.kaggle.com/siddharthchaini/gp-vae-intro
- https://github.com/siddharthchaini/GP-VAE/blob/master/train.py

### Download Physionet data

In [None]:
# !wget https://www.dropbox.com/s/651d86winb4cy9n/physionet.npz?dl=1 -O physionet.npz

### Imports

In [None]:
import sys
import os
import time
from datetime import datetime
import numpy as np
import matplotlib
matplotlib.use("Agg")
from matplotlib import pyplot as plt
import tensorflow as tf

tf.compat.v1.enable_eager_execution()

from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.linear_model import LogisticRegression

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
from lib.models import *

### Flags

In [None]:
latent_dim = 35 # 'Dimensionality of the latent space'
encoder_sizes = [128, 128] # 'Layer sizes of the encoder'
decoder_sizes = [256, 256] # 'Layer sizes of the decoder'
window_size = 24 # 'Window size for the inference CNN: Ignored if model_type is not gp-vae'
sigma = 1.005 # 'Sigma value for the GP prior: Ignored if model_type is not gp-vae'
length_scale = 7.0 # 'Length scale value for the GP prior: Ignored if model_type is not gp-vae'
beta = 0.2 # 'Factor to weigh the KL term (similar to beta-VAE'
num_epochs = 40 # 'Number of training epochs'

# Flags with common default values for all three datasets
learning_rate = 1e-3 # 'Learning rate for training'
gradient_clip = 1e4 # 'Maximum global gradient norm for the gradient clipping during training'
num_steps = 0 # 'Number of training steps: If non-zero it overwrites num_epochs'
print_interval = 0 # 'Interval for printing the loss and saving the model during training'
exp_name = "reproduce_physionet" # 'Name of the experiment'
basedir = "models" # 'Directory where the models should be stored'
data_dir = "" # 'Directory from where the data should be read in'
data_type = 'physionet' # ['hmnist', 'physionet', 'sprites'], 'Type of data to be trained on'
seed = 1337 # 'Seed for the random number generator'
model_type = 'gp-vae' # ['vae', 'hi-vae', 'gp-vae'], 'Type of model to be trained'
cnn_kernel_size = 3 # 'Kernel size for the CNN preprocessor'
cnn_sizes = [256] # 'Number of filters for the layers of the CNN preprocessor'
testing = True # 'Use the actual test set for testing'
banded_covar = True # 'Use a banded covariance matrix instead of a diagonal one for the output of the inference network: Ignored if model_type is not gp-vae'
batch_size = 64 # 'Batch size for training'

M = 1 # 'Number of samples for ELBO estimation'
K = 1 # 'Number of importance sampling weights'

kernel = 'cauchy' # ['rbf', 'diffusion', 'matern', 'cauchy'], 'Kernel to be used for the GP prior: Ignored if model_type is not (mgp-vae'
kernel_scales = 1 # 'Number of different length scales sigma for the GP prior: Ignored if model_type is not gp-vae'

### Prep

In [None]:
np.random.seed(seed)
tf.compat.v1.set_random_seed(seed)
print("Testing: ", testing, f"\t Seed: {seed}")

In [None]:
encoder_sizes = [int(size) for size in encoder_sizes]
decoder_sizes = [int(size) for size in decoder_sizes]

if 0 in encoder_sizes:
    encoder_sizes.remove(0)
if 0 in decoder_sizes:
    decoder_sizes.remove(0)

In [None]:
# Make up full exp name
timestamp = datetime.now().strftime("%y%m%d")
full_exp_name = "{}_{}".format(timestamp, exp_name)
outdir = os.path.join(basedir, full_exp_name)
if not os.path.exists(outdir): os.mkdir(outdir)
checkpoint_prefix = os.path.join(outdir, "ckpt")
print("Full exp name: ", full_exp_name)

### Define data specific parameters

In [None]:
data_type

In [None]:
if data_type == "hmnist":
    data_dir = "data/hmnist/hmnist_mnar.npz"
    data_dim = 784
    time_length = 10
    num_classes = 10
    decoder = BernoulliDecoder
    img_shape = (28, 28, 1)
    val_split = 50000
elif data_type == "physionet":
    if data_dir == "":
        data_dir = "physionet.npz"
    data_dim = 35
    time_length = 48
    num_classes = 2
    decoder = GaussianDecoder
elif data_type == "sprites":
    if data_dir == "":
        data_dir = "data/sprites/sprites.npz"
    data_dim = 12288
    time_length = 8
    decoder = GaussianDecoder
    img_shape = (64, 64, 3)
    val_split = 8000
else:
    raise ValueError("Data type must be one of ['hmnist', 'physionet', 'sprites']")

### Load data

In [None]:
data = np.load(data_dir)

In [None]:
x_train_full = data['x_train_full']
x_train_miss = data['x_train_miss']
m_train_miss = data['m_train_miss']

In [None]:
if data_type in ['hmnist', 'physionet']:
    y_train = data['y_train']

In [None]:
if testing:
    if data_type in ['hmnist', 'sprites']:
        x_val_full = data['x_test_full']
        x_val_miss = data['x_test_miss']
        m_val_miss = data['m_test_miss']
    if data_type == 'hmnist':
        y_val = data['y_test']
    elif data_type == 'physionet':
        x_val_full = data['x_train_full']
        x_val_miss = data['x_train_miss']
        m_val_miss = data['m_train_miss']
        y_val = data['y_train']
        m_val_artificial = data["m_train_artificial"]
elif data_type in ['hmnist', 'sprites']:
    x_val_full = x_train_full[val_split:]
    x_val_miss = x_train_miss[val_split:]
    m_val_miss = m_train_miss[val_split:]
    if data_type == 'hmnist':
        y_val = y_train[val_split:]
    x_train_full = x_train_full[:val_split]
    x_train_miss = x_train_miss[:val_split]
    m_train_miss = m_train_miss[:val_split]
    y_train = y_train[:val_split]
elif data_type == 'physionet':
    x_val_full = data["x_val_full"]  # full for artificial missings
    x_val_miss = data["x_val_miss"]
    m_val_miss = data["m_val_miss"]
    m_val_artificial = data["m_val_artificial"]
    y_val = data["y_val"]
else:
    raise ValueError("Data type must be one of ['hmnist', 'physionet', 'sprites']")

In [None]:
tf_x_train_miss = tf.data.Dataset.from_tensor_slices((x_train_miss, m_train_miss))\
                                 .shuffle(len(x_train_miss)).batch(batch_size).repeat()
tf_x_val_miss = tf.data.Dataset.from_tensor_slices((x_val_miss, m_val_miss)).batch(batch_size).repeat()
tf_x_val_miss = tf.compat.v1.data.make_one_shot_iterator(tf_x_val_miss)

In [None]:
# Build Conv2D preprocessor for image data
if data_type in ['hmnist', 'sprites']:
    print("Using CNN preprocessor")
    image_preprocessor = ImagePreprocessor(img_shape, cnn_sizes, cnn_kernel_size)
elif data_type == 'physionet':
    image_preprocessor = None
else:
    raise ValueError("Data type must be one of ['hmnist', 'physionet', 'sprites']")

### Build Model

In [None]:
if model_type == "vae":
    model = VAE(latent_dim=latent_dim, data_dim=data_dim, time_length=time_length,
                encoder_sizes=encoder_sizes, encoder=DiagonalEncoder,
                decoder_sizes=decoder_sizes, decoder=decoder,
                image_preprocessor=image_preprocessor, window_size=window_size,
                beta=beta, M=M, K=K)
elif model_type == "hi-vae":
    model = HI_VAE(latent_dim=latent_dim, data_dim=data_dim, time_length=time_length,
                   encoder_sizes=encoder_sizes, encoder=DiagonalEncoder,
                   decoder_sizes=decoder_sizes, decoder=decoder,
                   image_preprocessor=image_preprocessor, window_size=window_size,
                   beta=beta, M=M, K=K)
elif model_type == "gp-vae":
    encoder = BandedJointEncoder if banded_covar else JointEncoder
    model = GP_VAE(latent_dim=latent_dim, data_dim=data_dim, time_length=time_length,
                   encoder_sizes=encoder_sizes, encoder=encoder,
                   decoder_sizes=decoder_sizes, decoder=decoder,
                   kernel=kernel, sigma=sigma,
                   length_scale=length_scale, kernel_scales = kernel_scales,
                   image_preprocessor=image_preprocessor, window_size=window_size,
                   beta=beta, M=M, K=K, data_type=data_type)
else:
    raise ValueError("Model type must be one of ['vae', 'hi-vae', 'gp-vae']")

### Training preparation

In [None]:
print("GPU support: ", tf.test.is_gpu_available())

In [None]:
print("Training...")
_ = tf.compat.v1.train.get_or_create_global_step()
trainable_vars = model.get_trainable_vars()
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

print("Encoder: ", model.encoder.net.summary())
print("Decoder: ", model.decoder.net.summary())

In [None]:
if model.preprocessor is not None:
    print("Preprocessor: ", model.preprocessor.net.summary())
    saver = tf.compat.v1.train.Checkpoint(optimizer=optimizer, encoder=model.encoder.net,
                                          decoder=model.decoder.net, preprocessor=model.preprocessor.net,
                                          optimizer_step=tf.compat.v1.train.get_or_create_global_step())
else:
    saver = tf.compat.v1.train.Checkpoint(optimizer=optimizer, encoder=model.encoder.net, decoder=model.decoder.net,
                                          optimizer_step=tf.compat.v1.train.get_or_create_global_step())

In [None]:
summary_writer = tf.compat.v2.summary.create_file_writer(logdir=outdir, flush_millis=10000)

In [None]:
if num_steps == 0:
    num_steps = num_epochs * len(x_train_miss) // batch_size
else:
    num_steps = num_steps

In [None]:
if print_interval == 0:
    print_interval = num_steps // num_epochs

### Training

In [None]:
losses_train = []
losses_val = []

In [None]:
t0 = time.time()

In [None]:
with summary_writer.as_default(), tf.compat.v2.summary.record_if(True):
    for i, (x_seq, m_seq) in enumerate(tf_x_train_miss.take(num_steps)):
        try:
            with tf.GradientTape() as tape:
                tape.watch(trainable_vars)
                loss = model.compute_loss(x_seq, m_mask=m_seq)
                losses_train.append(loss.numpy())
            grads = tape.gradient(loss, trainable_vars)
            grads = [np.nan_to_num(grad) for grad in grads]
            grads, global_norm = tf.clip_by_global_norm(grads, gradient_clip)
            optimizer.apply_gradients(zip(grads, trainable_vars),
                                      global_step=tf.compat.v1.train.get_or_create_global_step())

            # Print intermediate results
            if i % print_interval == 0:
                print("================================================")
                print("Learning rate: {} | Global gradient norm: {:.2f}".format(optimizer._lr, global_norm))
                print("Step {}) Time = {:2f}".format(i, time.time() - t0))
                loss, nll, kl = model.compute_loss(x_seq, m_mask=m_seq, return_parts=True)
                print("Train loss = {:.3f} | NLL = {:.3f} | KL = {:.3f}".format(loss, nll, kl))

                saver.save(checkpoint_prefix)
                tf.compat.v2.summary.scalar(name="loss_train", data=loss, step=tf.compat.v1.train.get_or_create_global_step())
                tf.compat.v2.summary.scalar(name="kl_train", data=kl, step=tf.compat.v1.train.get_or_create_global_step())
                tf.compat.v2.summary.scalar(name="nll_train", data=nll, step=tf.compat.v1.train.get_or_create_global_step())

                # Validation loss
                x_val_batch, m_val_batch = tf_x_val_miss.get_next()
                val_loss, val_nll, val_kl = model.compute_loss(x_val_batch, m_mask=m_val_batch, return_parts=True)
                losses_val.append(val_loss.numpy())
                print("Validation loss = {:.3f} | NLL = {:.3f} | KL = {:.3f}".format(val_loss, val_nll, val_kl))

                tf.compat.v2.summary.scalar(name="loss_val", data=val_loss, step=tf.compat.v1.train.get_or_create_global_step())
                tf.compat.v2.summary.scalar(name="kl_val", data=val_kl, step=tf.compat.v1.train.get_or_create_global_step())
                tf.compat.v2.summary.scalar(name="nll_val", data=val_nll, step=tf.compat.v1.train.get_or_create_global_step())

                if data_type in ["hmnist", "sprites"]:
                    # Draw reconstructed images
                    x_hat = model.decode(model.encode(x_seq).sample()).mean()
                    tf.compat.v2.summary.image(name="input_train", data=tf.reshape(x_seq, [-1]+list(img_shape)), step=tf.compat.v1.train.get_or_create_global_step())
                    tf.compat.v2.summary.image(name="reconstruction_train", data=tf.reshape(x_hat, [-1]+list(img_shape)), step=tf.compat.v1.train.get_or_create_global_step())
                elif data_type == 'physionet':
                    # Eval MSE and AUROC on entire val set
                    x_val_miss_batches = np.array_split(x_val_miss, batch_size, axis=0)
                    x_val_full_batches = np.array_split(x_val_full, batch_size, axis=0)
                    m_val_artificial_batches = np.array_split(m_val_artificial, batch_size, axis=0)
                    get_val_batches = lambda: zip(x_val_miss_batches, x_val_full_batches, m_val_artificial_batches)

                    n_missings = m_val_artificial.sum()
                    mse_miss = np.sum([model.compute_mse(x, y=y, m_mask=m).numpy()
                                       for x, y, m in get_val_batches()]) / n_missings

                    x_val_imputed = np.vstack([model.decode(model.encode(x_batch).mean()).mean().numpy()
                                               for x_batch in x_val_miss_batches])
                    x_val_imputed[m_val_miss == 0] = x_val_miss[m_val_miss == 0]  # impute gt observed values

                    x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim])
                    val_split = len(x_val_imputed) // 2
                    cls_model = LogisticRegression(solver='liblinear', tol=1e-10, max_iter=10000)
                    cls_model.fit(x_val_imputed[:val_split], y_val[:val_split])
                    probs = cls_model.predict_proba(x_val_imputed[val_split:])[:, 1]
                    auroc = roc_auc_score(y_val[val_split:], probs)
                    print("MSE miss: {:.4f} | AUROC: {:.4f}".format(mse_miss, auroc))

                    # Update learning rate (used only for physionet with decay=0.5)
                    if i > 0 and i % (10*print_interval) == 0:
                        optimizer._lr = max(0.5 * optimizer._lr, 0.1 * learning_rate)
                t0 = time.time()
        except KeyboardInterrupt as e:
            print("KeyboardInterrupt")
            saver.save(checkpoint_prefix)
#             if debug:
#                 import ipdb
#                 ipdb.set_trace()
            break

### Evaluation

In [None]:
# Split data on batches
x_val_miss_batches = np.array_split(x_val_miss, batch_size, axis=0)
x_val_full_batches = np.array_split(x_val_full, batch_size, axis=0)

In [None]:
if data_type == 'physionet':
    m_val_batches = np.array_split(m_val_artificial, batch_size, axis=0)
else:
    m_val_batches = np.array_split(m_val_miss, batch_size, axis=0)

In [None]:
get_val_batches = lambda: zip(x_val_miss_batches, x_val_full_batches, m_val_batches)

In [None]:
# Compute NLL and MSE on missing values
n_missings = m_val_artificial.sum() if data_type == 'physionet' else m_val_miss.sum()
nll_miss = np.sum([model.compute_nll(x, y=y, m_mask=m).numpy()
                   for x, y, m in get_val_batches()]) / n_missings
mse_miss = np.sum([model.compute_mse(x, y=y, m_mask=m, binary=data_type=="hmnist").numpy()
                   for x, y, m in get_val_batches()]) / n_missings

In [None]:
print("NLL miss: {:.4f}".format(nll_miss))
print("MSE miss: {:.4f}".format(mse_miss))

In [None]:
# Save imputed values
z_mean = [model.encode(x_batch).mean().numpy() for x_batch in x_val_miss_batches]
np.save(os.path.join(outdir, "z_mean"), np.vstack(z_mean))
x_val_imputed = np.vstack([model.decode(z_batch).mean().numpy() for z_batch in z_mean])
np.save(os.path.join(outdir, "imputed_no_gt"), x_val_imputed)

In [None]:
# impute gt observed values
x_val_imputed[m_val_miss == 0] = x_val_miss[m_val_miss == 0]
np.save(os.path.join(outdir, "imputed"), x_val_imputed)

In [None]:
if data_type == "hmnist":
    # AUROC evaluation using Logistic Regression
    x_val_imputed = np.round(x_val_imputed)
    x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim])

    cls_model = LogisticRegression(solver='lbfgs', multi_class='multinomial', tol=1e-10, max_iter=10000)
    val_split = len(x_val_imputed) // 2

    cls_model.fit(x_val_imputed[:val_split], y_val[:val_split])
    probs = cls_model.predict_proba(x_val_imputed[val_split:])

    auprc = average_precision_score(np.eye(num_classes)[y_val[val_split:]], probs)
    auroc = roc_auc_score(np.eye(num_classes)[y_val[val_split:]], probs)
    print("AUROC: {:.4f}".format(auroc))
    print("AUPRC: {:.4f}".format(auprc))

elif data_type == "sprites":
    auroc, auprc = 0, 0

elif data_type == "physionet":
    # Uncomment to preserve some z_samples and their reconstructions
    # for i in range(5):
    #     z_sample = [model.encode(x_batch).sample().numpy() for x_batch in x_val_miss_batches]
    #     np.save(os.path.join(outdir, "z_sample_{}".format(i)), np.vstack(z_sample))
    #     x_val_imputed_sample = np.vstack([model.decode(z_batch).mean().numpy() for z_batch in z_sample])
    #     np.save(os.path.join(outdir, "imputed_sample_{}_no_gt".format(i)), x_val_imputed_sample)
    #     x_val_imputed_sample[m_val_miss == 0] = x_val_miss[m_val_miss == 0]
    #     np.save(os.path.join(outdir, "imputed_sample_{}".format(i)), x_val_imputed_sample)

    # AUROC evaluation using Logistic Regression
    x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim])
    val_split = len(x_val_imputed) // 2
    cls_model = LogisticRegression(solver='liblinear', tol=1e-10, max_iter=10000)
    cls_model.fit(x_val_imputed[:val_split], y_val[:val_split])
    probs = cls_model.predict_proba(x_val_imputed[val_split:])[:, 1]
    auprc = average_precision_score(y_val[val_split:], probs)
    auroc = roc_auc_score(y_val[val_split:], probs)

    print("AUROC: {:.4f}".format(auroc))
    print("AUPRC: {:.4f}".format(auprc))

In [None]:
# Visualize reconstructions
if data_type in ["hmnist", "sprites"]:
    img_index = 0
    if data_type == "hmnist":
        img_shape = (28, 28)
        cmap = "gray"
    elif data_type == "sprites":
        img_shape = (64, 64, 3)
        cmap = None

    fig, axes = plt.subplots(nrows=3, ncols=x_val_miss.shape[1], figsize=(2*x_val_miss.shape[1], 6))

    x_hat = model.decode(model.encode(x_val_miss[img_index: img_index+1]).mean()).mean().numpy()
    seqs = [x_val_miss[img_index:img_index+1], x_hat, x_val_full[img_index:img_index+1]]

    for axs, seq in zip(axes, seqs):
        for ax, img in zip(axs, seq[0]):
            ax.imshow(img.reshape(img_shape), cmap=cmap)
            ax.axis('off')

    suptitle = model_type + f" reconstruction, NLL missing = {mse_miss}"
    fig.suptitle(suptitle, size=18)
    fig.savefig(os.path.join(outdir, data_type + "_reconstruction.pdf"))

results_all = [seed, model_type, data_type, kernel, beta, latent_dim,
               num_epochs, batch_size, learning_rate, window_size,
               kernel_scales, sigma, length_scale,
               len(encoder_sizes), encoder_sizes[0] if len(encoder_sizes) > 0 else 0,
               len(decoder_sizes), decoder_sizes[0] if len(decoder_sizes) > 0 else 0,
               cnn_kernel_size, cnn_sizes,
               nll_miss, mse_miss, losses_train[-1], losses_val[-1], auprc, auroc, testing, data_dir]

with open(os.path.join(outdir, "results.tsv"), "w") as outfile:
    outfile.write("seed\tmodel\tdata\tkernel\tbeta\tz_size\tnum_epochs"
                  "\tbatch_size\tlearning_rate\twindow_size\tkernel_scales\t"
                  "sigma\tlength_scale\tencoder_depth\tencoder_width\t"
                  "decoder_depth\tdecoder_width\tcnn_kernel_size\t"
                  "cnn_sizes\tNLL\tMSE\tlast_train_loss\tlast_val_loss\tAUPRC\tAUROC\ttesting\tdata_dir\n")
    outfile.write("\t".join(map(str, results_all)))

with open(os.path.join(outdir, "training_curve.tsv"), "w") as outfile:
    outfile.write("\t".join(map(str, losses_train)))
    outfile.write("\n")
    outfile.write("\t".join(map(str, losses_val)))

print("Training finished.")