In [1]:
import math
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from jax import jacfwd, jacrev
from jax import vmap, grad, jit, random
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

from flax import nnx
import optax
import pcax

from tqdm.notebook import tqdm
from tqdm.notebook import trange

In [2]:
import os
os.chdir('../methods')
import lsci, supr, conf, uqno
os.chdir('../gpsims')

In [3]:
def risk(lower, upper, residual):
    return jnp.mean((residual > lower)*(residual < upper))

In [4]:
def split_data(data, lag, horizon):
    horizon = horizon-1
    y_t = data[(lag + horizon):][:,None]
    x_t = np.stack([data[(lag-i-1):(-(i+1+horizon))] for i in range(lag)], axis = 1)
    return x_t.copy(), y_t.copy()

In [5]:
class ANO_layer(nnx.Module):
    def __init__(self, width, rngs: nnx.Rngs):
        self.linear = nnx.Linear(width, width, rngs=rngs)
        # self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        # self.dropout = nnx.Dropout(0.2, rngs=rngs)
        self.linear_out = nnx.Linear(width, width, rngs=rngs)
        
    def __call__(self, x):
        # channel mix
        h = self.linear(x)

        # spatial mix
        g = jnp.mean(x, axis = (1, 2))[:,None,None,:]

        # sum
        x = h + g
        x = nnx.relu(x)

        return self.linear_out(x)

class encode_layer(nnx.Module):
    def __init__(self, in_dim, out_dim, rngs):
        self.linear = nnx.Linear(in_dim, out_dim, rngs=rngs)

    def __call__(self, x):
        return self.linear(x)

class DeepANO(nnx.Module):
    def __init__(self, in_dim, width, out_dim, rngs):
        self.encode_layer = encode_layer(in_dim, width, rngs)
        self.ano1 = ANO_layer(width, rngs)
        self.ano2 = ANO_layer(width, rngs)
        self.ano3 = ANO_layer(width, rngs)
        self.decode_layer = encode_layer(width, out_dim, rngs)

    def __call__(self, x):
        x = self.encode_layer(x)
        x = self.ano1(x)
        x = self.ano2(x)
        x = self.ano3(x)
        x = self.decode_layer(x)
        return x

In [7]:
@nnx.jit  # Automatic state management
def train_step_base(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = model(x)
        y_diff = jnp.diff(y_pred, axis = 0)
        return jnp.mean((y_pred - y) ** 2) + jnp.mean(y_diff**2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)  # In place updates.

    return loss

@nnx.jit  # Automatic state management
def train_step_quant(model, optimizer, x, y):
    def loss_fn(model):
        quant = 1 - 0.1
        y_pred = model(x)
        y_abs = jnp.abs(y)
        resid = y_abs - y_pred
        loss = jnp.max(jnp.concat([quant * resid, -(1-quant) * resid], axis = 3), axis = 3)
        return jnp.mean(loss)
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)  # in-place updates

    return loss

In [54]:
ntrain, nval, ntest = 1000, 1000, 1000
n_points = 100
n_samp = 2000
n_sims = 10


ano_width = 50
epochs = 50
nbat = 50
n_eval = 1000

alpha = 0.1
delta = alpha
gamma = 0.2
nproj = 90

alpha1 = 1 - jnp.ceil((1-alpha) * (gamma*nval + 1))/(gamma*nval)

In [55]:
t = jnp.linspace(0, 1, n_points)
s = jnp.linspace(-2*math.pi, 2*math.pi, ntrain+1)

amp = jnp.sin(s)
sd = jnp.sin(s)

f = jnp.sin(2 * math.pi * t)
f = 10 + amp[:,None] * f[None,:]

In [56]:
metrics = []

