In [35]:
import numpy as np

import matplotlib
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from jax import jacfwd, jacrev
from jax import vmap, grad, jit, random

from lsci.conformal import lsci

In [51]:
n, p = 100, 25

xval = np.random.randn(n, p)
yval = 1 + np.random.randn(n, p)
yval_hat = 1.2 + np.random.randn(n, p)

xtest = np.random.randn(n, p)
ytest = 1 + np.random.randn(n, p)
ytest_hat = 1.2 + np.random.randn(n, p)

rval = yval - yval_hat
rtest = ytest - ytest_hat

In [42]:
# local weights
local_weights = lsci.l2_localizer(xval, xtest, lam = 1)

In [43]:
# sample candidate set
ens = lsci.fpca_sampler(rval, local_weights[0:1], n_samp = 5000, rng = random.key(0))

# filter out-of-bounds functions
ens = lsci.depth_reject(rval, ens, local_weights[0:1], n_phi = 10, alpha = 0.1, rng = random.key(0))

In [196]:
n, p = 2000, 25

data_rng = random.key(0)
data_keys = random.split(data_rng, 6)

xval = random.normal(data_keys[0], (n, p))
yval = 1 + random.normal(data_keys[1], (n, p))
yval_hat = 1.2 + random.normal(data_keys[2], (n, p))

xtest = random.normal(data_keys[3], (n, p))
ytest = 1 + random.normal(data_keys[4], (n, p))
ytest_hat = 1.2 + random.normal(data_keys[5], (n, p))

rval = yval - yval_hat
rtest = ytest - ytest_hat

In [197]:
alpha = 0.1
alpha_conf = jnp.ceil((n+1)*(1-alpha))/n

In [199]:
# local weights and projections
method_rng = random.key(0)
method_keys = random.split(method_rng, 2)

noise = 0.1*random.normal(method_keys[0], xtest.shape)
local_weights = lsci.linf_localizer(xval, xtest + noise, 5)
phi = lsci.phi_slice(method_keys[1], p, 25)

In [200]:
# project data
phi_val = rval @ phi
phi_test = rtest @ phi

In [201]:
# conformal cutoffs 
depth_val = lsci.phi_tukey(phi_val, phi_val, local_weights)
quant_val = jnp.quantile(depth_val, 1-alpha_conf, axis = 1)

In [202]:
# check coverage (no ensemble req)
depth_test = jnp.diag(lsci.phi_tukey(phi_val, phi_test, local_weights))
jnp.mean(depth_test > quant_val) #0.901

Array(0.90150005, dtype=float32)

In [177]:
# sample ensemble
