# Introduction

This notebooks gets the data from an experiment, which includes:
1. the forward passes of a network on validation data
2. the golden labels of the validation data

And generates and saves reliability metrics into the same path as the experiment path

In [8]:
import os
cwd = os.getcwd()

# protection against running this cell multiple times
assert os.path.dirname(cwd).split('/')[-1] == 'master-thesis','Oops, directory already changed previously as indended. Ignoring...'

# change working directory (if assert passed)
new_cwd = os.path.dirname(cwd) # parent directory
os.chdir(new_cwd)

AssertionError: Oops, directory already changed previously as indended. Ignoring...

# Imports

In [9]:
%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np


import os
import json

from rnn_utils import DiagnosesDataset, split_dataset, MYCOLLATE
from rnn_utils import train_one_epoch, eval_model
from Logits2Predictions import Logits2Predictions

from mourga_variational.variational_rnn import VariationalRNN
from utils import plot_reliability,get_prediction_thresholds

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence

import math
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score

idx = pd.IndexSlice

from config import Settings; settings = Settings()

# Parameters

In [10]:
experiment_path = 'data/deterministic/C'
logits_path = os.path.join(experiment_path,'deterministic_forward.csv')
golden_path = os.path.join(experiment_path,'golden.csv')

save_path = os.path.join(experiment_path,'reliability.csv')

In [63]:
logits_df = pd.read_csv(logits_path,index_col=[0,1]).sort_index()
golden_df = pd.read_csv(golden_path)

# Define prediction method

In [12]:
# prediction based on ROC geometric mean
from Abstention.utils import get_prediction_thresholds

thresholds = get_prediction_thresholds(logits_df,golden_df)
thresholds.head(3)
thresholds.shape

Unnamed: 0,threshold,gmean (roc)
diag_0,0.514493,0.568329
diag_1,0.524683,0.499588
diag_2,0.518353,0.533982


(272, 2)

make the predictions

In [13]:
def predict(predictions: pd.Series, threshold : float):
    return predictions.apply(lambda x: 1 if x > threshold else 0)

predictions_df = logits_df.apply(lambda x: predict(x,thresholds.loc[x.name,'threshold']),axis=0)

