In [1]:
import math
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from jax import jacfwd, jacrev
from jax import vmap, grad, jit, random
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

from flax import nnx
import optax
import pcax

from tqdm.notebook import tqdm
from tqdm.notebook import trange

import torch
import torch_harmonics as th
from torch_harmonics.random_fields import GaussianRandomFieldS2

from torch.utils.data import DataLoader, TensorDataset

import os
os.chdir('../methods')
import lsci, supr, conf, uqno, prob_don, quant_don, gaus
os.chdir('../gpsims')

os.chdir('../models_and_metrics')
from models import *
from metrics import * 
from utility import *
os.chdir('../gpsims')

In [2]:
n = 500
s = jnp.linspace(-2*math.pi, 2*math.pi, n+1)
amp = jnp.sin(s)
sd = 1.25 + jnp.sin(s)

gp2d = GaussianRandomFieldS2(nlat = 30)
xtrain = gp2d(n+1).numpy()
xtrain = 10 + amp[:,None,None] + sd[:,None,None] * jnp.array(xtrain)
xtrain, ytrain = split_data(xtrain, 1, 1)
xtrain = xtrain[:,0,:,:,None]
ytrain = ytrain[:,0,:,:,None]

gp2d = GaussianRandomFieldS2(nlat = 30)
xval = gp2d(n+1).numpy()
xval = 10 + amp[:,None,None] + sd[:,None,None] * jnp.array(xval)
xval, yval = split_data(xval, 1, 1)
xval = xval[:,0,:,:,None]
yval = yval[:,0,:,:,None]

gp2d = GaussianRandomFieldS2(nlat = 30)
xtest = gp2d(n+1).numpy()
xtest = 10 + amp[:,None,None] + sd[:,None,None] * jnp.array(xtest)
xtest, ytest = split_data(xtest, 1, 1)
xtest = xtest[:,0,:,:,None]
ytest = ytest[:,0,:,:,None]

In [3]:
train_data = TensorDataset(torch.Tensor(xtrain), torch.Tensor(ytrain))
train_loader = DataLoader(train_data, batch_size = 30, shuffle = True)

In [4]:
epochs = 50
trace = []
lag, lead = 1, 1

width = 50
drop_prob = 0.1

rng = random.PRNGKey(0)
model_rng = nnx.Rngs(0)
base_model = DeepANO(lag, width, lead, rngs=model_rng)
drop_model = DropANO(lag, width, lead, drop_prob, rngs=model_rng)
prob_model = ProbANO(lag, width, lead, rngs=model_rng)
quant_model = DeepANO(lag, width, lead, rngs=model_rng)

base_optim = nnx.Optimizer(base_model, optax.adam(1e-3))
drop_optim = nnx.Optimizer(drop_model, optax.adam(1e-3))
prob_optim = nnx.Optimizer(prob_model, optax.adam(1e-3))
quant_optim = nnx.Optimizer(quant_model, optax.adam(1e-3))

for _ in trange(epochs):
    for xt, yt in train_loader:
        xt = torch2jax(xt)
        yt = torch2jax(yt)
        
        base_loss = train_step(base_model, base_optim, xt, yt)
        drop_loss = train_step(drop_model, drop_optim, xt, yt)
        prob_loss = prob_step(prob_model, prob_optim, xt, yt)
        quant_loss = quant_step(quant_model, quant_optim, xt, yt)

  0%|          | 0/50 [00:00<?, ?it/s]

In [5]:
yval_hat = base_model(xval)
ytest_hat = base_model(xtest)

yval_quant = quant_model(xval)
ytest_quant = quant_model(xtest)

yval_mu, yval_sd = prob_model(xval)
ytest_mu, ytest_sd = prob_model(xtest)

In [6]:
yval = yval.reshape(n, -1)
ytest = ytest.reshape(n, -1)

yval_hat = yval_hat.reshape(n, -1)
ytest_hat = ytest_hat.reshape(n, -1)
yval_quant = yval_quant.reshape(n, -1)
ytest_quant = ytest_quant.reshape(n, -1)
yval_mu = yval_mu.reshape(n, -1)
yval_sd = yval_sd.reshape(n, -1)
ytest_mu = ytest_mu.reshape(n, -1)
ytest_sd = ytest_sd.reshape(n, -1)

In [7]:
yval_hat.shape

(500, 1800)

In [8]:
nproj = npc = 200
gamma1 = 0.1
gamma2 = 0.05
alpha = 0.1
nval = xval.shape[0]

drop_risk, drop_width = [], []
orcl_risk, orcl_width = [], []
conf_risk, conf_width = [], []
gaus_risk, gaus_width = [], []
supr_risk, supr_width = [], []
uqn1_risk, uqn1_width = [], []
pdon_risk, pdon_width = [], []
qdon_risk, qdon_width = [], []
lsc1_risk, lsc1_width = [], []
lsc2_risk, lsc2_width = [], []

