In [1]:
import sys, os; sys.path.append(os.path.abspath("../"));

from importlib import reload
import pandas as pd, numpy as np, torch
import _settings
import utils.utils as utils
import data.dataloader as dld; reload(dld)
import models.regmodel as regmodel; reload(regmodel)
import models.conformal as conformal; reload(conformal)
import demos.demo as demo; reload(demo)
import demos.regression as reg; reload(reg)
import demos.experiments as exp; reload(exp)
import utils.eval_utils as eval_utils; reload(eval_utils)
import matplotlib.pyplot as plt

DATA_PATH = _settings.DATA_PATH
CACHE_PATH = os.path.join(_settings.WORKSPACE, 'Baselines')

In [2]:
dataset = _settings.YACHT_NAME #UCI_Yacht
all_dfs = {}
nseeds = 5 #Since this is a short demo, set nseeds=5 instead

ALPHA=0.5

## Save the embedding if necessary

In [3]:
import data.preprocess_small_datasets
for seed in range(nseeds):
    data.preprocess_small_datasets.cache(dataset, seed=seed, init=True, quiet=True)

train: MSE=9.790850553145379, Data Var=253.94810230533236
val: MSE=12.093842731329907, Data Var=163.29061956463318
test: MSE=9.664938856408282, Data Var=215.62144734651403
train: MSE=9.319103150060833, Data Var=212.57978120087654
val: MSE=11.409245221138145, Data Var=207.29003988175222
test: MSE=13.038572486105382, Data Var=299.5079933662851
train: MSE=9.731203980266105, Data Var=235.18683325054786
val: MSE=12.995070960531555, Data Var=235.29841187852728
test: MSE=8.605344322018196, Data Var=202.6004189386056
train: MSE=13.06395876716945, Data Var=240.29881780569755
val: MSE=9.80921976241949, Data Var=204.04668185971508
test: MSE=6.742651608643098, Data Var=220.0469747398543
train: MSE=15.451058040141035, Data Var=227.67842377209644
val: MSE=12.594747705440264, Data Var=270.1362748723461
test: MSE=15.197977455868411, Data Var=189.91275296566084


## Train $K_\mathbf{f}$ and evaluate the results for LVD

In [4]:
dfs = {}

datakwargs = {'model_setting': 0, 'train_split': dld.TRAIN, 'val_split': dld.VALID, 'test_split': dld.TEST}
default_fitkwargs = exp.get_default_fitkwargs()
#{'d': 10, 'n_iters': 1000, 'max_n': 3000, 'batch_size': 100, 'lr': 1e-2, 'stop_iters':50, 'norm': True, 'ybar_bias': True}
for seed in range(nseeds):
    datakwargs['seed'] = seed
    fitkwargs = utils.merge_dict_inline(default_fitkwargs, {"seed":seed})
    dfs[seed] = demo.eval_exp_cached(dataset, datakwargs, regmodel._KERNEL_MLKR, fitkwargs, alpha=ALPHA)
    #dfs[seed] = demo.eval_exp(dataset, datakwargs, regmodel._KERNEL_MLKR, fitkwargs, alpha=ALPHA)
all_dfs['LVD'] = dfs

## Add marginally valid baselines

In [5]:
conf_baselines = ['MADSplit', 'CQR']
for m in conf_baselines:
    other_dfs = {}
    for seed in range(nseeds):
        print(m, seed)
        datakwargs['seed'] = seed
        other_dfs[seed] = demo.conformal_baselines_cached(m, dataset, datakwargs, 0, alpha=ALPHA)
    all_dfs[m] = other_dfs

MADSplit 0
MADSplit 1
MADSplit 2
MADSplit 3
MADSplit 4
CQR 0
CQR 1
CQR 2
CQR 3
CQR 4


## Add other baselines

In [6]:
baselines = ["DJ", "DE", "MCDP", "PBP"]
params = dict({"activation": 'ReLU', "num_hidden": 100, "num_layers": 1})
train_params = dict({"num_iter": 1000,  "learning_rate": 1e-3})
results = exp.run_experiments(baselines, [dataset], N_exp=nseeds, damp=1e-2, 
                              mode='exact', 
                              coverage=1-ALPHA, 
                              params=params, train_params=train_params,
                              data_path=DATA_PATH, cache_path = CACHE_PATH)

