In [21]:
import math
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from jax import jacfwd, jacrev
from jax import vmap, grad, jit, random
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

import torch_harmonics as th
from torch_harmonics.random_fields import GaussianRandomFieldS2

import torch
from torch.utils.data import DataLoader, TensorDataset

from flax import nnx
import optax
import pcax

from tqdm.notebook import tqdm
from tqdm.notebook import trange

In [22]:
import os
os.chdir('../methods')
import lsci, supr, conf, uqno
os.chdir('../gpsims')

In [23]:
def risk(lower, upper, residual):
    return jnp.mean((residual > lower)*(residual < upper))

In [32]:
def split_data(data, lag, horizon):
    horizon = horizon-1
    y_t = data[(lag + horizon):][:,None]
    x_t = np.stack([data[(lag-i-1):(-(i+1+horizon))] for i in range(lag)], axis = 1)
    return x_t.copy(), y_t.copy()

def torch2jax(x):
    return jnp.array(x.numpy())

In [33]:
class ANO_layer(nnx.Module):
    def __init__(self, width, rngs: nnx.Rngs):
        self.linear = nnx.Linear(width, width, rngs=rngs)
        # self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        # self.dropout = nnx.Dropout(0.2, rngs=rngs)
        self.linear_out = nnx.Linear(width, width, rngs=rngs)
        
    def __call__(self, x):
        # channel mix
        h = self.linear(x)

        # spatial mix
        g = jnp.mean(x, axis = (1, 2))[:,None,None,:]

        # sum
        x = h + g
        x = nnx.relu(x)

        return self.linear_out(x)

class encode_layer(nnx.Module):
    def __init__(self, in_dim, out_dim, rngs):
        self.linear = nnx.Linear(in_dim, out_dim, rngs=rngs)

    def __call__(self, x):
        return self.linear(x)

class DeepANO(nnx.Module):
    def __init__(self, in_dim, width, out_dim, rngs):
        self.encode_layer = encode_layer(in_dim, width, rngs)
        self.ano1 = ANO_layer(width, rngs)
        self.ano2 = ANO_layer(width, rngs)
        self.ano3 = ANO_layer(width, rngs)
        self.decode_layer = encode_layer(width, out_dim, rngs)

    def __call__(self, x):
        x = self.encode_layer(x)
        x = self.ano1(x)
        x = self.ano2(x)
        x = self.ano3(x)
        x = self.decode_layer(x)
        return x

In [34]:
class ANO_layer(nnx.Module):
    def __init__(self, width, rngs: nnx.Rngs):
        self.conv = nnx.Conv(width, width, (1, 1), rngs=rngs)
        self.conv_out = nnx.Conv(width, width, (1, 1), rngs=rngs)
        
    def __call__(self, x):
        # channel mix
        h = self.conv(x)

        # spatial mix
        g = jnp.mean(x, axis = (1, 2))[:,None,None,:]

        # sum
        x = h + g
        x = nnx.relu(x)

        return self.conv_out(x)

class encode_layer(nnx.Module):
    def __init__(self, in_dim, out_dim, rngs):
        self.conv = nnx.Conv(in_dim, out_dim, (1, 1), rngs=rngs)

    def __call__(self, x):
        return self.conv(x)

class DeepANO(nnx.Module):
    def __init__(self, in_dim, width, out_dim, rngs):
        self.encode_layer = encode_layer(in_dim, width, rngs)
        self.ano1 = ANO_layer(width, rngs)
        self.ano2 = ANO_layer(width, rngs)
        self.ano3 = ANO_layer(width, rngs)
        self.decode_layer = encode_layer(width, out_dim, rngs)

    def __call__(self, x):
        x = self.encode_layer(x)
        x = self.ano1(x)
        x = self.ano2(x)
        x = self.ano3(x)
        x = self.decode_layer(x)
        return x

In [40]:
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = model(x)
        return jnp.mean((y_pred - y) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)  # In place updates.

    return loss

def quant_step(model, optimizer, x, y):
    def loss_fn(model):
        quant = 1 - 0.1
        y_pred = model(x)
        y_abs = jnp.abs(y)
        resid = y_abs - y_pred
        loss = jnp.max(jnp.concat([quant * resid, -(1-quant) * resid], axis = 3), axis = 3)
        return jnp.mean(loss)
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)  # in-place updates

    return loss

train_step = nnx.jit(train_step)
quant_step = nnx.jit(quant_step)

In [41]:
n = 501
s = jnp.linspace(-2*math.pi, 2*math.pi, n+1)
amp = jnp.sin(s)
sd = 1.25 + jnp.sin(s)