rval = (yval - yval_hat).reshape(n, -1)
rtest = (ytest - ytest_hat).reshape(n, -1)
rtest2 = (ytest - ytest_mu).reshape(n, -1)

state = lsci.lsci_state(xval.reshape(n, -1), rval, npc)
pca_state = state[-1]

quant_scores = jnp.abs(yval - yval_hat) / yval_quant
lam_uqno = uqno.estimate_lambda(quant_scores, 0.1, 0.01, 1.1)

conf_lower, conf_upper = conf.conf_band(rval, pca_state, alpha)
gaus_lower, gaus_upper = gaus.gaus_band(rval, pca_state, alpha)
supr_lower, supr_upper = supr.supr_band(rval, alpha)
uqn1_lower, uqn1_upper = uqno.uqno_band(ytest_quant, lam_uqno)
pdon_lower, pdon_upper = prob_don.prob_don(prob_model, xval, xtest, yval, alpha)
qdon_lower, qdon_upper = quant_don.quant_don(quant_model, xval, xtest, yval - yval_hat, alpha)

for i in trange(0, ytest.shape[0]):
    
    xtest_i = xtest[i].reshape(1, -1)
    rtest_i = rtest[i].reshape(1, -1)
    
    # LSCI
    lsc1_lower, lsc1_upper = lsci.lsci_band(xtest_i, state, alpha, 2000, gamma1)
    lsc1_risk.append(risk(lsc1_lower, lsc1_upper, rtest_i))
    lsc1_width.append(jnp.median(lsc1_upper - lsc1_lower))
    
    # LSCI
    lsc2_lower, lsc2_upper = lsci.lsci_band(xtest_i, state, alpha, 2000, gamma2)
    lsc2_risk.append(risk(lsc2_lower, lsc2_upper, rtest_i))
    lsc2_width.append(jnp.median(lsc2_upper - lsc2_lower))

    # Oracle
    orcl_lower = -(jnp.abs(rtest[i]) + 1e-4)
    orcl_upper = jnp.abs(rtest[i]) + 1e-4
    orcl_risk.append(risk(orcl_lower, orcl_upper, rtest_i))
    orcl_width.append(jnp.median(orcl_upper - orcl_lower))
    
    # CONF 
    conf_risk.append(risk(conf_lower, conf_upper, rtest_i))
    conf_width.append(jnp.median(conf_upper - conf_lower))
    
    # GAUSS 
    gaus_risk.append(risk(gaus_lower, gaus_upper, rtest_i))
    gaus_width.append(jnp.median(gaus_upper - gaus_lower))
    
    # SUPR
    supr_risk.append(risk(supr_lower, supr_upper, rtest_i))
    supr_width.append(jnp.median(supr_upper - supr_lower))
    
    # UQNO
    uqn1_risk.append(risk(uqn1_lower[i], uqn1_upper[i], rtest_i))
    uqn1_width.append(jnp.median(uqn1_upper[i] - uqn1_lower[i]))
    
    # PDON
    pdon_risk.append(risk(pdon_lower[i], pdon_upper[i], rtest_i))
    pdon_width.append(jnp.median(pdon_upper[i] - pdon_lower[i]))
    
    # PDON
    qdon_risk.append(risk(qdon_lower[i], qdon_upper[i], rtest_i))
    qdon_width.append(jnp.median(qdon_upper[i] - qdon_lower[i]))
    
    # DROPOUT
    drop_model.train()
    drop_set = jnp.stack([drop_model(xtest[i:(i+1)]).squeeze() for _ in range(500)])
    drop_set = drop_set.reshape(500, -1)
    drop_lower = jnp.quantile(drop_set, alpha/2, axis = 0)
    drop_upper = jnp.quantile(drop_set, 1 - alpha/2, axis = 0)
    drop_model.eval()
    
    drop_risk.append(risk(drop_lower, drop_upper, rtest_i))
    drop_width.append(jnp.median(drop_upper - drop_lower))


orcl_risk, orcl_width = np.array(orcl_risk), np.array(orcl_width)
drop_risk, drop_width = np.array(drop_risk), np.array(drop_width)
conf_risk, conf_width = np.array(conf_risk), np.array(conf_width)
gaus_risk, gaus_width = np.array(gaus_risk), np.array(gaus_width)
supr_risk, supr_width = np.array(supr_risk), np.array(supr_width)
uqn1_risk, uqn1_width = np.array(uqn1_risk), np.array(uqn1_width)
lsc1_risk, lsc1_width = np.array(lsc1_risk), np.array(lsc1_width)
lsc2_risk, lsc2_width = np.array(lsc2_risk), np.array(lsc2_width)
pdon_risk, pdon_width = np.array(pdon_risk), np.array(pdon_width)
qdon_risk, qdon_width = np.array(qdon_risk), np.array(qdon_width)

  0%|          | 0/500 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
