In [1]:
import math
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from jax import jacfwd, jacrev
from jax import vmap, grad, jit, random
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

from flax import nnx
import optax
import pcax

from tqdm.notebook import tqdm
from tqdm.notebook import trange

import torch
from torch.utils.data import DataLoader, TensorDataset

import os
os.chdir('../methods')
import lsci, supr, conf, uqno, prob_don, quant_don, gaus
os.chdir('../gpsims')

os.chdir('../models_and_metrics')
from models import *
from metrics import * 
from utility import *
os.chdir('../gpsims')

In [114]:
import pandas as pd
air = pd.read_csv('../data/PRSA_Data_Tiantan_20130301-20170228.csv')
air["date"] = air["year"].astype(str) + '_' + air["month"].astype(str) + '_' + air["day"].astype(str)

In [115]:
air

Unnamed: 0,No,year,month,day,hour,PM2.5,PM10,SO2,NO2,CO,O3,TEMP,PRES,DEWP,RAIN,wd,WSPM,station,date
0,1,2013,3,1,0,6.0,6.0,4.0,8.0,300.0,81.0,-0.5,1024.5,-21.4,0.0,NNW,5.7,Tiantan,2013_3_1
1,2,2013,3,1,1,6.0,29.0,5.0,9.0,300.0,80.0,-0.7,1025.1,-22.1,0.0,NW,3.9,Tiantan,2013_3_1
2,3,2013,3,1,2,6.0,6.0,4.0,12.0,300.0,75.0,-1.2,1025.3,-24.6,0.0,NNW,5.3,Tiantan,2013_3_1
3,4,2013,3,1,3,6.0,6.0,4.0,12.0,300.0,74.0,-1.4,1026.2,-25.5,0.0,N,4.9,Tiantan,2013_3_1
4,5,2013,3,1,4,5.0,5.0,7.0,15.0,400.0,70.0,-1.9,1027.1,-24.5,0.0,NNW,3.2,Tiantan,2013_3_1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
35059,35060,2017,2,28,19,20.0,48.0,2.0,,500.0,,12.5,1013.5,-16.2,0.0,NW,2.4,Tiantan,2017_2_28
35060,35061,2017,2,28,20,11.0,34.0,3.0,36.0,500.0,,11.6,1013.6,-15.1,0.0,WNW,0.9,Tiantan,2017_2_28
35061,35062,2017,2,28,21,18.0,32.0,4.0,48.0,500.0,48.0,10.8,1014.2,-13.3,0.0,NW,1.1,Tiantan,2017_2_28
35062,35063,2017,2,28,22,15.0,42.0,5.0,52.0,600.0,44.0,10.5,1014.4,-12.9,0.0,NNW,1.2,Tiantan,2017_2_28


In [117]:
air_curves = np.array(air.groupby('date')['PM2.5'].apply(list).reset_index())
dts = air_curves[:,0]
fns = air_curves[:,1]

# transform
fns = [np.array(f) for f in fns]
fns = [jax.image.resize(f, 24, 'bicubic') for f in fns]
fns = np.array(fns)

# backfill missing
missing = jnp.sum(jnp.isnan(fns), axis = 1) > 0
for i in range(fns.shape[0]):
    if missing[i]:
        fns[i] = fns[i-1]

y = fns
# y = np.log(y[:,None,:,None])

x = []
variables = ['TEMP', 'PRES', 'DEWP']
for k in trange(len(variables)):
    air_curves = np.array(air.groupby('date')[variables[k]].apply(list).reset_index())
    dts = air_curves[:,0]
    fns = air_curves[:,1]
    
    # transform
    fns = [np.array(f) for f in fns]
    fns = [jax.image.resize(f, 24, 'bicubic') for f in fns]
    fns = np.array(fns)
    
    # backfill missing
    missing = jnp.sum(jnp.isnan(fns), axis = 1) > 0
    for i in range(fns.shape[0]):
        if missing[i]:
            fns[i] = fns[i-1]

    x.append(fns)
    