predictions_df.head(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_262,diag_263,diag_264,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271
pat_id,adm_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
21,1,1,1,1,0,0,1,0,1,1,1,...,1,0,0,0,0,1,0,0,0,0
23,1,0,0,1,0,0,1,1,1,1,1,...,1,0,0,0,1,0,1,1,0,1
61,1,0,1,0,0,1,0,1,1,1,0,...,1,0,0,0,0,1,0,0,0,0


# Create ECE for each diagnostic

In [19]:
np.array([True,False]).astype(int)

array([1, 0])

In [49]:
np.array([[1,2],[3,4]]).mean(axis=0)

array([2., 3.])

In [68]:
logits_df.max().max()

0.58227813

In [61]:
confidences = logits_df.where(preds==1,1-logits,axis=0)
confidences

Unnamed: 0_level_0,Unnamed: 1_level_0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_262,diag_263,diag_264,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271
pat_id,adm_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
21,1,0.535628,0.525377,0.527379,0.497576,0.488512,0.485052,0.479470,0.513265,0.488949,0.526202,...,0.485728,0.464404,0.486183,0.505239,0.518107,0.519951,0.496050,0.491190,0.512884,0.487418
23,1,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783,...,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783,0.488783
61,1,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861,...,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861,0.486861
94,1,0.516627,0.524240,0.517302,0.493810,0.490940,0.496106,0.476769,0.508379,0.483581,0.518348,...,0.485349,0.477870,0.485690,0.514437,0.519057,0.518302,0.506622,0.504196,0.516316,0.497733
105,1,0.517499,0.526581,0.516084,0.491584,0.479159,0.480493,0.488201,0.509726,0.500441,0.517651,...,0.487531,0.465359,0.477943,0.514071,0.520461,0.519138,0.499049,0.504346,0.518888,0.485601
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99383,3,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588,...,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588,0.487588
99650,1,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558,...,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558,0.487558
99650,2,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668,...,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668,0.487668
99756,1,0.515934,0.532526,0.519323,0.491696,0.491541,0.490135,0.486393,0.509251,0.495549,0.503772,...,0.493461,0.458696,0.483413,0.513972,0.527434,0.509724,0.510397,0.493296,0.527507,0.481414


In [54]:
def ECE(logits,preds,goldens,nbins =10):
    
    # confidences of predicted class, not positive class
    confidences = logits.where(preds==1,1-logits)
    
    accuracies = preds == goldens.to_numpy()
    
    ece = np.zeros(shape=(accuracies.shape[1],))
    
    bins = np.linspace(0,1,nbins+1)
    for left,right in zip(bins[:-1],bins[1:]):
        
        in_bin = ((confidences > left) & (confidences < right)).values
        
        acc_in_bin = accuracies[in_bin].mean(axis=0)
        np.nan_to_num(acc_in_bin,nan=0)
        
        avg_confidence_in_bin = confidences[in_bin].mean()
        np.nan_to_num(avg_confidence_in_bin,nan=0)
        
        weight = in_bin.astype(int).sum(axis=0) / preds.shape[0]
        
        ece += weight * abs(acc_in_bin - avg_confidence_in_bin)
    
    return ece
    

In [55]:
ECE(logits,preds,goldens)

IndexError: tuple index out of range

In [43]:
def ece(logits: pd.Series, preds : pd.Series, goldens : pd.Series, nbins:int = 10):
    
    # confidences of predicted class, not positive class
    confidences = logits.where(preds==1,1-logits)
    
    accuracies = preds == goldens.to_numpy()
    
    ece = 0
    
    bins = np.linspace(0,1,nbins+1)
    for left,right in zip(bins[:-1],bins[1:]):
        
        in_bin = ((confidences > left) & (confidences < right)).values
        
        acc_in_bin = accuracies[in_bin].mean() if in_bin.sum() > 0 else 0 # prevents nans
        
        avg_confidence_in_bin = confidences[in_bin].mean() if in_bin.sum() > 0 else 0 # prevents nans
        
        weight = in_bin.astype(int).sum() / preds.shape[0]
        
        ece += weight * abs(acc_in_bin - avg_confidence_in_bin)
    
    return ece
        

In [44]:
all_res = None
for diag in logits_df.columns:
    
    logits = logits_df.loc[:,diag]
    goldens = golden_df.loc[:,diag]
    preds = predictions_df.loc[:,diag]
    res = ece(logits, preds, goldens)
    break
res

0.4997603291507977

In [23]:
def get_reliability_metrics(logits: pd.Series, preds : pd.Series, goldens : pd.Series, nbins:int = 10):
    """
    Computes:
    1. Expected calibration error (ECE)
    2. Accuracy per bin
    3. Relative frequency of positive examples per bin
    4. % & Nº of samples of predicted class per bin
    5. % & Nº of samples of positive class per bin
    """
    
    confidences_predicted_class = logits.where(preds==1, 1-logits)
    confidences_positive_class = logits
    
    accuracies = preds == goldens.to_numpy()

    acc_in_bin_list = list()
    perc_samples_predicted_in_bin_list = list() # percentage of samples in bin
    perc_samples_positive_in_bin_list = list()
    rel_freq_positive_examples_in_bin_list = list()
    samples_in_predicted_bin_list = list()
    samples_in_positive_bin_list = list()
    ece = 0
    
    bins = np.linspace(0,1,nbins+1)
    for left,right in zip(bins[:-1],bins[1:]):

        # which examples have predicted-confidence in this bin?
        in_bin_predicted_mask = ((confidences_predicted_class > left) & (confidences_predicted_class < right)).values
        
        # which examples have positive-confidence in this bin?
        in_bin_positive_mask = ((confidences_positive_class > left) & (confidences_positive_class < right)).values
        
        # are there any examples with predicted-confidence in this bin?
        any_in_bin_predicted = in_bin_predicted_mask.sum() > 0
        
        # Count the Nº of predicted-confidence examples and positive-confidence samples in the bin
        samples_in_predicted_bin = in_bin_predicted_mask.sum()
        samples_in_positive_bin = in_bin_positive_mask.sum()
        perc_samples_predicted_in_bin = in_bin_predicted_mask.mean()
        perc_samples_positive_in_bin = in_bin_positive_mask.mean()
        
        # Out of the examples positive-confidence in this bin, how many are actually positive?
        rel_freq_positive_examples_in_bin = goldens[in_bin_positive_mask].mean()
        if math.isnan(rel_freq_positive_examples_in_bin):
            rel_freq_positive_examples_in_bin = 0
            
        # What is the accuracy in the predicted-confidence bin?
        acc_in_bin = accuracies[in_bin_predicted_mask].mean() if any_in_bin_predicted else 0
        
        # These are redundant, should delete later prob...
        conf_predicted_class_in_bin = confidences_predicted_class[in_bin_predicted_mask].mean()
        conf_positive_class_in_bin = confidences_positive_class[in_bin_positive_mask].mean()
        
        # update ECE
        if any_in_bin_predicted:
            # to be sure we don't turn ece into a np.nan
            ece += in_bin_predicted_mask.sum()/preds.shape[0] * abs(acc_in_bin - conf_predicted_class_in_bin)
        
        # Save everything
        acc_in_bin_list.append(acc_in_bin)
        rel_freq_positive_examples_in_bin_list.append(rel_freq_positive_examples_in_bin)
        perc_samples_predicted_in_bin_list.append(perc_samples_predicted_in_bin)
        perc_samples_positive_in_bin_list.append(perc_samples_positive_in_bin)
        samples_in_predicted_bin_list.append(samples_in_predicted_bin)
        samples_in_positive_bin_list.append(samples_in_positive_bin)
        
    return {'ece':ece,
            'accuracies':acc_in_bin_list,
            'rel_freq_positive_examples':rel_freq_positive_examples_in_bin_list,
            'n_samples_predicted_class':samples_in_predicted_bin_list,
            'perc_samples_predicted_class':perc_samples_predicted_in_bin_list,
            'n_samples_positive_class':samples_in_positive_bin_list,
            'perc_samples_positive_class':perc_samples_positive_in_bin_list,
            'nbins':nbins}

# Process the data and save

In [24]:
all_res = None
for diag in logits_df.columns:
    
    logits = logits_df.loc[:,diag]
    goldens = golden_df.loc[:,diag]
    preds = predictions_df.loc[:,diag]
    res = get_reliability_metrics(logits, preds, goldens)
    res['diag'] = diag
    
    res = pd.DataFrame(res)
    
    all_res = res if all_res is None else pd.concat([all_res,res],axis=0)
    
all_res.index.name = 'bin'
all_res = all_res.reset_index().set_index(['diag','bin']).sort_index()

all_res.head(5)

all_res.to_csv(save_path)

Unnamed: 0_level_0,Unnamed: 1_level_0,ece,accuracies,rel_freq_positive_examples,n_samples_predicted_class,perc_samples_predicted_class,n_samples_positive_class,perc_samples_positive_class,nbins
diag,bin,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
diag_0,0,0.49976,0.0,0.0,0,0.0,0,0.0,10
diag_0,1,0.49976,0.0,0.0,0,0.0,0,0.0,10
diag_0,2,0.49976,0.0,0.0,0,0.0,0,0.0,10
diag_0,3,0.49976,0.0,0.0,0,0.0,0,0.0,10
diag_0,4,0.49976,0.996939,0.041667,980,0.504375,24,0.012352,10
