# Imports

In [1]:
from time import time
import sys
import os

from copy import deepcopy

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import jax.scipy as jsp

from itertools import product
from functools import partial
from scipy.io import savemat
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import KFold
import pandas as pd

import optax


import matplotlib.pyplot as plt

from base import loss_fn, predict_fn, get_params
from utils import train_fn, latexify
from data import get_nonstat2d, get_jump1d, get_mcycle, get_simulated

jax.config.update("jax_enable_x64", True)

%reload_ext watermark

# Config

In [2]:
default_params = False
methods = ["delta_inducing", "heinonen"]

# Dataset

In [4]:
data = {}
for name, loader in zip(["Motorcycle", "SYNTH-1D", "NONSTAT2D", "Jump"], [get_mcycle, get_simulated, get_nonstat2d, get_jump1d]):
    X, y = loader()
    X = jnp.asarray(X)
    y = jnp.asarray(y)
    data[name] = {"X": X, "y": y}

# Run

In [6]:
init_time = time()
ablation_results = {}
idx = 0
for method in methods:
    result1 = ablation_results[method] = {}
    for name, data_dict in data.items():
        result2 = result1[name] = {}
        X = data_dict["X"]
        y = data_dict["y"]
        
        # Normalize
        x_scaler = MinMaxScaler()
        X = x_scaler.fit_transform(X)
        xscale = x_scaler.data_max_ - x_scaler.data_min_
        yscale = jnp.max(jnp.abs(y - jnp.mean(y)))
        ymean = jnp.mean(y)
        y = (y - ymean) / yscale
        
        # Split into train and test
        for fold_i, (train_idx, test_idx) in enumerate(KFold(n_splits=2, shuffle=True, random_state=0).split(X)):
            result3 = result2[fold_i] = {}
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]
            
            for ell, sigma, omega in product([1, 0], repeat=3):
                result4 = result3[(ell, sigma, omega)] = {}
                model_flex_dict = {"ell": ell, "sigma": sigma, "omega": omega}
        
                value_and_grad_fn = partial(loss_fn, X=X_train, y=y_train, flex_dict=model_flex_dict, method=method)
                partial_get_params = partial(
                    get_params, X=X_train, flex_dict=model_flex_dict, method=method, default=default_params
                )
                params = jax.vmap(partial_get_params)(jax.random.split(jax.random.PRNGKey(1000), 10))
                partial_train_fn = partial(
                    train_fn, loss_fn=value_and_grad_fn, optimizer=optax.adam(0.001), n_iters=5000
                )
                
                # print(partial_train_fn(init_raw_params=jtu.tree_map(lambda x: x[0], params)))
                train_init = time()
                results = jax.vmap(partial_train_fn)(init_raw_params=params)
                train_end = time()
                # print("Losses: ", results["loss_history"][:, -1])
                best_idx = jnp.nanargmin(results["loss_history"][:, -1])
                result = jtu.tree_map(lambda x: x[best_idx], results)
                
                pred_init = time()               
                pred_mean, pred_var, pred_ell, pred_sigma, pred_omega = predict_fn(
                    result["raw_params"],
                    X_train,
                    y_train,
                    X_test,
                    model_flex_dict,
                    method,
                )
                pred_end = time()
                
                result4["X_train"] = X_train
                result4["y_train"] = y_train
                result4["X_test"] = X_test
                result4["y_test"] = y_test
                result4["pred_mean"] = pred_mean
                result4["pred_var"] = pred_var
                result4["pred_ell"] = pred_ell
                result4["pred_sigma"] = pred_sigma
                result4["pred_omega"] = pred_omega
                result4["train_time"] = train_time = train_end - train_init
                result4["pred_time"] = pred_time = pred_end - pred_init
                
                print(f"{method} {name} {fold_i} {ell} {sigma} {omega} {idx} {train_time=}, {pred_time=}")
                idx += 1

print(f"Total time: {(time() - init_time)/60:.3f} min")
pd.to_pickle(ablation_results, "results/ablation_results.pkl")

delta_inducing Motorcycle 0 1 1 1 0 train_time=8.23494291305542, pred_time=0.08530902862548828
delta_inducing Motorcycle 0 1 1 0 1 train_time=7.989428758621216, pred_time=0.29902076721191406
delta_inducing Motorcycle 0 1 0 1 2 train_time=7.560467958450317, pred_time=0.08227777481079102
delta_inducing Motorcycle 0 1 0 0 3 train_time=7.284414052963257, pred_time=0.06609725952148438
delta_inducing Motorcycle 0 0 1 1 4 train_time=7.915605306625366, pred_time=0.129655122756958
delta_inducing Motorcycle 0 0 1 0 5 train_time=7.2352235317230225, pred_time=0.05625510215759277
delta_inducing Motorcycle 0 0 0 1 6 train_time=6.885382652282715, pred_time=0.05504345893859863
delta_inducing Motorcycle 0 0 0 0 7 train_time=6.692352771759033, pred_time=0.04555821418762207
delta_inducing Motorcycle 1 1 1 1 8 train_time=8.427367448806763, pred_time=0.08721041679382324
delta_inducing Motorcycle 1 1 1 0 9 train_time=7.523759365081787, pred_time=0.07950806617736816
delta_inducing Motorcycle 1 1 0 1 10 train