gp2d = GaussianRandomFieldS2(nlat = 30)
xtrain = gp2d(n+1).numpy()
xtrain = 10 + amp[:,None,None] + sd[:,None,None] * jnp.array(xtrain)
xtrain, ytrain = split_data(xtrain, 1, 1)
xtrain = xtrain[:,0,:,:,None]
ytrain = ytrain[:,0,:,:,None]

gp2d = GaussianRandomFieldS2(nlat = 30)
xval = gp2d(n+1).numpy()
xval = 10 + amp[:,None,None] + sd[:,None,None] * jnp.array(xval)
xval, yval = split_data(xval, 1, 1)
xval = xval[:,0,:,:,None]
yval = yval[:,0,:,:,None]

gp2d = GaussianRandomFieldS2(nlat = 30)
xtest = gp2d(n+1).numpy()
xtest = 10 + amp[:,None,None] + sd[:,None,None] * jnp.array(xtest)
xtest, ytest = split_data(xtest, 1, 1)
xtest = xtest[:,0,:,:,None]
ytest = ytest[:,0,:,:,None]

In [42]:
train_data = TensorDataset(torch.Tensor(xtrain), torch.Tensor(ytrain))
train_loader = DataLoader(train_data, batch_size = 30, shuffle = True)

In [43]:
# @nnx.jit  # Automatic state management
# def train_step(model, optimizer, x, y):
#     def loss_fn(model):
#         y_pred = model(x)
#         y_diff = jnp.diff(y_pred, axis = 0)
#         return jnp.mean((y_pred - y) ** 2) + jnp.mean(y_diff**2)

#     loss, grads = nnx.value_and_grad(loss_fn)(model)
#     optimizer.update(grads)  # In place updates.

#     return loss

# epochs = 5
# trace = []

# model = ClimateNO(1, 32, 1, rngs=nnx.Rngs(0))
# optim = nnx.Optimizer(model, optax.adam(1e-3))
# rng = random.PRNGKey(0)

# for _ in trange(epochs):
#     for xt, yt in tqdm(train_loader, leave = False):
#         xt = torch2jax(xt)
#         yt = torch2jax(yt)
        
#         loss = train_step(model, optim, xt, yt)
#         trace.append(loss)

In [44]:
epochs = 100
trace = []
lag, lead = 1, 1

model = DeepANO(lag, 50, lead, rngs=nnx.Rngs(0))
optim = nnx.Optimizer(model, optax.adam(1e-3))
rng = random.PRNGKey(0)

for _ in trange(epochs):
    for xt, yt in train_loader:
        xt = torch2jax(xt)
        yt = torch2jax(yt)
        
        loss = train_step(model, optim, xt, yt)
        trace.append(loss)

  0%|          | 0/100 [00:00<?, ?it/s]

In [45]:
quant = DeepANO(lag, 50, lead, rngs=nnx.Rngs(0))
optim = nnx.Optimizer(quant, optax.adam(1e-3))
rng = random.PRNGKey(0)

for _ in trange(epochs):
    for xt, yt in train_loader:
        xt = torch2jax(xt)
        yt = torch2jax(yt)
        
        loss = quant_step(quant, optim, xt, yt)
        trace.append(loss)

  0%|          | 0/100 [00:00<?, ?it/s]

In [46]:
# def train_step(model, optimizer, x, y):
#     def loss_fn(model):
#         quant = 1 - 0.1
#         y_pred = model(x)
#         y_abs = jnp.abs(y)
#         resid = y_abs - y_pred
#         loss = jnp.max(jnp.concat([quant * resid, -(1-quant) * resid], axis = 3), axis = 3)
#         return jnp.mean(loss)
    
#     loss, grads = nnx.value_and_grad(loss_fn)(model)
#     optimizer.update(grads)  # in-place updates

#     return loss
# train_step = nnx.jit(train_step)

# quant = DeepANO(1, 50, 1, nnx.Rngs(0))
# optimizer = nnx.Optimizer(quant, optax.adam(1e-3)) 

# epochs = 100
# nbat = 50
# for _ in trange(epochs):
#     for i in trange(xtrain.shape[0]//nbat, leave = False):
#         xi = xtrain[i*nbat:(i+1)*nbat]
#         yi = ytrain[i*nbat:(i+1)*nbat]

#         loss = train_step(quant, optimizer, xi, yi)

In [47]:
yval_hat = model(xval)
ytest_hat = model(xtest)
yval_quant = quant(xval)
ytest_quant = quant(xtest)

In [48]:
nproj = 100
gamma1 = 0.2
alpha = 0.1
nval = xval.shape[0]
alpha1 = 1 - jnp.ceil((1-alpha) * (gamma1*nval + 1))/(gamma1*nval)

In [49]:
yval = yval.reshape(yval.shape[0], -1)
yval_hat = yval_hat.reshape(yval_hat.shape[0], -1)
pca_state = lsci.phi_state(yval, yval_hat, nproj)

