In [1]:
import math
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from jax import jacfwd, jacrev
from jax import vmap, grad, jit, random
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

from flax import nnx
import optax
import pcax

from tqdm.notebook import tqdm
from tqdm.notebook import trange

In [20]:
import os
os.chdir('../methods')
import lsci, supr, conf, uqno
os.chdir('../gpsims')

In [3]:
def risk(lower, upper, residual):
    return jnp.mean((residual > lower)*(residual < upper))

In [4]:
def split_data(data, lag, horizon):
    horizon = horizon-1
    y_t = data[(lag + horizon):][:,None]
    x_t = np.stack([data[(lag-i-1):(-(i+1+horizon))] for i in range(lag)], axis = 1)
    return x_t.copy(), y_t.copy()

In [5]:
class ANO_layer(nnx.Module):
    def __init__(self, width, rngs: nnx.Rngs):
        self.linear = nnx.Linear(width, width, rngs=rngs)
        # self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        # self.dropout = nnx.Dropout(0.2, rngs=rngs)
        self.linear_out = nnx.Linear(width, width, rngs=rngs)
        
    def __call__(self, x):
        # channel mix
        h = self.linear(x)

        # spatial mix
        g = jnp.mean(x, axis = (1, 2))[:,None,None,:]

        # sum
        x = h + g
        x = nnx.relu(x)

        return self.linear_out(x)

class encode_layer(nnx.Module):
    def __init__(self, in_dim, out_dim, rngs):
        self.linear = nnx.Linear(in_dim, out_dim, rngs=rngs)

    def __call__(self, x):
        return self.linear(x)

class DeepANO(nnx.Module):
    def __init__(self, in_dim, width, out_dim, rngs):
        self.encode_layer = encode_layer(in_dim, width, rngs)
        self.ano1 = ANO_layer(width, rngs)
        self.ano2 = ANO_layer(width, rngs)
        self.ano3 = ANO_layer(width, rngs)
        self.decode_layer = encode_layer(width, out_dim, rngs)

    def __call__(self, x):
        x = self.encode_layer(x)
        x = self.ano1(x)
        x = self.ano2(x)
        x = self.ano3(x)
        x = self.decode_layer(x)
        return x

In [69]:
key1 = random.PRNGKey(0)
key2 = random.PRNGKey(0)
key3 = random.PRNGKey(0)

t = jnp.linspace(0, 1, 100)
s = jnp.linspace(-2*math.pi, 2*math.pi, 1001)

amp = jnp.sin(s)
sd = jnp.sin(s)

f = jnp.sin(2 * math.pi * t)
f = 2 + amp[:,None] * f[None,:]

xtrain = f + 0.25 * (1.15 + sd)[:,None] * random.normal(key1, f.shape)
xtrain, ytrain = split_data(xtrain, 1, 1)
xtrain = xtrain[:,:,:,None]
ytrain = ytrain[:,:,:,None]

xval = f + 0.25 * (1.15 + sd)[:,None] * random.normal(key2, f.shape)
xval, yval = split_data(xval, 1, 1)
xval = xval[:,:,:,None]
yval = yval[:,:,:,None]

xtest = f + 0.25 * (1.15 + sd)[:,None] * random.normal(key3, f.shape)
xtest, ytest = split_data(xtest, 1, 1)
xtest = xtest[:,:,:,None]
ytest = ytest[:,:,:,None]

In [71]:
@nnx.jit  # Automatic state management
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = model(x)
        y_diff = jnp.diff(y_pred, axis = 0)
        return jnp.mean((y_pred - y) ** 2) + jnp.mean(y_diff**2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)  # In place updates.

    return loss

model = DeepANO(1, 50, 1, nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) 

epochs = 50
nbat = 50
for _ in trange(epochs):
    for i in range(len(f)//nbat):
        xi = xtrain[i*nbat:(i+1)*nbat]
        yi = ytrain[i*nbat:(i+1)*nbat]

        loss = train_step(model, optimizer, xi, yi)

  0%|          | 0/50 [00:00<?, ?it/s]

In [72]:
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        quant = 1 - 0.1
        y_pred = model(x)
        y_abs = jnp.abs(y)
        resid = y_abs - y_pred
        loss = jnp.max(jnp.concat([quant * resid, -(1-quant) * resid], axis = 3), axis = 3)
        return jnp.mean(loss)
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)  # in-place updates

    return loss
train_step = nnx.jit(train_step)

quant = DeepANO(1, 50, 1, nnx.Rngs(0))
optimizer = nnx.Optimizer(quant, optax.adam(1e-3)) 

