In [17]:
# numpy
import numpy as np
import scipy
np.set_printoptions(suppress=True)

import jax
from jax import vmap, grad, jit, random
import jax.numpy as jnp

# plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib.ticker import FormatStrFormatter
plt.style.use('default')

from matplotlib.gridspec import GridSpec
from matplotlib import colors
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

from tqdm.notebook import tqdm
from tqdm.notebook import trange

from scipy.stats import genextreme, norm

In [223]:
def tukey1d(x, y):
    f1 = jnp.mean(x < y)
    f2 = jnp.mean(x > y)
    f = jnp.vstack([f1, f2])
    return 2*jnp.min(f)
tukey1d = vmap(tukey1d, (None, 0))
tukey2d = vmap(vmap(tukey1d, (1, 1)), (1, 1))

def tukey_depth(x, y):
    depth = tukey2d(x, y)
    return jnp.mean(depth, axis = (0, 1))
tukey_depth = jit(tukey_depth)

In [212]:
# Estimates the cutoff value conformal inference on spatiotemporal processes. Assumes tukey depth.
def conf_quantile(res_val, alpha):
    '''
    Parameters
    ----------

    res_val: 3D tensor (n, p1, p2)
        n = sample size
        p1 = spatial dim 1
        p2 = spatial dim 2

        Residual fields from the calibration / validation set, i.e. y - y_hat


    alpha: float
        confidence level between 0 and 1

    Returns
    -------
    Cutoff value for depths to generate an alpha level prediction set.
    '''

    nval = res_val.shape[0]
    adj_alpha = jnp.ceil((1 - alpha) * (nval + 1))/(nval + 1)

    # compute the calibration depths
    depth_val = tukey_depth(res_val, res_val)

    # smoothed quantile estimator
    q_val = jnp.sort(depth_val)[nval-int(np.ceil((1 - alpha) * (nval + 1)))]

    return q_val
    
# Generates the conformal ensemble. Assumes tukey depth.
def conf_ensemble(res_val, alpha):
    '''
    Parameters
    ----------

    res_val: 3D tensor (n, p1, p2)
        n = sample size
        p1 = spatial dim 1
        p2 = spatial dim 2

        Residual fields from the calibration / validation set, i.e. y - y_hat


    alpha: float
        confidence level between 0 and 1

    Returns
    -------
    Full alpha level conformal ensemble. Add these fields onto predictions to generate prediction sets.
    '''
    
    nval = res_val.shape[0]
    adj_alpha = jnp.ceil((1 - alpha) * (nval + 1))/(nval + 1)

    # compute the calibration depths
    depth_val = tukey_depth(res_val, res_val)
    
    # smoothed quantile estimator
    q_val = jnp.sort(depth_val)[nval-int(np.ceil((1 - alpha) * (nval + 1)))]

    return res_val[depth_val >= q_val]

In [256]:
# significance level
alpha = 0.1

# calibration data
np.random.seed(1023)
y_cal = np.random.randn(500, 30, 30)
yhat_cal = np.random.randn(500, 30, 30)

# test data
y_test = np.random.randn(500, 30, 30)
yhat_test = np.random.randn(500, 30, 30)

# compute residuals
res_val = y_cal - yhat_cal
res_test = y_test - yhat_test

# cutoff value
q_val = conf_quantile(res_val, alpha)

# full ensemble
ens_val = conf_ensemble(res_val, alpha)

In [257]:
q_val = conf_quantile(res_val, alpha)
q_val

Array(0.48624894, dtype=float32)

In [258]:
depth_test = tukey_depth(res_val, res_test)
jnp.mean(depth_test >= q_val)

In [259]:
jnp.mean(depth_test >= q_val)

Array(0.90200007, dtype=float32)

In [230]:
ens_val = conf_ensemble(res_val, alpha)

In [231]:
ens_val.shape

(451, 30, 30)