for m in results.keys():
    all_dfs[m] = results[m][dataset]

Running experiments on dataset:  UCI_Yacht
Exp: 0
Exp: 1
Exp: 2
Exp: 3
Exp: 4


## Show some examples and compute the stats

In [7]:
all_dfs['LVD'][0]

Unnamed: 0,lo,hi,y,yhat,extra,index
0,-0.710170,0.164930,0.16,-0.272620,,281
1,0.330634,1.136433,0.40,0.733533,,142
2,33.691284,48.358576,41.77,41.024930,,264
3,3.605136,5.195530,4.09,4.400333,,91
4,33.660430,48.327722,34.50,40.994076,,82
...,...,...,...,...,...,...
57,34.206027,52.364293,50.94,43.285160,,251
58,12.695515,19.539047,14.11,16.117281,,192
59,1.384729,2.190527,1.76,1.787628,,117
60,1.421465,2.227264,1.97,1.824364,,47


In [8]:
summs = eval_utils.summary_by_dataset(all_dfs, alpha=ALPHA, filter_idx=True)[1].reindex(columns=['LVD'] + conf_baselines + baselines)
summs

100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 15.63it/s]


method,LVD,MADSplit,CQR,DJ,DE,MCDP,PBP
measure,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
AUROC,0.8642(0.0406)[9.56e-01>CQR],0.8171(0.0818),0.8622(0.0653),0.4821(0.1007),0.6000(0.0520),0.4814(0.0724)[9.90e-01<DJ],0.7971(0.0494)
Corr(noninf),0.8042(0.1004),0.7411(0.0595),0.8086(0.0452)[9.32e-01>LVD],0.0309(0.1090),-0.1170(0.0312)[3.60e-02<DJ],0.2877(0.0864),0.6039(0.0935)
MSE,10.6499(3.4208),10.6499(3.4208),9.5419(2.3385),192.8352(43.0033),229.4146(43.9959)[7.60e-01>MCDP],220.6949(43.4040),7.1064(2.7385)[1.70e-01<CQR]
R^2,0.9517(0.0186),0.9517(0.0186),0.9572(0.0100),0.1485(0.0271),-0.0171(0.0086)[1.53e-02<MCDP],0.0220(0.0229),0.9689(0.0076)[7.38e-02>CQR]
cnt_width(noninf),62.0000(0.0000)[nan<PBP][nan>PBP],62.0000(0.0000)[nan<PBP][nan>PBP],62.0000(0.0000)[nan<PBP][nan>PBP],62.0000(0.0000)[nan<PBP][nan>PBP],62.0000(0.0000)[nan<PBP][nan>PBP],62.0000(0.0000)[nan<PBP][nan>PBP],62.0000(0.0000)[nan<PBP][nan>PBP]
cover,53.5484(13.1726),48.0645(9.0811),56.4516(7.3027),77.0968(5.8600)[1.10e-01>PBP],7.7419(2.3923)[1.20e-03<MCDP],26.4516(6.0992),62.2581(16.0725)
cover(noninf),53.5484(13.1726),48.0645(9.0811),56.4516(7.3027),77.0968(5.8600)[1.10e-01>PBP],7.7419(2.3923)[1.20e-03<MCDP],26.4516(6.0992),62.2581(16.0725)
cover(tail),47.6923(20.6406)[7.58e-01>DJ],36.9231(18.3651),41.5385(20.0591),44.6154(3.4401),0.0000(0.0000)[1.09e-02<MADSplit],0.0000(0.0000)[1.09e-02<MADSplit],44.6154(31.9022)
mean_width(noninf),3.9839(0.9393),3.2634(1.0338),2.7971(0.3539),18.8169(1.3704)[8.36e-04>MCDP],4.9450(0.6949),14.5509(0.7003),2.4775(0.5573)[3.16e-01<CQR]
resid,1.8343(0.5859),1.8343(0.5859),1.3967(0.2135)[6.14e-01<PBP],10.5181(0.8960),11.3251(0.7319)[9.91e-01>MCDP],11.3190(0.8491),1.4695(0.2251)