sigma = 0.25 * (1.15 + sd)
noise_sd = sigma[1:]
gamma = 0.99

risk_control = np.mean(orcl_risk >= gamma), \
               np.mean(drop_risk >= gamma), \
               np.mean(conf_risk >= gamma), \
               np.mean(gaus_risk >= gamma), \
               np.mean(supr_risk >= gamma), \
               np.mean(uqn1_risk >= gamma), \
               np.mean(pdon_risk >= gamma), \
               np.mean(qdon_risk >= gamma), \
               np.mean(lsc1_risk >= gamma), \
               np.mean(lsc2_risk >= gamma)

width = np.mean(orcl_width), \
        np.mean(drop_width), \
        np.mean(conf_width), \
        np.mean(gaus_width), \
        np.mean(supr_width), \
        np.mean(uqn1_width), \
        np.mean(pdon_width), \
        np.mean(qdon_width), \
        np.mean(lsc1_width), \
        np.mean(lsc2_width)

risk_cor = -0, \
           np.corrcoef([noise_sd, drop_risk])[0,1], \
           np.corrcoef([noise_sd, conf_risk])[0,1], \
           np.corrcoef([noise_sd, gaus_risk])[0,1], \
           np.corrcoef([noise_sd, supr_risk])[0,1], \
           np.corrcoef([noise_sd, uqn1_risk])[0,1], \
           np.corrcoef([noise_sd, pdon_risk])[0,1], \
           np.corrcoef([noise_sd, qdon_risk])[0,1], \
           np.corrcoef([noise_sd, lsc1_risk])[0,1], \
           np.corrcoef([noise_sd, lsc2_risk])[0,1]

width_cor = 1, \
            np.corrcoef([noise_sd, drop_width])[0,1], \
            0, \
            0, \
            0, \
            np.corrcoef([noise_sd, uqn1_width])[0,1], \
            np.corrcoef([noise_sd, pdon_width])[0,1], \
            np.corrcoef([noise_sd, qdon_width])[0,1], \
            np.corrcoef([noise_sd, lsc1_width])[0,1], \
            np.corrcoef([noise_sd, lsc2_width])[0,1]

metrics = np.array([risk_control, risk_cor, width, width_cor]).T

In [None]:
orcl_width

In [None]:
for i in range(metrics.shape[0]):
    for j in range(metrics.shape[1]):
        val = f'{np.round(np.nan_to_num(metrics[i,j]), 3):.3f}'
        if j < 3:
            val += ' & '
            print(val, end = '')
        else:
            val += ' \\\\'
            print(val)

In [25]:
methods = ['Oracle', 'MC-Drop', 'Conf1', 'Conf2', 
           'Supr', 'CQR-NO', 'Prob-NO', 'Quant-NO', 
           'LSCI1', 'LSCI2']

for i in range(metrics.shape[0]):
    print(f'{methods[i]} & ', end = '')
    for j in range(metrics.shape[1]):
        val = f'{np.round(np.nan_to_num(metrics[i,j]), 3):.3f}'
        if j < 3:
            val += ' & '
            print(val, end = '')
        else:
            val += ' \\\\'
            print(val)

Oracle & 1.000 & 0.000 & 0.170 & 1.000 \\
MC-Drop & 0.000 & 0.000 & 0.392 & 0.997 \\
Conf1 & 0.810 & -0.582 & 1.070 & 0.000 \\
Conf2 & 0.778 & -0.613 & 1.028 & 0.000 \\
Supr & 0.930 & -0.511 & 1.210 & 0.000 \\
CQR-NO & 0.778 & -0.628 & 0.916 & 0.999 \\
Prob-NO & 0.447 & -0.867 & 0.553 & -0.999 \\
Quant-NO & 0.854 & 0.482 & 2.219 & 0.999 \\
LSCI1 & 0.946 & 0.006 & 0.837 & 0.998 \\
LSCI2 & 0.906 & 0.010 & 0.823 & 0.998 \\


In [None]:
# DCT based fpca?

In [382]:
def dct_basis(N, K):
    basis = np.zeros((N, K))
    n_seq = np.arange(N)
    
    basis[:,0] = np.sqrt(1 / N)
    for k in range(1, K):
        basis[:,k] = np.cos((math.pi / N) * (n_seq + 0.5) * k)
        basis[:,k] = basis[:,k] * np.sqrt(2 / N)  
    return basis.T

def dct_project(f, nvr = None, nvc = None):
    n1, p1r, p1c = f.shape
    dct1r = dct_basis(p1r, nvr)
    dct1c = dct_basis(p1c, nvc).T
    xi = (dct1r @ f @ dct1c).reshape(-1, nvr * nvc)
    return xi, [dct1r, dct1c]