x = np.array(x)
x = np.moveaxis(x, 0, 2)[:,None]

x_mu = np.mean(x, axis = (0,2))[None,:,None]
x_sd = np.std(x, axis = (0,2))[None,:,None]
x = (x - x_mu)/x_sd

  0%|          | 0/3 [00:00<?, ?it/s]

In [118]:
ntrain, nval = 600, 1200

xtrain = x[:ntrain].reshape(-1, 1, 24, lag)
ytrain = y[:ntrain].reshape(-1, 1, 24, lead)

xval = x[ntrain:nval].reshape(-1, 1, 24, lag)
yval = y[ntrain:nval].reshape(-1, 1, 24, lead)

xtest = x[nval:].reshape(-1, 1, 24, lag)
ytest = y[nval:].reshape(-1, 1, 24, lead)

train_data = TensorDataset(torch.Tensor(np.array(xtrain)), torch.Tensor(np.array(ytrain)))
train_loader = DataLoader(train_data, batch_size = 30, shuffle = True)

In [119]:
epochs = 50
lag, lead = len(variables), 1

width = 50
drop_prob = 0.1

rng = random.PRNGKey(0)
model_rng = nnx.Rngs(1)
base_model = DeepANO(lag, width, lead, rngs=model_rng)
prob_model = ProbANO(lag, width, lead, rngs=model_rng)
quant_model = DeepANO(lag, width, lead, rngs=model_rng)

base_optim = nnx.Optimizer(base_model, optax.adam(1e-3))
prob_optim = nnx.Optimizer(prob_model, optax.adam(1e-3))
quant_optim = nnx.Optimizer(quant_model, optax.adam(1e-3))

for _ in trange(epochs):
    for xt, yt in train_loader:
        xt = torch2jax(xt)
        yt = torch2jax(yt)
        
        base_loss = train_step(base_model, base_optim, xt, yt)
        prob_loss = prob_step(prob_model, prob_optim, xt, yt)
        quant_loss = quant_step(quant_model, quant_optim, xt, yt)

  0%|          | 0/50 [00:00<?, ?it/s]

In [120]:
yval_hat = base_model(xval)
ytest_hat = base_model(xtest)

yval_quant = quant_model(xval)
ytest_quant = quant_model(xtest)

yval_mu, yval_sd = prob_model(xval)
ytest_mu, ytest_sd = prob_model(xtest)

In [121]:
yval = yval.squeeze()
ytest = ytest.squeeze()

yval_hat = yval_hat.squeeze()
ytest_hat = ytest_hat.squeeze()
yval_quant = yval_quant.squeeze()
ytest_quant = ytest_quant.squeeze()
yval_mu = yval_mu.squeeze()
yval_sd = yval_sd.squeeze()
ytest_mu = ytest_mu.squeeze()
ytest_sd = ytest_sd.squeeze()

In [122]:
nproj = npc = 24
gamma1 = 0.1
gamma2 = 0.05
alpha = 0.1
nval = xval.shape[0]

drop_risk, drop_width = [], []
orcl_risk, orcl_width = [], []
conf_risk, conf_width = [], []
gaus_risk, gaus_width = [], []
supr_risk, supr_width = [], []
uqn1_risk, uqn1_width = [], []
pdon_risk, pdon_width = [], []
qdon_risk, qdon_width = [], []
lsc1_risk, lsc1_width = [], []
lsc2_risk, lsc2_width = [], []

rval = (yval - yval_hat).squeeze()
rtest = (ytest - ytest_hat).squeeze()
rtest2 = (ytest - ytest_mu).squeeze()

state = lsci.lsci_state(xval, rval, npc, localization='l2')
pca_state = state[-1]

quant_scores = jnp.abs(yval - yval_hat) / yval_quant
lam_uqno = uqno.estimate_lambda(quant_scores, 0.1, 0.01, 1.1)

