# Code used to get low-data regime result for Battle et al. (2024)

Code reused from Battle et al. (2024). Used to generate Helmholtz low-data result for Table 4.5.

CAUTION: Very memory-intensive for large `n_components` (>50)

Our result used `Ntrain = 1000`, `n_components = 20`.

In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.metrics import pairwise_distances
import numpy as onp
import jax
import jax.numpy as np
from jax import jit, vmap
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as t
from ipywidgets import interact
from jax import grad
from jax.scipy.optimize import minimize
from jax import config
config.update("jax_enable_x64", True)
from jax.scipy.linalg import cholesky, cho_factor, cho_solve
from jax.scipy.optimize import minimize
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_box

import pandas as pd

In [None]:
Inputs = onp.load('../helmholtz_data/Helmholtz_inputs.npy')
Outputs = onp.load('../helmholtz_data/Helmholtz_outputs.npy')

In [None]:
Inputs = Inputs.transpose((2,1,0))
Outputs = Outputs.transpose((2,1,0))

Inputs_fl = Inputs.reshape(len(Inputs), 101*101)
Outputs_fl = Outputs.reshape(len(Outputs), 101*101)

In [None]:
Ntrain = 1000

In [None]:
Xtr = Inputs_fl[:Ntrain]

In [None]:
%%time
pca = PCA(n_components=20)
Xtr = pca.fit_transform(Xtr)

In [None]:
Xtest = pca.transform(Inputs_fl[Ntrain:])
Ytr = Outputs_fl[:Ntrain]
Ytest = Outputs_fl[Ntrain:]


In [None]:
def sqeuclidean_distances(x: np.ndarray, y: np.ndarray) -> float:
    return np.sum( (x - y) ** 2)
dists = jit(vmap(vmap(sqeuclidean_distances, in_axes=(None, 0)), in_axes=(0, None)))

def euclidean_distances(x: np.ndarray, y: np.ndarray) -> float:
    return np.sqrt(np.sum( (x - y) ** 2))
sqdists = jit(vmap(vmap(euclidean_distances, in_axes=(None, 0)), in_axes=(0, None)))


@jit
def matern(v1, v2, sigma = 50):
    #V1 is a [k1] vector
    #V2 is a [k2] vector
    #returns a k1xk2 matrix
    d = sqdists(v1, v2)
    #return a*np.exp(-d**2/sigma)
    return (1+np.sqrt(5)*d/sigma +5*d**2/(3*sigma**2))*np.exp(-np.sqrt(5)*d/sigma)

@jit
def exp(v1, v2, sigma):
    #V1 is a [k1] vector
    #V2 is a [k2] vector
    #returns a k1xk2 matrix
    d = dists(v1, v2)
    return np.exp(-d/sigma)
    #return (1+np.sqrt(5)*d/sigma +5*d**2/(3*sigma**2))*np.exp(-np.sqrt(5)*d/sigma)

@jit
def iq(v1, v2, sigma):
    #V1 is a [k1] vector
    #V2 is a [k2] vector
    #returns a k1xk2 matrix
    d = dists(v1, v2)
    #return a*np.exp(-d**2/sigma)
    #return (1+np.sqrt(5)*d/sigma +5*d**2/(3*sigma**2))*np.exp(-np.sqrt(5)*d/sigma)
    return 1/np.sqrt(d+sigma)

In [None]:
nugget = 1e-8

In [None]:
def aux(kernel, s, nugget):
    k = kernel
    Kxx = k(Xtr, Xtr, s)
    nuggeted_matrix = Kxx.at[np.diag_indices_from(Kxx)].add(nugget)
    L = cho_factor(nuggeted_matrix)
    result = cho_solve(L, Ytr)
    Train_pred = Kxx@result #train predictions
    K_te_tr = k(Xtest, Xtr, s)
    Test_pred = K_te_tr@result #test predictions

    print(Test_pred.shape)
    print(Ytest.shape)

    np.mean(np.linalg.norm(Ytr-Train_pred, axis = 1))

    aux1 = np.mean(np.linalg.norm(Ytr-Train_pred, axis = 1))
    aux2 = np.mean(np.linalg.norm(Train_pred-Ytr, axis = 1)/np.linalg.norm(Ytr, axis = 1))
    aux3 = np.mean(np.linalg.norm(Ytest-Test_pred, axis = 1))
    aux4 = np.mean(np.linalg.norm(Ytest-Test_pred, axis = 1)/np.linalg.norm(Ytest, axis = 1))

    print(s, nugget)
    print("\n Train error (abs): {0} \n Train error (rel): {1} \n Test error (abs): {2} \n Test error (rel): {3}".format(aux1, aux2, aux3, aux4))
    print('---')

In [None]:
for kernel in [matern]:
    for s in [600, 700, 800, 900, 1000]:
        for nugget in [1e-8]:
            aux(kernel, s, nugget)