epochs = 100
nbat = 50
for _ in trange(epochs):
    for i in range(len(f)//nbat):
        xi = xtrain[i*nbat:(i+1)*nbat]
        yi = ytrain[i*nbat:(i+1)*nbat]

        loss = train_step(quant, optimizer, xi, yi)

  0%|          | 0/100 [00:00<?, ?it/s]

In [73]:
yval_hat = model(xval)
ytest_hat = model(xtest)

yval_quant = model(xval)
ytest_quant = model(xtest)

In [74]:
nproj = 90
gamma1 = 0.2
alpha = 0.1
nval = xval.shape[0]
alpha1 = 1 - jnp.ceil((1-alpha) * (gamma1*nval + 1))/(gamma1*nval)

In [75]:
yval = yval.reshape(yval.shape[0], -1)
yval_hat = yval_hat.reshape(yval_hat.shape[0], -1)
pca_state = lsci.phi_state(yval, yval_hat, nproj)

In [76]:
# UQNO lambda estimate
yval_quant = yval.reshape(yval_quant.shape[0], -1)
ytest_quant = ytest_quant.reshape(ytest_quant.shape[0], -1)

alpha = 0.1
delta = 0.1
m = 32*64
tau = 1.1 * jnp.sqrt(-jnp.log(delta)/(2*m))
sg = jnp.abs(yval - yval_hat) / yval_quant
sg = jnp.quantile(sg, 1-alpha+tau, axis = (1))
nval = sg.shape[0]

adj_alpha = 1 - jnp.ceil((nval + 1) * (delta - jnp.exp(-2*m*tau**2)))/nval
lam_uqno = jnp.quantile(sg, adj_alpha)

In [133]:
nproj = 90
gamma1 = 0.1
alpha = 0.1
nval = xval.shape[0]
alpha1 = 1 - jnp.ceil((1-alpha) * (gamma1*nval + 1))/(gamma1*nval)

lsc1_rc = []
lsc1_width = []

conf_rc = []
conf_width = []

supr_rc = []
supr_width = []

uqn1_rc = []
uqn1_width = []

yval = yval.reshape(yval.shape[0], -1)
yval_hat = yval_hat.reshape(yval_hat.shape[0], -1)
pca_state = lsci.phi_state(yval, yval_hat, nproj)

rval = (yval - yval_hat).squeeze()
rtest = (ytest - ytest_hat).squeeze()

conf_lower, conf_upper = conf.conf_band(rval, pca_state, alpha)
supr_lower, supr_upper = supr.supr_band(rval, alpha)
uqn1_lower, uqn1_upper = uqno.uqno_band(ytest_quant, lam_uqno)

for i in trange(0, ytest.shape[0]):
    
    # LSCI
    lsc1_lower, lsc1_upper = lsci.lsci(yval - yval_hat, xval, xtest[i], pca_state, alpha1, gamma1, 2000)
    lsc1_rc.append(risk(lsc1_lower, lsc1_upper, rtest[i]))
    lsc1_width.append(jnp.median(lsc1_upper - lsc1_lower))
    
    # CONF 
    conf_rc.append(risk(conf_lower, conf_upper, rtest[i]))
    conf_width.append(jnp.mean(conf_upper - conf_lower))
    
    # SUPR
    supr_rc.append(risk(supr_lower, supr_upper, rtest[i]))
    supr_width.append(jnp.mean(supr_upper - supr_lower))
    
    # SUPR
    uqn1_rc.append(risk(uqn1_lower[i], uqn1_upper[i], rtest[i]))
    uqn1_width.append(jnp.mean(uqn1_upper[i] - uqn1_lower[i]))

    
conf_rc = np.array(conf_rc)
supr_rc = np.array(supr_rc)
uqn1_rc = np.array(uqn1_rc)
lsc1_rc = np.array(lsc1_rc)

conf_width = np.array(conf_width)
supr_width = np.array(supr_width)
uqn1_width = np.array(uqn1_width)
lsc1_width = np.array(lsc1_width)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [134]:
noise_sd = np.std(rtest, axis = 1)[:conf_rc.shape[0]]

risk_control = np.mean(conf_rc >= 0.975), \
               np.mean(supr_rc >= 0.975), \
               np.mean(uqn1_rc >= 0.975), \
               np.mean(lsc1_rc >= 0.975)

risk_control = np.mean(conf_rc), \
               np.mean(supr_rc), \
               np.mean(uqn1_rc), \
               np.mean(lsc1_rc)

width = np.mean(conf_width), \
        np.mean(supr_width), \
        np.mean(uqn1_width), \
        np.mean(lsc1_width)

risk_cor = np.abs(np.corrcoef([noise_sd, conf_rc])[0,1]), \
           np.abs(np.corrcoef([noise_sd, supr_rc])[0,1]), \
           np.abs(np.corrcoef([noise_sd, uqn1_rc])[0,1]), \
           np.abs(np.corrcoef([noise_sd, lsc1_rc])[0,1])

width_cor = 0, \
            0, \
            np.corrcoef([noise_sd, uqn1_width])[0,1], \
            np.corrcoef([noise_sd, lsc1_width])[0,1]

metrics = np.array([risk_control, width, risk_cor, width_cor]).T

In [132]:
metrics

array([[0.99556845, 3.00904608, 0.55782437, 0.        ],
       [0.99807578, 3.8011682 , 0.47198912, 0.        ],
       [1.        , 6.61982822,        nan, 0.02234269],
       [0.99900866, 2.47479582, 0.0304006 , 0.89587904]])

In [137]:
metrics

array([[ 0.99698997,  3.00904679,  0.61376196,  0.        ],
       [ 0.99866998,  3.80116749,  0.49841795,  0.        ],
       [ 1.        ,  6.62046432,         nan, -0.03630156],
       [ 0.99863994,  1.93129361,  0.07498678,  0.84899161]])