conf_lower, conf_upper = conf.conf_band(rval, pca_state, alpha)
gaus_lower, gaus_upper = gaus.gaus_band(rval, pca_state, alpha)
supr_lower, supr_upper = supr.supr_band(rval, alpha)
uqn1_lower, uqn1_upper = uqno.uqno_band(ytest_quant, lam_uqno)
pdon_lower, pdon_upper = prob_don.prob_don(prob_model, xval, xtest, yval, alpha)
qdon_lower, qdon_upper = quant_don.quant_don(quant_model, xval, xtest, yval - yval_hat, alpha)

for i in trange(0, ytest.shape[0]):
    
    # LSCI
    lsc1_lower, lsc1_upper = lsci.lsci_band(xtest[i].squeeze(), state, alpha, 2000, gamma1)
    lsc1_risk.append(risk(lsc1_lower, lsc1_upper, rtest[i]))
    lsc1_width.append(jnp.median(lsc1_upper - lsc1_lower))
    
    # LSCI
    lsc2_lower, lsc2_upper = lsci.lsci_band(xtest[i].squeeze(), state, alpha, 2000, gamma2)
    lsc2_risk.append(risk(lsc2_lower, lsc2_upper, rtest[i]))
    lsc2_width.append(jnp.median(lsc2_upper - lsc2_lower))

    # Oracle
    orcl_lower = -(jnp.abs(rtest[i]) + 1e-4)
    orcl_upper = jnp.abs(rtest[i]) + 1e-4
    orcl_risk.append(risk(orcl_lower, orcl_upper, rtest[i]))
    orcl_width.append(jnp.median(orcl_upper - orcl_lower))
    
    # CONF 
    conf_risk.append(risk(conf_lower, conf_upper, rtest[i]))
    conf_width.append(jnp.median(conf_upper - conf_lower))
    
    # GAUSS 
    gaus_risk.append(risk(gaus_lower, gaus_upper, rtest[i]))
    gaus_width.append(jnp.median(gaus_upper - gaus_lower))
    
    # SUPR
    supr_risk.append(risk(supr_lower, supr_upper, rtest[i]))
    supr_width.append(jnp.median(supr_upper - supr_lower))
    
    # UQNO
    uqn1_risk.append(risk(uqn1_lower[i], uqn1_upper[i], rtest[i]))
    uqn1_width.append(jnp.median(uqn1_upper[i] - uqn1_lower[i]))
    
    # PDON
    pdon_risk.append(risk(pdon_lower[i], pdon_upper[i], rtest2[i]))
    pdon_width.append(jnp.median(pdon_upper[i] - pdon_lower[i]))
    
    # PDON
    qdon_risk.append(risk(qdon_lower[i], qdon_upper[i], rtest[i]))
    qdon_width.append(jnp.median(qdon_upper[i] - qdon_lower[i]))
    
    # # DROPOUT
    # drop_model.train()
    # drop_set = jnp.stack([drop_model(xtest[i:(i+1)]).squeeze() for _ in range(500)])
    # drop_lower = jnp.quantile(drop_set, alpha/2, axis = 0)
    # drop_upper = jnp.quantile(drop_set, 1 - alpha/2, axis = 0)
    # drop_model.eval()
    
    # drop_risk.append(risk(drop_lower, drop_upper, rtest[i]))
    # drop_width.append(jnp.median(drop_upper - drop_lower))

orcl_risk, orcl_width = np.array(orcl_risk), np.array(orcl_width)
drop_risk, drop_width = np.array(drop_risk), np.array(drop_width)
conf_risk, conf_width = np.array(conf_risk), np.array(conf_width)
gaus_risk, gaus_width = np.array(gaus_risk), np.array(gaus_width)
supr_risk, supr_width = np.array(supr_risk), np.array(supr_width)
uqn1_risk, uqn1_width = np.array(uqn1_risk), np.array(uqn1_width)
lsc1_risk, lsc1_width = np.array(lsc1_risk), np.array(lsc1_width)
lsc2_risk, lsc2_width = np.array(lsc2_risk), np.array(lsc2_width)
pdon_risk, pdon_width = np.array(pdon_risk), np.array(pdon_width)
qdon_risk, qdon_width = np.array(qdon_risk), np.array(qdon_width)

  0%|          | 0/261 [00:00<?, ?it/s]

