In [1]:
import os
import numpy as np

import matplotlib
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from jax import jacfwd, jacrev
from jax import vmap, grad, jit, random

# from lsci.conformal import lsci

In [2]:
os.chdir('../src/lsci/conformal/')
import lsci2
os.chdir('../../../project_scripts')

In [19]:
n, p = 2000, 100

data_rng = random.key(0)
data_keys = random.split(data_rng, 6)

xval = random.normal(data_keys[0], (n, p))
yval = 1 + random.normal(data_keys[1], (n, p))
yval_hat = 1.2 + random.normal(data_keys[2], (n, p))

xtest = random.normal(data_keys[3], (n//2, p))
ytest = 1 + random.normal(data_keys[4], (n//2, p))
ytest_hat = 1.2 + random.normal(data_keys[5], (n//2, p))

rval = yval - yval_hat
rtest = ytest - ytest_hat

In [21]:
# pre-compute local weights
local_weights = lsci2.localize(xval, xtest, 5)

# evaluate phi-depth
n_proj = 5
depth_val = lsci2.local_phi_depth(rval, rval, local_weights, n_proj)
depth_test = lsci2.local_phi_depth(rval, rtest, local_weights, n_proj, reduce = True)

# conformal quantiles
alpha = 0.1
quant_val = lsci2.local_quantile(depth_val, alpha)

# check coverage
jnp.mean(depth_test > quant_val) # 0.901

Array(0.901, dtype=float32)

In [22]:
# sample ensemble at test point X_i
i = 10
n_samp = 5000
ens = lsci2.local_sampler(rval, local_weights[i], alpha, n_samp, n_proj)
ens.shape

(5000, 100)

In [4]:
# localizers
method_rng = random.key(1)
method_keys = random.split(method_rng, 3)
n_proj = 20

local_weights = lsci2.localize(xval, xtest, 5)

# phi projectors
depth_val = lsci2.local_phi_depth(rval, rval, local_weights, n_proj, rng = method_keys[1], proj_fn = 'rand', depth_fn = 'tukey')
depth_test = lsci2.local_phi_depth(rval, rtest, local_weights, n_proj, reduce = True, rng = method_keys[1], proj_fn = 'rand', depth_fn = 'tukey')

# conformal cutoffs
alpha = 0.1
quant_val = lsci2.local_quantile(depth_val, alpha)
jnp.mean(depth_test > quant_val) # 0.905

Array(0.90500003, dtype=float32)

In [5]:
# ensemble samplers
n_samp = 5000
ens = lsci2.local_sampler(rval, local_weights[0], alpha, n_samp, n_proj)
ens.shape

(5000, 25)

In [None]:
# localizers

## three localizers + a wrapper function localize()

# phi projectors

## five projectors + a wrapper function project()

# depth functions

## three depth functions + a wrapper function depth()

# ensemble samplers

## one sampler function sample()

# depth filters

## one filtering function filter()