In [None]:

import numpy as onp
onp.set_printoptions(precision=3,suppress=True)
from sklearn.metrics import mean_squared_error

import jax
import jax.numpy as np
from jax import grad, jit, vmap, device_put, random
from flax import linen as nn


import matplotlib.pyplot as plt
import matplotlib as mpl
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman'
cmap = plt.cm.get_cmap('bwr')


from tabulate import tabulate

import sys
sys.path.append('../kernel')
from jaxkern import normalize_K

from plt_utils import plt_savefig, plt_scaled_colobar_ax
from gpax import *

In [None]:
## Parameters

M = 1
n_train = 50
n_test = 100
ylim = (-.5, 2.5)
xlim = (-1.2, 1.2)
σn = .003
ℓ = .2


## Data

onp.random.seed(0)
key = jax.random.PRNGKey(0)

def f(X,slope=2, intercept=0): return .3*np.sin(5*X) + slope*X + intercept
X = np.sort(random.uniform(key, (n_train, 1)), axis=0)*2-1
y = f(X,2) + random.normal(key, (n_train, M))*σn
data = (X, y)
print(X.shape, y.shape)

Xs = np.linspace(xlim[0], xlim[1], n_test).reshape(-1,1)

for i in range(M):
    plt.scatter(X, y[:,i])

In [None]:

class SVGP(nn.Module, GPModel):
    n_data: int
    Xu_initial: np.ndarray

    def setup(self):
        self.d = self.Xu_initial.shape[1]
        self.n_inducing = self.Xu_initial.shape[0]
        self.mean_fn = MeanContant()
        self.k = CovSE()
        self.lik = LikNormal()
        self.Xu = self.param('Xu', lambda k, s: self.Xu_initial,
                             (self.n_inducing, self.d))
        self.q = VariationalMultivariateNormal(self.n_inducing)

    def get_init_params(self, key, n_tasks=1):
        n = 1
        Xs = np.ones((n, self.Xu_initial.shape[-1]))
        ys = np.ones((n, n_tasks))
        params = self.init(key, (Xs, ys), method=self.mll)
        return params

    def mll(self, data):
        X, y = data
        y = y.flatten()
        k = self.k
        Xu, μq, Lq = self.Xu, self.q.μ, self.q.L

        Kff = k(X, full_cov=False)
        Kuf = k(Xu, X)
        Kuu = k(Xu)
        Luu = cholesky_jitter(Kuu, jitter=5e-5)

        α = self.n_data/len(X) \
            if self.n_data is not None else 1.

        μqf, σ2qf = mvn_conditional_variational(Kff, Kuf,
                                                Luu, μq, Lq, full_cov=False)
        if isinstance(self.lik, LikMultipleNormal):
            elbo_lik = α*self.lik.variational_log_prob(y, μqf, σ2qf, X[:, -1])
        else:
            elbo_lik = α*self.lik.variational_log_prob(y, μqf, σ2qf)
        elbo_nkl = -kl_mvn_tril_zero_mean_prior(μq, Lq, Luu)
        elbo = elbo_lik + elbo_nkl

        return elbo

    def pred_f(self, Xs, full_cov=True):
        k = self.k
        Xu, μq, Lq = self.Xu, self.q.μ, self.q.L

        Kss = k(Xs, full_cov=full_cov)
        Kus = k(Xu, Xs)
        Kuu = k(Xu)
        Luu = cholesky_jitter(Kuu, jitter=5e-5)

        μf, Σf = mvn_conditional_variational(Kss, Kus,
                                             Luu, μq, Lq, full_cov=full_cov)
        return μf, Σf

In [None]:

class ExMTGP(GPR):
    def setup(self):
        self.k = CovICM(kt_kwargs={'output_dim': M})
        self.mean_fn = MeanConstant(M, init_val_m=.35)
        self.lik = LikMultipleNormalKron(M)

# n = 10
# key = random.PRNGKey(0)
# X = np.sort(random.uniform(key, (10, 1)), axis=0)
# y = random.normal(key, X.shape)*.2
# y = y*np.arange(M).reshape(1,-1)
# data = (X, y)

# model = ExMTGP(data)
# params = model.get_init_params(key)
# model = model.bind(params)
# model.pred_f(X[:2])


In [None]:
lr = 0.003
num_steps = 300


colors_b = [cmap(.1), cmap(.3)]
colors_r = [cmap(.9), cmap(.7)]
fig, axs = plt.subplots(4,2,figsize=(20,20))

for it, intercept in enumerate([.1,.5,1,2]):
    
    def f_gen(X): return f(X,slope=0, intercept=intercept)
    
    y = f_gen(X) + random.normal(key, (n_train, M))*σn
    data = (X, y)

    ## Training
    model = ExMTGP(data)
    params = model.get_init_params(key)

    @jax.jit
    def nmll(params):
        return -model.apply(params, method=model.mll)

    params = flax_run_optim(nmll, params, num_steps=num_steps,
                            optimizer_kwargs={'learning_rate': lr},
                            log_func=log_func_default)

    ## Plotting
    model = model.bind(params)
    mll = model.mll()
    μ, Σ = model.pred_f(Xs, full_cov=False)
    μ = μ.flatten()
    std = np.sqrt(Σ).flatten()

    ax = axs[it,0]
    ax.plot(Xs, μ, color=colors_b[0], lw=2)
    ax.fill_between(Xs.flatten(), μ-2*std, μ+2*std, alpha=.2, color=colors_b[0])
    ax.plot(Xs, f_gen(Xs), linewidth=1, color='k', linestyle='dashed')
    ax.scatter(X, y, marker='x', color=colors_r[0], s=50)
    ax.grid()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    title = ' $-mll$'+f'={-mll.item():.2f}, ' + \
            '$\ell$'+f'={model.k.kx.ℓ.item():.2f}, ' + \
            '$m$'+f'={model.mean_fn.c[0]:.2f}, ' + \
            '$\sigma_n$'+f'={model.lik.σ2.item():.4f}'
    ax.set_title(title)

    ax = axs[it,1]
    XX = np.vstack((X, Xs))
    K = model.k(XX)
    im = ax.imshow(normalize_K(K), cmap=cmap)
    fig.colorbar(im, cax=plt_scaled_colobar_ax(ax))
    ax.set_title('$K(X, Xs)$')