In [123]:
drop_risk, drop_width = jnp.zeros(ytest.shape[0]),  jnp.zeros(ytest.shape[0])

In [124]:
noise_sd = jnp.mean(rtest**2, axis = 1)

In [125]:
gamma = 0.99

risk_control = np.mean(orcl_risk >= gamma), \
               np.mean(drop_risk >= gamma), \
               np.mean(conf_risk >= gamma), \
               np.mean(gaus_risk >= gamma), \
               np.mean(supr_risk >= gamma), \
               np.mean(uqn1_risk >= gamma), \
               np.mean(pdon_risk >= gamma), \
               np.mean(qdon_risk >= gamma), \
               np.mean(lsc1_risk >= gamma), \
               np.mean(lsc2_risk >= gamma)

width = np.mean(orcl_width), \
        np.mean(drop_width), \
        np.mean(conf_width), \
        np.mean(gaus_width), \
        np.mean(supr_width), \
        np.mean(uqn1_width), \
        np.mean(pdon_width), \
        np.mean(qdon_width), \
        np.mean(lsc1_width), \
        np.mean(lsc2_width)

risk_cor = -0, \
           np.corrcoef([noise_sd, drop_risk])[0,1], \
           np.corrcoef([noise_sd, conf_risk])[0,1], \
           np.corrcoef([noise_sd, gaus_risk])[0,1], \
           np.corrcoef([noise_sd, supr_risk])[0,1], \
           np.corrcoef([noise_sd, uqn1_risk])[0,1], \
           np.corrcoef([noise_sd, pdon_risk])[0,1], \
           np.corrcoef([noise_sd, qdon_risk])[0,1], \
           np.corrcoef([noise_sd, lsc1_risk])[0,1], \
           np.corrcoef([noise_sd, lsc2_risk])[0,1]

width_cor = 1, \
            np.corrcoef([noise_sd, drop_width])[0,1], \
            0, \
            0, \
            0, \
            np.corrcoef([noise_sd, uqn1_width])[0,1], \
            np.corrcoef([noise_sd, pdon_width])[0,1], \
            np.corrcoef([noise_sd, qdon_width])[0,1], \
            np.corrcoef([noise_sd, lsc1_width])[0,1], \
            np.corrcoef([noise_sd, lsc2_width])[0,1]

metrics = np.array([risk_control, risk_cor, width, width_cor]).T

  c /= stddev[:, None]
  c /= stddev[None, :]


In [126]:
methods = ['Oracle', 'MC-Drop', 'Conf1', 'Conf2', 
           'Supr', 'CQR-NO', 'Prob-NO', 'Quant-NO', 
           'LSCI1', 'LSCI2']

for i in range(metrics.shape[0]):
    print(f'{methods[i]} & ', end = '')
    for j in range(metrics.shape[1]):
        val = f'{np.round(np.nan_to_num(metrics[i,j]), 3):.3f}'
        if j < 3:
            val += ' & '
            print(val, end = '')
        else:
            val += ' \\\\'
            print(val)

Oracle & 1.000 & 0.000 & 87.938 & 1.000 \\
MC-Drop & 0.000 & 0.000 & 0.000 & 0.000 \\
Conf1 & 0.969 & -0.867 & 453.874 & 0.000 \\
Conf2 & 0.958 & -0.872 & 448.337 & 0.000 \\
Supr & 0.946 & -0.805 & 449.777 & 0.000 \\
CQR-NO & 0.985 & -0.114 & 1309.577 & 0.459 \\
Prob-NO & 0.874 & -0.672 & 313.344 & -0.032 \\
Quant-NO & 0.678 & -0.061 & 222.917 & 0.460 \\
LSCI1 & 0.981 & -0.667 & 365.285 & 0.390 \\
LSCI2 & 0.946 & -0.523 & 330.759 & 0.439 \\
