In [None]:
# Reference:
#     gpflow: https://gpflow.readthedocs.io/en/master/notebooks/advanced/gps_for_big_data.html
#             https://github.com/GPflow/GPflow/blob/develop/gpflow/models/sgpr.py#L263
#     julia:  https://github.com/STOR-i/GaussianProcesses.jl/blob/master/src/sparse/fully_indep_train_conditional.jl
#     ladax:  https://github.com/danieljtait/ladax
#

import sys
sys.path.append('../kernel')

import numpy as np
import numpy.random as npr
from numpy.linalg import inv, det, cholesky
from numpy.linalg import solve as backsolve
np.set_printoptions(precision=3,suppress=True)
from sklearn.metrics import mean_squared_error

import jax
import jax.numpy as np
import jax.numpy.linalg as linalg
from typing import Any, Callable, Sequence, Optional, Tuple
import flax
from flax import linen as nn

import itertools

import matplotlib.pyplot as plt
import matplotlib as mpl
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman'
cmap = plt.cm.get_cmap('bwr')

from tabulate import tabulate

import sys
sys.path.append('../kernel')
from jaxkern import (cov_se, cov_rq, cov_pe, LookupKernel, normalize_K, mtgp_k)

from plt_utils import plt_savefig, plt_scaled_colobar_ax
from gp import gp_regression_chol, run_sgd
from gpax import log_func_default, log_func_simple, flax_run_optim, CovSE, GPR, GPRFITC


In [None]:

## Parameters 

xlim = (-1, 1)
ylim = (-2, 2)
n_train = 200
n_test = 200
σn = .5
logsn = np.log(σn)
lr = .01
num_steps = 20


## Data

def f_gen(x):
    return np.sin(x * 3 * 3.14) + \
           0.3 * np.cos(x * 9 * 3.14) + \
           0.5 * np.sin(x * 7 * 3.14)

## Plotting

key = jax.random.PRNGKey(0)
npr.seed(0)
X_train = np.expand_dims(npr.uniform(xlim[0], xlim[1], size=n_train), 1)
y_train = f_gen(X_train) + σn * npr.rand(n_train, 1)
data = (X_train, y_train)
X_test  = np.expand_dims(np.linspace(*xlim, n_test), 1)
print(X_train.shape, y_train.shape)


fig, ax = plt.subplots(1,1,figsize=(10,5))
ax.plot(X_train, y_train, 'x', alpha=1)
ax.grid()
ax.set_ylim(ylim)

In [None]:
lr = .0002
num_steps = 150
n_inducing = 30


fig, axs = plt.subplots(1, 2, figsize=(20, 5), sharey=True)

def get_model(i):
    if i == 0:
        return 'GPR', GPR(data)
    if i == 1:
        return 'GPR+FITC', GPRFITC(data, n_inducing)
    
    
def log_func(i, f, params):
    if i%10==0:
        print(f'[{i:3}]\tLoss={f(params):.3f}\t'
              f'σn={np.exp(params["params"]["logσn"][0]):.3f}')


for i in range(2):
    name, model = get_model(i)
    params = model.get_init_params(key)
    nmll = lambda params: -model.apply(params, method=model.mll)
    params = flax_run_optim(nmll, params, lr=lr,
                            num_steps=num_steps, log_func=log_func)
    
    mll = model.apply(params, method=model.mll)
    μ, Σ = model.apply(params, X_test, method=model.pred_y)
    std = np.expand_dims(np.sqrt(np.diag(Σ)), 1)

    ax = axs[i]
    ax.plot(X_test, μ, color='k')
    ax.fill_between(X_test.squeeze(), (μ-2*std).squeeze(), (μ+2*std).squeeze(), alpha=.2, color=cmap(0))
    ax.scatter(X_train, y_train, marker='x', color='r', s=50, alpha=.4)
    if i == 1:
        Xu = params['params']['Xu']
        ax.plot(Xu, np.zeros_like(Xu), "k|", mew=2, label="Inducing locations")
    ax.grid()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_title(f'{name}: mll={-mll:.2f}')


In [None]:
opt_logs = opt.minimize(m.training_loss, m.trainable_variables, options=dict(maxiter=100))
print_summary(m)