for s in trange(n_sims):
    
    rng0 = s
    key1 = random.PRNGKey(rng0 + 1)
    key2 = random.PRNGKey(rng0 + 2)
    key3 = random.PRNGKey(rng0 + 3)

    xtrain = f + 0.25 * (1.15 + sd)[:,None] * random.normal(key1, f.shape)
    xtrain, ytrain = split_data(xtrain, 1, 1)
    xtrain = xtrain[:,:,:,None]
    ytrain = ytrain[:,:,:,None]

    xval = f + 0.25 * (1.15 + sd)[:,None] * random.normal(key2, f.shape)
    xval, yval = split_data(xval, 1, 1)
    xval = xval[:,:,:,None]
    yval = yval[:,:,:,None]

    xtest = f + 0.25 * (1.15 + sd)[:,None] * random.normal(key3, f.shape)
    xtest, ytest = split_data(xtest, 1, 1)
    xtest = xtest[:,:,:,None]
    ytest = ytest[:,:,:,None]


    model = DeepANO(1, ano_width, 1, nnx.Rngs(s))
    quant = DeepANO(1, ano_width, 1, nnx.Rngs(s))
    optim_model = nnx.Optimizer(model, optax.adam(1e-3))
    optim_quant = nnx.Optimizer(quant, optax.adam(1e-3))  

    for _ in trange(epochs, leave = False):
        for i in range(ntrain//nbat):
            xi = xtrain[i*nbat:(i+1)*nbat]
            yi = ytrain[i*nbat:(i+1)*nbat]
            loss = train_step_base(model, optim_model, xi, yi)
            

    for _ in trange(epochs, leave = False):
        for i in range(len(f)//nbat):
            xi = xtrain[i*nbat:(i+1)*nbat]
            yi = ytrain[i*nbat:(i+1)*nbat]
            loss = train_step_quant(quant, optim_quant, xi, yi)
            
    yval_hat = model(xval)
    yval_quant = quant(xval)
    ytest_hat = model(xtest)
    ytest_quant = quant(xtest)

    ## estimate EIGEN
    yval = yval.reshape(nval, -1)
    yval_hat = yval_hat.reshape(nval, -1)
    pca_state = lsci.phi_state(yval, yval_hat, nproj)
    
    # UQNO lambda estimate
    yval_quant = yval.reshape(nval, -1)
    ytest_quant = ytest_quant.reshape(ntest, -1)

    tau = 1.01 * jnp.sqrt(-jnp.log(delta)/(2*n_points))
    sg = jnp.abs(yval - yval_hat) / yval_quant
    sg = jnp.quantile(sg, 1-alpha+tau, axis = (1))

    adj_alpha = 1 - jnp.ceil((nval + 1) * (delta - jnp.exp(-2*n_points*tau**2)))/nval
    lam_uqno = jnp.quantile(sg, adj_alpha)
    
    lsc1_rc = []
    conf_rc = []
    supr_rc = []
    uqn1_rc = []
    lsc1_width = []
    conf_width = []
    supr_width = []
    uqn1_width = []

    # compute static / deterministic bands
    rval = (yval - yval_hat).squeeze()
    rtest = (ytest - ytest_hat).squeeze()
    conf_lower, conf_upper = conf.conf_band(rval, pca_state, alpha)
    supr_lower, supr_upper = supr.supr_band(rval, alpha)
    uqn1_lower, uqn1_upper = uqno.uqno_band(ytest_quant, lam_uqno)

    for i in trange(0, n_eval, leave = False):

        # LSCI
        lsc1_lower, lsc1_upper = lsci.lsci(yval - yval_hat, xval, xtest[i], pca_state, alpha1, gamma1, n_samp)
        lsc1_rc.append(risk(lsc1_lower, lsc1_upper, rtest[i]))
        lsc1_width.append(jnp.median(lsc1_upper - lsc1_lower))

        # CONF 
        conf_rc.append(risk(conf_lower, conf_upper, rtest[i]))
        conf_width.append(jnp.mean(conf_upper - conf_lower))

        # SUPR
        supr_rc.append(risk(supr_lower, supr_upper, rtest[i]))
        supr_width.append(jnp.mean(supr_upper - supr_lower))

        # SUPR
        uqn1_rc.append(risk(uqn1_lower[i], uqn1_upper[i], rtest[i]))
        uqn1_width.append(jnp.mean(uqn1_upper[i] - uqn1_lower[i]))


    conf_rc = np.array(conf_rc)
    supr_rc = np.array(supr_rc)
    uqn1_rc = np.array(uqn1_rc)
    lsc1_rc = np.array(lsc1_rc)

    conf_width = np.array(conf_width)
    supr_width = np.array(supr_width)
    uqn1_width = np.array(uqn1_width)
    lsc1_width = np.array(lsc1_width)
    
    noise_sd = np.std(rtest, axis = 1)[:conf_rc.shape[0]]


    risk_control = np.mean(conf_rc), \
                   np.mean(supr_rc), \
                   np.mean(uqn1_rc), \
                   np.mean(lsc1_rc)

    width = np.mean(conf_width), \
            np.mean(supr_width), \
            np.mean(uqn1_width), \
            np.mean(lsc1_width)

    risk_cor = np.corrcoef([noise_sd, conf_rc])[0,1], \
               np.corrcoef([noise_sd, supr_rc])[0,1], \
               np.corrcoef([noise_sd, uqn1_rc])[0,1], \
               np.corrcoef([noise_sd, lsc1_rc])[0,1]

    width_cor = 0, \
                0, \
                np.corrcoef([noise_sd, uqn1_width])[0,1], \
                np.corrcoef([noise_sd, lsc1_width])[0,1]

    metrics.append(np.array([risk_control, risk_cor, width, width_cor]).T)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [65]:
metric_means

array([[ 0.99627295, -0.65859826,  2.95463824,  0.        ],
       [ 0.99869598, -0.51223943,  3.78698554,  0.        ],
       [ 0.999998  ,         nan,  6.05731649,  0.29250552],
       [ 0.99546993,  0.02189394,  1.9045119 ,  0.79359517]])

In [97]:
metrics2 = np.array(metrics)
metrics2.shape

metric_means = np.mean(np.nan_to_num(metrics2), axis = 0)
metric_std = np.std(np.nan_to_num(metrics2), axis = 0)

for i in range(4):
    for j in range(4):
        val = f'{np.round(metric_means[i,j], 3):.3f} ({np.round(metric_std[i,j], 3):.3f})'
        if j < 3:
            val += ' & '
            print(val, end = '')
        else:
            val += ' \\\\'
            print(val)

0.996 (0.000) & -0.659 (0.023) & 2.955 (0.026) & 0.000 (0.000) \\
0.999 (0.000) & -0.512 (0.039) & 3.787 (0.050) & 0.000 (0.000) \\
1.000 (0.000) & -0.008 (0.023) & 6.057 (0.223) & 0.293 (0.152) \\
0.995 (0.000) & 0.022 (0.036) & 1.905 (0.012) & 0.794 (0.012) \\


In [102]:
np.round(np.mean(np.nan_to_num(metrics2), axis = 0), 3)

array([[ 0.996, -0.659,  2.955,  0.   ],
       [ 0.999, -0.512,  3.787,  0.   ],
       [ 1.   , -0.008,  6.057,  0.293],
       [ 0.995,  0.022,  1.905,  0.794]])