KeyboardInterrupt: 

# Analysis

In [20]:
ablation_results = pd.read_pickle("results/ablation_results.pkl")

In [21]:
dfs = {}
for method in methods:
    header = pd.MultiIndex.from_product([[name for name in data.keys()],
                                     ["NLPD", "RMSE"]])
    dfs[method] = pd.DataFrame(index=[(ell, sigma, omega) for ell, sigma, omega in sorted(product([1, 0], repeat=3), key=lambda x: -sum(x))], columns=header)
    for name in data.keys():
        for ell, sigma, omega in sorted(product([1, 0], repeat=3), key=lambda x: sum(x)):
            pred_mean = ablation_results[method][name][0][(ell, sigma, omega)]["pred_mean"]
            pred_var = ablation_results[method][name][0][(ell, sigma, omega)]["pred_var"]
            pred_omega = ablation_results[method][name][0][(ell, sigma, omega)]["pred_omega"]
            y_test = ablation_results[method][name][0][(ell, sigma, omega)]["y_test"]
            pred_scale = jnp.sqrt(pred_var + omega**2)
            dfs[method][(name, "RMSE")][(ell, sigma, omega)] = jnp.sqrt(jnp.mean((pred_mean - y_test)**2))
            dfs[method][(name, "NLPD")][(ell, sigma, omega)] = -jsp.stats.norm.logpdf(y_test, loc=pred_mean, scale=pred_scale).mean()

# dfs["heinonen"]

In [23]:
def convert(ell, sigma, omega):
    if ell == sigma == omega == 0:
        return "Stationary Homoskedastic GP"
    names = []
    if ell:
        names.append(r"\ell")
    if sigma:
        names.append(r"\sigma")
    if omega:
        names.append(r"\omega")
    return f"(${','.join(names)}$)-GP"

dfs_new = deepcopy(dfs)
dfs_new["delta_inducing"].index = [convert(ell, sigma, omega) for ell, sigma, omega in dfs["delta_inducing"].index]
dfs_new["heinonen"].index = [convert(ell, sigma, omega) for ell, sigma, omega in dfs["heinonen"].index]

# bold minimum values in each column
diff = dfs_new["heinonen"] - dfs_new["delta_inducing"]


res = dfs_new["delta_inducing"].astype(float)#.round(4).astype(str)# + " (" + diff.astype(float).round(2).astype(str) + ")"

# show only upto 3 decimal places
style = res.style.highlight_min(color='green').format("{:.2f}")
style.to_latex("results/ablation_results.tex")
style

Unnamed: 0_level_0,SYNTH-1D,SYNTH-1D,NONSTAT2D,NONSTAT2D,Jump,Jump,Motorcycle,Motorcycle
Unnamed: 0_level_1,NLPD,RMSE,NLPD,RMSE,NLPD,RMSE,NLPD,RMSE
"($\ell,\sigma,\omega$)-GP",0.94,0.17,0.94,0.17,0.96,0.3,0.95,0.23
"($\ell,\sigma$)-GP",0.98,0.21,1.17,0.19,5.49,0.3,1.69,0.22
"($\ell,\omega$)-GP",0.94,0.17,0.94,0.16,0.96,0.3,0.95,0.23
"($\sigma,\omega$)-GP",0.94,0.17,0.94,0.17,0.96,0.3,0.95,0.21
($\ell$)-GP,2.39,0.17,0.02,0.18,6.51,0.3,1.62,0.22
($\sigma$)-GP,2.13,0.17,0.71,0.17,5.79,0.31,1.7,0.22
($\omega$)-GP,0.94,0.17,0.94,0.17,0.97,0.3,0.95,0.21
Stationary Homoskedastic GP,0.9,0.17,0.46,0.17,6.37,0.3,1.61,0.21


In [7]:
%watermark --iversions

jax       : 0.3.25
optax     : 0.1.3
matplotlib: 3.5.1
pandas    : 1.4.2
sys       : 3.9.12 (main, Apr  5 2022, 06:56:58) 
[GCC 7.5.0]



# End