In [None]:
%load_ext autoreload
%autoreload 2

# Reference:
#     gpflow: https://gpflow.readthedocs.io/en/master/notebooks/advanced/natural_gradients.html

import sys
from dataclasses import dataclass

sys.path.append('../kernel')

import numpy as onp
import numpy.random as npr
onp.set_printoptions(precision=3,suppress=True)

import jax
from jax import device_put, random
import jax.numpy as np
import jax.numpy.linalg as linalg
from jax.scipy.linalg import cho_solve, solve_triangular

from typing import Any, Callable, Sequence, Optional, Tuple
import flax
from flax import linen as nn
from flax import optim, struct
from flax.core import freeze, unfreeze

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import torch
print(torch.cuda.is_available(), jax.devices())

import itertools
import matplotlib.pyplot as plt
import matplotlib as mpl
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman'
cmap = plt.cm.get_cmap('bwr')

from tabulate import tabulate

import sys
sys.path.append('../kernel')
from jaxkern import (cov_se, cov_rq, cov_pe, LookupKernel, normalize_K, mtgp_k)

from plt_utils import plt_savefig, plt_scaled_colobar_ax
from gp import gp_regression_chol, run_sgd
from gpax import is_psd, cholesky_jitter, MultivariateNormalTril, MultivariateNormalInducing
from gpax import softplus_inv, BijSoftplus, BijFillTril, BijExp, BijSoftplusFillTril
from gpax import log_func_default, log_func_simple, flax_run_optim, filter_contains, flax_get_optimizer
from gpax import CovSE, LikNormal, GPR, GPRFITC, VFE, SVGP, VariationalMultivariateNormal, mvn_conditional_sparse
from gpax import kl_mvn, kl_mvn_tril, kl_mvn_tril_zero_mean_prior, mvn_conditional_variational, mvn_conditional_variational_us
from gpax import flax_create_optimizer, get_data_stream, diag_indices_kth, rand_μΣ, pytree_mutate


In [None]:
import warnings
import numpy as onp
import gpflow
import tensorflow as tf

from gpflow.ci_utils import ci_niter, ci_range
from gpflow.optimizers import NaturalGradient
from gpflow.optimizers.natgrad import XiSqrtMeanVar
from gpflow import set_trainable

%matplotlib inline
%precision 4

In [None]:

onp.random.seed(0)
tf.random.set_seed(0)

N, D = 200, 1
batch_size = 20

# inducing points
M = 50

def f_gen(x):
    return np.sin(x * 3 * 3.14) + \
           0.3 * np.cos(x * 9 * 3.14) + \
           0.5 * np.sin(x * 7 * 3.14)

X = onp.random.uniform(size=(N, D))
y = f_gen(X)
Xs = onp.random.uniform(size=(50, D))

data = (X, y)
inducing_variable = X[:M]
adam_learning_rate = 0.01
iterations = ci_niter(5)
autotune = tf.data.experimental.AUTOTUNE

In [None]:
gpr = gpflow.models.GPR(data, kernel=gpflow.kernels.SquaredExponential())
print(f'gpr mll: {gpr.log_marginal_likelihood().numpy()}')
vgp = gpflow.models.VGP(data, kernel=gpflow.kernels.SquaredExponential(), likelihood=gpflow.likelihoods.Gaussian())
print(f'vgp elbo: {vgp.elbo().numpy()}')
natgrad_opt = NaturalGradient(gamma=1.0)
variational_params = [(vgp.q_mu, vgp.q_sqrt)]
natgrad_opt.minimize(vgp.training_loss, var_list=variational_params)
print(f'vgp elbo: {vgp.elbo().numpy()}')


In [None]:

sgpr = gpflow.models.SGPR(data, kernel=gpflow.kernels.SquaredExponential(), inducing_variable=inducing_variable)
print(f'vfe: {sgpr.elbo().numpy()}')
svgp = gpflow.models.SVGP(kernel=gpflow.kernels.SquaredExponential(),
                          likelihood=gpflow.likelihoods.Gaussian(),
                          inducing_variable=inducing_variable,
                          whiten=False)
print(f'svgp: {svgp.elbo(data).numpy()}')
print(f'\tlik+nkl={(svgp.elbo(data)+svgp.prior_kl()).numpy(), -svgp.prior_kl().numpy()}')

# plot q(f|y) for VFE
f_mean, f_var = svgp.predict_f(X, full_cov=False, full_output_cov=False)
I = np.argsort(X,0).squeeze()
f_mean = np.array(f_mean).squeeze()[I]
f_var = np.array(f_var).squeeze()[I]

plt.plot(X[I], f_mean, color='k')
plt.fill_between(X[I].squeeze(),
                (f_mean-2*np.sqrt(f_var)).squeeze(),
                (f_mean+2*np.sqrt(f_var)).squeeze(), alpha=.2, color=cmap(0))