In [50]:
# UQNO lambda estimate
yval_quant = yval.reshape(yval_quant.shape[0], -1)
ytest_quant = ytest_quant.reshape(ytest_quant.shape[0], -1)

alpha = 0.1
delta = 0.1
m = 30*60
tau = 2 * jnp.sqrt(-jnp.log(delta)/(2*m))
sg = jnp.abs(yval - yval_hat) / yval_quant
sg = jnp.quantile(sg, 1-alpha+tau, axis = (1))
nval = sg.shape[0]

adj_alpha = 1 - jnp.ceil((nval + 1) * (delta - jnp.exp(-2*m*tau**2)))/nval
lam_uqno = jnp.quantile(sg, adj_alpha)
lam_uqno

Array(0.03918276, dtype=float32)

In [51]:
nproj = 150
gamma1 = 0.1
alpha = 0.1
nval = xval.shape[0]
alpha1 = 1 - jnp.ceil((1-alpha) * (gamma1*nval + 1))/(gamma1*nval)

lsc1_rc = []
lsc1_width = []

conf_rc = []
conf_width = []

supr_rc = []
supr_width = []

uqn1_rc = []
uqn1_width = []

yval = yval.reshape(yval.shape[0], -1)
yval_hat = yval_hat.reshape(yval_hat.shape[0], -1)
pca_state = lsci.phi_state(yval, yval_hat, nproj)

rval = (yval - yval_hat).squeeze().reshape(-1, 30*60)
rtest = (ytest - ytest_hat).squeeze().reshape(-1, 30*60)

conf_lower, conf_upper = conf.conf_band(rval, pca_state, alpha)
supr_lower, supr_upper = supr.supr_band(rval, alpha)
uqn1_lower, uqn1_upper = uqno.uqno_band(ytest_quant, lam_uqno)

for i in trange(0, 100):
    
    # LSCI
    lsc1_lower, lsc1_upper = lsci.lsci((yval - yval_hat).reshape(-1, 30*60), xval.reshape(-1, 30*60), xtest[i].reshape(-1, 30*60), pca_state, alpha1, gamma1, 2000)
    lsc1_rc.append(risk(lsc1_lower, lsc1_upper, rtest[i]))
    lsc1_width.append(jnp.median(lsc1_upper - lsc1_lower))
    
    # CONF 
    conf_rc.append(risk(conf_lower, conf_upper, rtest[i]))
    conf_width.append(jnp.mean(conf_upper - conf_lower))
    
    # SUPR
    supr_rc.append(risk(supr_lower, supr_upper, rtest[i]))
    supr_width.append(jnp.mean(supr_upper - supr_lower))
    
    # UQNO
    uqn1_rc.append(risk(uqn1_lower[i], uqn1_upper[i], rtest[i]))
    uqn1_width.append(jnp.mean(uqn1_upper[i] - uqn1_lower[i]))

    
conf_rc = np.array(conf_rc)
supr_rc = np.array(supr_rc)
uqn1_rc = np.array(uqn1_rc)
lsc1_rc = np.array(lsc1_rc)

conf_width = np.array(conf_width)
supr_width = np.array(supr_width)
uqn1_width = np.array(uqn1_width)
lsc1_width = np.array(lsc1_width)

  0%|          | 0/100 [00:00<?, ?it/s]

In [52]:
noise_sd = np.std(rtest, axis = 1)[:conf_rc.shape[0]]

risk_control = np.mean(conf_rc >= 0.99), \
               np.mean(supr_rc >= 0.99), \
               np.mean(uqn1_rc >= 0.99), \
               np.mean(lsc1_rc >= 0.99),  \
               np.mean(lsc2_rc >= 0.99)

width = np.mean(conf_width), \
        np.mean(supr_width), \
        np.mean(uqn1_width), \
        np.mean(lsc1_width), \
        np.mean(lsc2_width)

risk_cor = np.corrcoef([noise_sd, conf_rc])[0,1], \
           np.corrcoef([noise_sd, supr_rc])[0,1], \
           np.corrcoef([noise_sd, uqn1_rc])[0,1], \
           np.corrcoef([noise_sd, lsc1_rc])[0,1], \
           np.corrcoef([noise_sd, lsc2_rc])[0,1]

width_cor = 0, \
            0, \
            np.corrcoef([noise_sd, uqn1_width])[0,1], \
            np.corrcoef([noise_sd, lsc1_width])[0,1], \
            np.corrcoef([noise_sd, lsc2_width])[0,1]

metrics = np.array([risk_control, risk_cor, width, width_cor]).T

NameError: name 'lsc2_rc' is not defined

In [None]:
print(np.round(metrics, 3))