def dct_invert(xi, basis, nvr = None, nvc = None):
    dct1r, dct1c = basis
    p1r, p1c = dct1r.shape[0], dct1c.shape[1]
    return dct1r.T @ xi.reshape(-1, p1r, p1c) @ dct1c.T

def low_pass(x, d1, d2):
    xi, basis = dct_project(x, d1, d2)
    return dct_invert(xi, basis)

In [383]:
n = 501
s = jnp.linspace(-2*math.pi, 2*math.pi, n+1)
amp = jnp.sin(s)
sd = 1.25 + jnp.sin(s)

gp2d = GaussianRandomFieldS2(nlat = 30)
xtrain = gp2d(n+1).numpy()
xtrain = 10 + amp[:,None,None] + sd[:,None,None] * jnp.array(xtrain)
xtrain, ytrain = split_data(xtrain, 1, 1)
xtrain = xtrain[:,0,:,:,None]
ytrain = ytrain[:,0,:,:,None]

gp2d = GaussianRandomFieldS2(nlat = 30)
xval = gp2d(n+1).numpy()
xval = 10 + amp[:,None,None] + sd[:,None,None] * jnp.array(xval)
xval, yval = split_data(xval, 1, 1)
xval = xval[:,0,:,:,None]
yval = yval[:,0,:,:,None]

gp2d = GaussianRandomFieldS2(nlat = 30)
xtest = gp2d(n+1).numpy()
xtest = 10 + amp[:,None,None] + sd[:,None,None] * jnp.array(xtest)
xtest, ytest = split_data(xtest, 1, 1)
xtest = xtest[:,0,:,:,None]
ytest = ytest[:,0,:,:,None]

In [384]:
yval_hat = base_model(xval)
ytest_hat = base_model(xtest)

yval = yval.reshape(n, -1)
ytest = ytest.reshape(n, -1)
yval_hat = yval_hat.reshape(n, -1)
ytest_hat = ytest_hat.reshape(n, -1)

In [385]:
nproj = npc = 50
gamma1 = 0.1
gamma2 = 0.01
alpha = 0.1
p1, p2 = 20, 30
rval = (yval - yval_hat).reshape(-1, 30, 60)
rtest = (ytest - ytest_hat).reshape(-1, 30, 60)

rval = low_pass(rval, p1, p2).reshape(n, -1)
rtest = low_pass(rtest, p1, p2).reshape(n, -1)
xval = low_pass(xval.squeeze(), p1, p2).reshape(n, -1)
xtest = low_pass(xtest.squeeze(), p1, p2).reshape(n, -1)

state = lsci.lsci_state(xval, rval, npc)
pca_state = state[-1]

In [386]:
xtest.shape

(501, 1800)

In [387]:
lsc1_risk = []
for i in trange(500):
    lsc1_lower, lsc1_upper = lsci.lsci_band(xtest[i].squeeze(), state, alpha, 2000, gamma2)
    lsc1_risk.append(risk(lsc1_lower, lsc1_upper, rtest[i]))

  0%|          | 0/500 [00:00<?, ?it/s]

In [388]:
np.mean(np.array(lsc1_risk) >= 0.99)

0.088

In [389]:
lsc1_risk

[Array(0.9727778, dtype=float32),
 Array(0.8983334, dtype=float32),
 Array(0.98444444, dtype=float32),
 Array(0.98444444, dtype=float32),
 Array(0.97944444, dtype=float32),
 Array(0.95944446, dtype=float32),
 Array(0.9711111, dtype=float32),
 Array(0.9616667, dtype=float32),
 Array(0.9255556, dtype=float32),
 Array(0.95500004, dtype=float32),
 Array(0.9572222, dtype=float32),
 Array(0.8927778, dtype=float32),
 Array(0.9277778, dtype=float32),
 Array(0.88222224, dtype=float32),
 Array(0.8616667, dtype=float32),
 Array(0.9016667, dtype=float32),
 Array(0.98222226, dtype=float32),
 Array(0.9027778, dtype=float32),
 Array(0.9538889, dtype=float32),
 Array(0.95000005, dtype=float32),
 Array(0.95055556, dtype=float32),
 Array(0.93833333, dtype=float32),
 Array(0.95000005, dtype=float32),
 Array(0.88222224, dtype=float32),
 Array(0.87111115, dtype=float32),
 Array(0.94611114, dtype=float32),
 Array(0.935, dtype=float32),
 Array(0.9105556, dtype=float32),
 Array(0.9827778, dtype=float32),
 Arr

In [390]:
sigma = 0.25 * (1.15 + sd)
noise_sd = sigma[1:]
risk_cor = np.corrcoef([noise_sd, lsc1_risk])[0,1]
width_cor = np.corrcoef([noise_sd, lsc1_width])[0,1]

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.