## Uncertainty estimation for regression

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn import metrics

from alpaca.utils.datasets.builder import build_dataset
from alpaca.utils.ue_metrics import get_uq_metrics, ndcg, uq_ll
from alpaca.ue.masks import BasicBernoulliMask, DecorrelationMask, LeverageScoreMask
from alpaca.utils import model_builder
import alpaca.nn as ann

from nuq import NuqRegressor

In [2]:
datasets = ["red_wine", "kin8nm", "concrete" , "ccpp"]#, "boston_housing", "naval_propulsion"]
for dataset_name in datasets:
    dataset = build_dataset(dataset_name, val_split=0.1)
    x_train, y_train = dataset.dataset('train')
    x_val, y_val = dataset.dataset('val')
    
    regressor = NuqRegressor()
    regressor.fit(x_train, y_train.reshape(-1))
    
    uncertainty = regressor.predict_uncertainty(x_val, infinity=100)
    predictions = regressor.predict(x_val)
    
    del regressor
    
    acc, ndcg, ll = get_uq_metrics(uncertainty["total"], predictions)
    print("====================================================================================")
    print(dataset_name)
    print("Mean squared error: ", metrics.mean_squared_error(y_val, predictions))
    print("Uncertainty quality by \naccuracy: ", acc, "\nndcg: ", ndcg, "\nlog-likelihood: ", ll)
    print("====================================================================================")

red_wine
Mean squared error:  0.6579511772658864
Uncertainty quality by 
accuracy:  0.125 
ndcg:  0.23796303677132144 
log-likelihood:  -16815303530.432201
kin8nm
Mean squared error:  0.016264685238861175
Uncertainty quality by 
accuracy:  0.04878048780487805 
ndcg:  0.1801923825519764 
log-likelihood:  -4.495082825678347
concrete
Mean squared error:  58.66329340190997
Uncertainty quality by 
accuracy:  0.2 
ndcg:  0.18455449989700945 
log-likelihood:  -4.67850019953174
ccpp
Mean squared error:  18.38311967807019
Uncertainty quality by 
accuracy:  0.10526315789473684 
ndcg:  0.22586763243253938 
log-likelihood:  -51.19594847037606