plt.plot(inducing_variable, -np.ones_like(inducing_variable), "k|", mew=2)

svgp

In [None]:
variational_params = [(svgp.q_mu, svgp.q_sqrt)]
natgrad_opt = NaturalGradient(gamma=1.0)
natgrad_opt.minimize(svgp.training_loss_closure(data), var_list=variational_params)
print(f'svgp: {svgp.elbo(data).numpy()}')
print(f'\t-prior_kl={svgp.prior_kl().numpy()}')


In [None]:
class SVGP(nn.Module):
    data: Tuple[np.ndarray, np.ndarray]
    n_inducing: int
    n_data: int

    def setup(self):
        self.k = CovSE()
        self.lik = LikNormal()
        X, y = self.data
        init_fn = lambda k,s: X[:self.n_inducing]
        self.Xu = self.param('Xu', init_fn, 
                             (self.n_inducing, X.shape[-1]))
        
        self.q = VariationalMultivariateNormal(np.eye(len(self.Xu)))

    def get_init_params(self, key):
        Xs = np.ones((2, self.data[0].shape[-1]))
        ys = np.ones((2, self.data[1].shape[-1]))
        params = self.init(key, (Xs, ys), method=self.mll)
        return params

    def mll(self, data):
        X, y = data
        k = self.k
        m = self.n_inducing
        Xu, μq, Lq = self.Xu, self.q.μ, self.q.L

        Kff = k(X, full_cov=False)
        Kuf = k(Xu, X)
        Kuu = k(Xu)
        Luu = cholesky_jitter(Kuu, jitter=1e-5)
        
        α = self.n_data/len(X) \
            if self.n_data is not None else 1.

        μqf, σ2qf = mvn_conditional_variational(Kff, Kuf,
                                                Luu, μq, Lq, full_cov=False)

        elbo_lik = α*self.lik.variational_log_prob(y, μqf, σ2qf)
        elbo_nkl = -kl_mvn_tril_zero_mean_prior(μq, Lq, Luu)
        print('\t lik+nkl',elbo_lik, elbo_nkl)
        elbo = elbo_lik + elbo_nkl
        
        return elbo

    def pred_f(self, Xs, full_cov=False):
        k = self.k
        m = self.n_inducing
        Xu, μq, Lq = self.Xu, self.q.μ, self.q.L

        Kss = k(Xs, full_cov=full_cov)
        Kus = k(Xu, Xs)
        Kuu = k(Xu)
        Luu = cholesky_jitter(Kuu)

        μf, Σf = mvn_conditional_variational(Kss, Kus,
                                             Luu, μq, Lq, full_cov=full_cov)
        return μf, Σf

    def pred_y(self, Xs):
        μf, Σf = self.pred_f(Xs)
        ns = len(Σf)
        μy, Σy = μf, Σf + self.lik.σ2*np.diag(np.ones((ns,)))
        return μy, Σy
    
    
svgp = SVGP(data, M, len(X))
mll = svgp.apply(svgp.get_init_params(key), data, method=svgp.mll)
print(f'svgp: {mll}')


In [None]:
key = jax.random.PRNGKey(0)

gpr = GPR(data)
mll = gpr.apply(gpr.get_init_params(key), method=gpr.mll)
print(f'gpr: {mll}')

vfe = VFE(data, M)
mll = vfe.apply(vfe.get_init_params(key), method=vfe.mll)
print(f'vfe: {mll}')

svgp = SVGP(data, M, len(X))
mll = svgp.apply(svgp.get_init_params(key), data, method=svgp.mll)
print(f'svgp: {mll}')


In [None]:
params = svgp.get_init_params(key)


pp = {'params': {'μ': params['params']['q']['μ'], 'L': params['params']['q']['L']}}
mvn = VariationalMultivariateNormal(BijFillTril.forward(pp['params']['L'])).apply(pp)
plt.imshow(mvn.cov())



In [None]:


m = 100
μ0,Σ0 = rand_μΣ(jax.random.PRNGKey(0), m)
μ1,Σ1 = rand_μΣ(jax.random.PRNGKey(1), m)
μ1 = np.zeros((m,1))
L0 = linalg.cholesky(Σ0)
L1 = linalg.cholesky(Σ1)
print(kl_mvn(μ0, Σ0, μ1, Σ1))
print(kl_mvn_tril(μ0, L0, μ1, L1))
print(kl_mvn_tril_zero_mean_prior(μ0, L0, L1))
print(torch.distributions.kl_divergence(
    torch.distributions.MultivariateNormal(
        loc=torch.tensor(onp.array(μ0)), covariance_matrix=torch.tensor(onp.array(Σ0))),
    torch.distributions.MultivariateNormal(
        loc=torch.tensor(onp.array(μ1)), covariance_matrix=torch.tensor(onp.array(Σ1)))).mean())