<a href="https://colab.research.google.com/github/RSNA/AI-Deep-Learning-Lab-2024/blob/main/sessions/uncertainty-quant/UQ_RSNA2024.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#RSNA Deep Learning Lab: Quantifying Uncertainty in Deep Learning
##Shahriar Faghani MD, Mana Moassefi MD, Bradley J. Erickson, MD., PhD.
###Radiology Informatics Lab, Department of Radiology, Mayo Clinic, Rochester, MN

## Packages and Imports
In this tutorial, you will learn how to train uncertainty aware models using Enemble, Monte Carlo Dropout, Evidential, and Conformal techniques. To start with this notebook, you need to install three packages:
- MONAI: A framework for deep learning in medical image analysis.
- timm: A rich collection of many deep learning models pretrained on natural images.
- NetCal: A package for model output calibration.

In [None]:
!pip install netcal timm monai -q --no-deps

In [None]:
from glob import glob
import random
import os
import copy
from tqdm import tqdm
import monai as mn
import numpy as np
import torch
import timm
import gdown
import shutil

from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.calibration import calibration_curve

## Data Preperation
For this notebook we will use a dataset from [Kermany et. al.](https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia?datasetId=17810) that has a set of more than 5,000 pediatric chest radiographs, with and without pneumonia. We use MONAI for dataset preperation, as it has several medical image-specific augmentations and enables using them with a very simple interface.

We create our datasets and dataloaders once here and use them in all of the upcomming experiments. In brief, the following cell takes care of the following:

1. Gets the list of files in the train, validation, and test folders of the dataset.
2. Creates a dictionary of these file paths with their corresponding label (normal or pneumonia) to leverage MONAI's dictionary-based augmentation pipeline.
3. Defines training and validation transforms.
4. Creates datasets and dataloaders for each of the three sets (train, validation, and test).

In [None]:
# Downloading the zip archive containing training images
if not os.path.isdir('chest_xray'):
    gdown.download(
        "https://drive.google.com/uc?export=download&confirm=pbef&id=1L8ox5fIwb_PijLcPEofQyhe3oGiYESO2",
        "chest_xray.zip",
        quiet=True
    )
    !unzip -q chest_xray.zip
    os.remove('chest_xray.zip')

In [None]:
# Global Variables
DATA_DIR = 'chest_xray'
MONAI_CACHE_DIR = 'chest_xray/cache_UQ'
IMG_SIZE = 256
BATCH_SIZE = 32
EPOCHS = 3
DROP_RATE = 0.3
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MODEL_ARCHITECTURE = "convnext_small"

# 1. Gets the list of files in the train, validation, and test folders of the dataset.
train_normal = glob(f"{DATA_DIR}/train/NORMAL/*.jpeg")
train_pneumonia = glob(f"{DATA_DIR}/train/PNEUMONIA/*.jpeg")
os.mkdir(f"{DATA_DIR}/val")
os.mkdir(f"{DATA_DIR}/val/NORMAL")
os.mkdir(f"{DATA_DIR}/val/PNEUMONIA")
for file in train_normal[:10]:
  shutil.move(file,(f"{DATA_DIR}/val/NORMAL"))
for file in train_pneumonia[:10]:
  shutil.move(file,(f"{DATA_DIR}/val/PNEUMONIA"))
train_files = glob(f"{DATA_DIR}/train/*/*.jpeg")
val_files = glob(f"{DATA_DIR}/val/*/*.jpeg")
test_files = glob(f"{DATA_DIR}/test/*/*.jpeg")

# 2. Creates a dictionary of image paths and labels.
def get_data_dict(files:list)-> dict:
    """creates a dictionary of image paths and labels

    Args:
        files (list): list of image paths

    Returns:
        dict: dictionary of image paths and labels
    """
    final_dict = []
    for file in files:
        label = file.split("/")[-2]
        if label == "NORMAL":
            label = 0
        else:
            label = 1
        final_dict.append({"img": file, "label": label})

    return final_dict

train_data = get_data_dict(train_files)
val_data = get_data_dict(val_files)
test_data = get_data_dict(test_files)

# 3. Defines training and validation transforms.
train_transforms = mn.transforms.Compose([
    mn.transforms.LoadImaged(keys=["img"]),
    mn.transforms.EnsureChannelFirstD(keys=["img"]),
    mn.transforms.LambdaD(keys=["img"], func=lambda x: x[0:1, :,:]),
    mn.transforms.ScaleIntensityd(keys=["img"],minv=0.0, maxv=1.0,),
    mn.transforms.ResizeD(keys=["img"], spatial_size=IMG_SIZE, size_mode="longest"),
    mn.transforms.SpatialPadD(keys=["img"], spatial_size=IMG_SIZE, method="symmetric"),
    mn.transforms.RandAffineD(keys='img', rotate_range=0.25, translate_range=int(IMG_SIZE * 0.05), scale_range=0.1, mode="bilinear", padding_mode="zeros", prob=0.8),
    mn.transforms.RandFlipD(keys='img', spatial_axis=1, prob=0.5),
    mn.transforms.RandGaussianNoiseD(keys='img', mean=0.5, std=0.2, prob=0.5),
    mn.transforms.ToTensorD(keys=["img"]),
    mn.transforms.ToTensorD(keys=["label"], dtype=torch.long),
    mn.transforms.SelectItemsD(keys=["img","label"]),
])

val_transforms = mn.transforms.Compose([
    mn.transforms.LoadImaged(keys=["img"]),
    mn.transforms.EnsureChannelFirstD(keys=["img"]),
    mn.transforms.LambdaD(keys=["img"], func=lambda x: x[0:1, :,:]),
    mn.transforms.ResizeD(keys=["img"], spatial_size=IMG_SIZE, size_mode="longest"),
    mn.transforms.SpatialPadD(keys=["img"], spatial_size=IMG_SIZE, method="symmetric"),
    mn.transforms.ToTensorD(keys=["img"]),
    mn.transforms.ToTensorD(keys=["label"], dtype=torch.long),
    mn.transforms.SelectItemsD(keys=["img","label"]),
])

# 4. Creates a dataset and a dataloader.
train_ds = mn.data.PersistentDataset(data=train_data[:2000], transform=train_transforms, cache_dir=MONAI_CACHE_DIR)
val_ds = mn.data.PersistentDataset(data=test_data, transform=val_transforms, cache_dir=MONAI_CACHE_DIR)
test_ds = mn.data.Dataset(data=val_data, transform=val_transforms)

def get_dataloaders(train_ds=train_ds, val_ds=val_ds, test_ds=test_ds, batch_size=BATCH_SIZE):
    """Creates a dataloader for the train, validation, and test datasets.

    Args:
        train_ds (PersistentDataset, optional): Training dataset. Defaults to train_ds.
        val_ds (PersistentDataset, optional): Validation dataset. Defaults to val_ds.
        test_ds (Dataset, optional): Test dataset. Defaults to test_ds.
        batch_size (int, optional): Batch size. Defaults to BATCH_SIZE.

    Returns:
        DataLoader: Dataloader for the train dataset.
        DataLoader: Dataloader for the validation dataset.
        DataLoader: Dataloader for the test dataset.
    """
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4, worker_init_fn=seed_worker)
    val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2, worker_init_fn=seed_worker)
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=False, drop_last=False, num_workers=2, worker_init_fn=seed_worker)
    return train_loader, val_loader, test_loader

## Model Training
In this tutorial we will be dealing with several models and need to have a robust pipeline for training our deep learning models. In order to facilitate this process, we will create a function that takes care of training our models given the architecture, training seed and loss function as parameters. First we have to have some helper functions that take care of seeding training parameters (including the augmentation pipeline), calculating area under receiver operating curve (ROC), and training on a single batch:

In [None]:
def seed_everything(seed: int = 42):
    """sets the seed for all libraries

    Args:
        seed (int, optional): seed value. Defaults to 42.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    mn.utils.misc.set_determinism(seed=seed)


class AUCMetric(object):
    """
    Computes the Area Under the ROC (AUROC) for binary classification.
    """
    def __init__(self):
        self.y_preds = []
        self.y_trues = []

    def __call__(self, y_pred, y_true):
        self.y_preds.append(y_pred)
        self.y_trues.append(y_true)

    def reset(self):
        self.y_preds = []
        self.y_trues = []

    def compute(self):
        y_preds = torch.cat(self.y_preds, dim=0).cpu().numpy()[:, 1]
        y_trues = torch.cat(self.y_trues, dim=0).cpu().numpy()
        auc = roc_auc_score(y_trues, y_preds)
        return auc

def one_step(model, batch, loss_fn, optimizer, device: str = "cuda:0", training: bool = True):
    """trains the model for one step

    Args:
        model (torch.nn.Module): model
        batch (dict): batch of data
        loss_fn (torch.nn.Module): loss function
        optimizer (torch.optim.Optimizer): optimizer
        training (bool, optional): whether the model is in training mode or not. Defaults to True.

    Returns:
        float: loss value
    """
    x = batch["img"].to(device)
    y = batch["label"].to(device)

    if training:
        model.train()
    else:
        model.eval()

    with torch.set_grad_enabled(training):
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if training:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    return loss.item(), torch.softmax(y_pred, dim=-1)

Now we can write our generic training function:

In [None]:
def train_model(architecture, drop_rate, loss_fn, weight_filename: str = "best_model", num_classes: int = 2, training_seed: int = 42, epochs: int = 10, device: str = "cuda:0"):
    """trains the model

    Args:
        architecture (torch.nn.Module): model architecture
        train_dataloader (torch.utils.data.DataLoader): train dataloader
        val_dataloader (torch.utils.data.DataLoader): validation dataloader
        loss_fn (torch.nn.Module): loss function
        training_seed (int): seed value
        epochs (int, optional): number of epochs. Defaults to 10.
        device (str, optional): device to use. Defaults to "cuda:0".

    Returns:
        torch.nn.Module: best weights of the trained model
    """
    # some hyperparameters that you can use to tune your model
    lr = 1e-4
    weight_decay = 1e-5

    seed_everything(training_seed)

    model = timm.create_model(architecture, drop_rate=drop_rate, pretrained=False, in_chans=1, num_classes=num_classes).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_auc = 0
    best_model = None
    validation_auc = AUCMetric()

    train_dataloader, val_dataloader, _ = get_dataloaders()
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        print("-" * 10)
        train_losses = []
        val_losses = []
        for batch in tqdm(train_dataloader, desc="training"):
            train_loss, _ = one_step(model, batch, loss_fn, optimizer, device=device, training=True)
            train_losses.append(train_loss)

        validation_auc.reset()
        for batch in tqdm(val_dataloader, desc="validating"):
            val_loss, y_pred = one_step(model, batch, loss_fn, optimizer, device=device, training=False)
            val_losses.append(val_loss)
            validation_auc(y_pred, batch["label"])

        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)
        val_auc = validation_auc.compute()
        print(f"train loss: {train_loss:.4f}")
        print(f"val loss: {val_loss:.4f}")
        print(f"val auc: {val_auc:.4f}")
        if val_auc > best_auc:
            best_auc = val_auc
            best_model = copy.deepcopy(model)
            print("saving best model")
            torch.save(best_model.state_dict(), f"{weight_filename}.pth")
        print("\n")

    return best_model

Let's train a simple model with `CrossEntropy` loss, which is the routine practice in many deep learning projects. `CrossEntropy` loss is particularly well-suited for binary classification tasks because it has a steeper gradient when the predicted label (ŷ) is far from the ground truth (y) compared to other popular loss function such as Mean Squared Error (MSE).

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
best_model = train_model(
    architecture=MODEL_ARCHITECTURE,
    drop_rate=DROP_RATE,
    loss_fn=loss_fn,
    weight_filename="CE_seed_42",
    training_seed=42,
    num_classes=2,
    epochs=EPOCHS,
    device=DEVICE
)

Now let's checkout our model's calibration curve:

In [None]:
def get_inference_results(weight_path, set_name='val'):
    model = timm.create_model(MODEL_ARCHITECTURE, pretrained=False, in_chans=1, num_classes=2).to(DEVICE)
    model.load_state_dict(torch.load(weight_path))
    model.eval()
    if set_name == 'val':
        _, dataloader, _ = get_dataloaders()
    elif set_name == 'test':
        _, _, dataloader = get_dataloaders()
    else:
        raise ValueError('set_name must be val or test')
    y_preds = []
    y_true = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            x = batch["img"].to(DEVICE)
            y = batch["label"].to(DEVICE)
            y_pred = model(x)
            y_preds.extend(torch.softmax(y_pred, dim=-1).cpu().numpy())
            y_true.extend(y.cpu().numpy())

    y_preds = np.array(y_preds)
    y_true = np.array(y_true)
    return np.array(y_preds), np.array(y_true)

y_preds, y_true = get_inference_results("CE_seed_42.pth", set_name='test')

In [None]:
def plot_calibration_curve(y_true, y_pred, n_bins=10, normalize=True, title="Calibration Curve", ax=None):
    """plots calibration curve

    Args:
        y_true (np.array): true labels
        y_pred (np.array): predicted labels
        n_bins (int, optional): number of bins. Defaults to 10.
        normalize (bool, optional): whether to normalize the data or not. Defaults to True.
        title (str, optional): title of the plot. Defaults to "Calibration Curve".
        ax (matplotlib.axes.Axes, optional): axes to plot on. Defaults to None.

    Returns:
        matplotlib.axes.Axes: axes
    """
    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=(10, 10))
    fraction_of_positives, mean_predicted_value = calibration_curve(y_true, y_pred, n_bins=n_bins)
    ax.plot(mean_predicted_value, fraction_of_positives, "s-")
    ax.plot([0, 1], [0, 1], "--", color="gray")
    ax.set_ylabel("Fraction of positives")
    ax.set_ylim([-0.05, 1.05])
    ax.set_xlabel("Mean predicted value")
    ax.set_title(title)
    return ax

plot_calibration_curve(y_true, y_preds[:, 1], title="Calibration Curve (before calibration)");

## Calibration

One of the cornerstones of uncertainty quantification is probability calibration. Traditionally, deep learning models output some logits ŷ for which each ŷ<sub>i</sub> ∈ ℝ. Then, we "activate" using a function like [Softmax](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html) or [Sigmoid](https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html) these ŷ such that each ŷ<sub>i</sub> is now bounded on the interval [0, 1]. Furthermore, if there are K classes, the sum over all K activated ŷ<sub>i</sub> equals 1. With these properties, we often interpret the activated outputs of a model as class probabilities. Concretely, in our task, if the ConvNet outputs a tensor of activated predictions ŷ = [0.3, 0.7], we would assign our sample i a label of `1`, since the argmax of ŷ is `1`, and we would say that our model assigns a 70% probability to the event that our sample's true class is `1`.

This, however, is a critical misconception. In reality, the activated outputs of our model, while they posses the mathematical properties of probabilties, cannot be interpreted as such until they are calibrated. In general, probability calibration refers to returning the *true* likelihood of an event. Hence, since our events can be defined as a sample having a label `0` and having a label `1`, probability calibration is the process of aligning our model's activated outputs with the *true* likelihood of the sample having each label.

To calibrate our model's predictions, we use the `netcal` package:

In [None]:
from netcal.binning import IsotonicRegression

calibrator = IsotonicRegression()
calibrator.fit(y_preds, y_true)
y_preds_calibrated = calibrator.transform(y_preds)
plot_calibration_curve(y_true, y_preds_calibrated);

## UQ Techniques: Ensemble
The ensemble technique relies on the instrinsic variability of deep learning models. An ensemble is a collection of models of the same architecture (usually, but not always) whose predictions are aggregated in a *soft* or *hard* manner to output a single point prediction (for classification tasks, that is). Soft prediction refers to the aggregation of the models' activated outputs, while hard aggregation accepts the models' class predictions (discrete integers 0,...,K-1). In this tutorial, we employ soft prediction.

Most critically, however, **the variance of the ensemble's predictions is calculated and interpreted as a measure of uncertainty**.

One analogy to help build intuition behind the ensemble UQ technique is a panel of doctors. Assume each doctor received his or her medical education from the same institution, completed their residency in the same hospital system, and now specializes in the same field. When presented with a sample, if the panel has unanimous or near-unanimous agreement about the sample's label, we would consider the panel to be highly certain. And conversely, if there are many dissenters or high variability among the opinions of each panelist, we would consider their ruling to be highly uncertain. Now, replace the panelists with deep learning models and medical school with a training dataset, and you understand the ensemble UQ technique!

The strength of the ensemble is its conceptual simplicity. The straightforwardness of variance as a measure of uncertainty can help people grasp the intuition of the ensemble without needing an advanced background in mathematics.

This simplicity, however, comes with a cost. The ensemble's primary weakness is that it is not mathematically rigorous. Of course, a UQ technique need not be mathematically complex nor convoluted to be effective, but the ensemble assigns a quantity an entirely unsupported interpretation when it considers variance a measure of uncertainty. Just because there are qualitative similarities between real-life "ensembles" and deep learning ensembles does not mean that we can place meaning where it does not exist. Hence, experts generally consider ensembles to be a naive approach to UQ. Ensembles are also significantly more computationally intensive than single-model UQ techniques. Nevertheless, let's try one out!

To use this technique, we have to train our model with different initialization seeds. For increased variance, we can also train each model on a different fold from our training set, though we do not demonstrate that in this tutorial. This process is known as *cross validation*.

Then we can run a single image through all of those models and get a mean and standard deviation of their prediction. *The higher the standard deviation, the more uncertain the models are.*

In [None]:
ensemble_models = [best_model]
training_seeds = [42, 43]
loss_fn = torch.nn.CrossEntropyLoss()
for seed in training_seeds[1:]: # we already trained the first model (seed=42)
    print("*" * 40)
    print(f"Training model with seed {seed}")
    print("*" * 40)
    ensemble_models.append(train_model(
        architecture=MODEL_ARCHITECTURE,
        loss_fn=loss_fn,
        drop_rate=DROP_RATE,
        weight_filename=f"seed_{seed}",
        training_seed=seed,
        num_classes=2,
        epochs=EPOCHS,
        device=DEVICE
    ))

In [None]:
_, _, test_dataloader = get_dataloaders()
for i, item in enumerate(test_dataloader):
    if i >= 3 and i < (len(test_dataloader) - 3):
        continue # we only need a few samples

    img = item["img"].to(DEVICE)
    label = item["label"].to(DEVICE)
    y_preds = []
    for model in tqdm(ensemble_models):
        model.eval()
        with torch.no_grad():
            y_pred = model(img)
            y_preds.append(torch.softmax(y_pred, dim=-1).cpu().numpy()[:, 1].item())
    y_preds = np.array(y_preds)
    mean_pred = np.mean(y_preds)
    std_pred = np.std(y_preds)
    print(f"Predicted label: {mean_pred:.4f} (SD: {std_pred:.4f})")
    print(f"True label: {label.item()}")
    print()

## UQ Techniques: Monte Carlo Dropout
Monte Carlo Dropout for UQ is mathematically similar to the ensemble technique because it relies on the variance of predictions to quantify uncertainty. The training process, however, requires only one model. It is at inference (test) time when variability is introduced Monte Carlo Dropout. Dropout refers to "turning off" randomly selected nodes in a model. By instantiating models with different nodes turned off, we can observe the variability of predictions within a single model's architecture. We perform dropout on a model several times, which gives us many instance of that model. Then, we use those unique instances to infer on a sample and calculate the variance of the instances' predictions to be the model's uncertainty.

Monte Carlo Dropout generally has comparatively poor performance, and it, like the ensemble technique, is not mathematically rigorous because it defines uncertainty as variance.

To use this technique, we turn on our model's drop layer(s) during inference and run the image through the model multiple times, this leads to having multiple predictions for a signle image. Similar to the ensemble technique we can have the mean and standard deviation of these predictions an use them as a measure of uncertainty. *The higher the standard deviation, the more uncertain a model is.*

In [None]:
model = best_model
model.eval()
number_of_iterations = 10 # number of iterations to run the dropout
# turn on dropout
for m in model.modules():
    if m.__class__.__name__.startswith("Dropout"):
        m.train()

_, _, test_dataloader = get_dataloaders()
for i, item in enumerate(test_dataloader):
    if i >= 3 and i < (len(test_dataloader) - 3):
        continue # we only need a few samples

    img = item["img"].to(DEVICE)
    label = item["label"].to(DEVICE)
    y_preds = []

    with torch.no_grad():
        for i in range(number_of_iterations):
            y_pred = model(img)
            y_preds.append(torch.softmax(y_pred, dim=-1).cpu().numpy()[:, 1].item())

    y_preds = np.array(y_preds)
    mean_pred = np.mean(y_preds)
    std_pred = np.std(y_preds)
    print(f"Predicted label: {mean_pred:.4f} (SD: {std_pred:.4f})")
    print(f"True label: {label.item()}")
    print()

## UQ Techniques: Evidential Deep Learning

Evidential Deep Learning (EDL) is a UQ technique that implements a custom loss function to calculate the uncertainty as inversely proportional (additively) to the amount of evidence obtained by the model. EDL transforms the training process into an evidence acquisition process by attempting to maximize the amount of evidence the model gathers for correct predictions, thus reducing its uncertainty when it accurately classifies a sample. EDL works by passing the logits generated by our model as the alpha parameters of a Dirichlet distribution. A Dirichlet distribution is an advanced statistics concept that is a second-order probability map. In essence, it gives information of about the "probability of probabilities." For example, we are all familiar with first-order probability questions such as a fair coin being flipped some number of times. But what about an unfair coin whose probability of heads varies with every flip? A Dirichlet distribution would describe the probability of a given chance of heads (e.g. P(P(heads) = 0.5)). Applying this to deep learning problems allows us to calculate a model's uncertainty (even during training!) in a way that is both mathematically rigorous and not computationally costly.

Using evidential deep learning (EDL) requires that our model be retrained with a new loss function.  The following cell demonstrates how this loss function is defined:

---
>
There are multiple ways to define the EDL loss function. For an exhaustive discussion of this topic, please refer to the [original paper](https://arxiv.org/pdf/1806.01768.pdf).

---

In [None]:
class EDLLossv1(torch.nn.Module):
    def __init__(self, num_classes=2, regr=1e-5, to_one_hot=True):
        self.regr = regr # regularization parameter
        self.one_hot = to_one_hot
        self.num_classes = num_classes

    def __call__(self, y_preds, y_true):
        if self.one_hot:
            y_true = torch.nn.functional.one_hot(y_true, num_classes=self.num_classes).float()
        evidence = torch.nn.functional.softplus(y_preds, beta=0.1)
        alpha = evidence + 1
        alpha_sum = torch.sum(alpha, dim=1, keepdim=True)
        entropy_reg = torch.distributions.dirichlet.Dirichlet(alpha).entropy()
        edl_loss = torch.sum(y_true * (torch.digamma(alpha_sum) - torch.digamma(alpha))) - self.regr * torch.sum(entropy_reg)
        return edl_loss

In [None]:
class EDLLossv2(torch.nn.Module):
    def __init__(self, annealing_step=0, num_classes=2):
        super().__init__()
        self.num_classes = num_classes
        self.annealing_step = annealing_step

    def _get_evidence(self, y):
        return torch.nn.functional.softplus(y)

    def _kl_divergence(self, alpha):
        device = alpha.device
        ones = torch.ones([1, self.num_classes], dtype=torch.float32, device=device)
        alpha_sum = torch.sum(alpha, dim=1, keepdim=True)
        first_term = (
            torch.lgamma(alpha_sum)
            - torch.lgamma(alpha).sum(dim=1, keepdim=True)
            + torch.lgamma(ones).sum(dim=1, keepdim=True)
            - torch.lgamma(ones.sum(dim=1, keepdim=True))
        )
        second_term = (
            (alpha - ones)
            .mul(torch.digamma(alpha) - torch.digamma(alpha_sum))
            .sum(dim=1, keepdim=True)
        )
        kl = first_term + second_term
        return kl

    def _loglikelihood_loss(self, y, alpha):
        device = alpha.device
        y = y.to(device)
        alpha = alpha.to(device)
        alpha_sum = torch.sum(alpha, dim=1, keepdim=True)
        loglikelihood_err = torch.sum((y - (alpha / alpha_sum)) ** 2, dim=1, keepdim=True)
        loglikelihood_var = torch.sum(alpha * (alpha_sum - alpha) / (alpha_sum * alpha_sum * (alpha_sum + 1)), dim=1, keepdim=True)
        loglikelihood = loglikelihood_err + loglikelihood_var
        return loglikelihood

    def __call__(self, output, target, epoch_num=None):
        evidence = self._get_evidence(output)
        alpha = evidence + 1
        device = alpha.device
        if target.ndim == 1:
            y = torch.nn.functional.one_hot(target, num_classes=self.num_classes).float().to(device)
        else:
            y = target.to(device)
        alpha = alpha.to(device)
        loglikelihood = self._loglikelihood_loss(y, alpha)
        if self.annealing_step!= 0:
            assert epoch_num is not None, "epoch_num must be provided when using annealing"
            annealing_coef = torch.min(
                torch.tensor(1.0, dtype=torch.float32),
                torch.tensor(epoch_num / self.annealing_step, dtype=torch.float32),
            )
        else:
            annealing_coef = 1.0
        kl_alpha = (alpha - 1) * (1 - y) + 1
        kl_div = annealing_coef * self._kl_divergence(kl_alpha)
        loss = (loglikelihood + kl_div).mean()
        return loss

# Dummy test
loss = EDLLossv2()
trues = torch.randint(0, 2, (10,))
preds = torch.randn(10, 2)
print(trues.shape, preds.shape)
loss(preds, trues)

Now let's retrain our model with the EDL loss function:

In [None]:
loss_fn = EDLLossv1()
edl_model = train_model(
    architecture=MODEL_ARCHITECTURE,
    loss_fn=loss_fn,
    drop_rate=DROP_RATE,
    weight_filename=f"edl_seed_42",
    training_seed=42,
    num_classes=2,
    epochs=3,
    device=DEVICE
)

Finally, we can apply our model to a single image and perform a handful of calculations to obtain the calibrated probabilities and the prediction's uncertainty:

---
>
This "handful of calculations" is described extensively in the original paper linked above. We do not go into elaborate detail here about the specific loss function because first, there are many EDL loss functions, and second, because the focus of this tutorial is a breadth-first exploration of uncertainty quantification, rather than a thorough explanation of the mathematics of EDL.

---

In [None]:
def get_edl_probs(y_preds):
    if type(y_preds) != torch.Tensor:
        y_preds = torch.tensor(y_preds)
    evidence = torch.nn.functional.softplus(y_preds, beta = 0.1)
    alpha = evidence + 1
    alpha_sum = torch.sum(alpha, dim=1, keepdim=True)
    probs = alpha / alpha_sum
    uncertainties = y_preds.shape[1] / alpha_sum
    return probs, uncertainties, alpha_sum

In [None]:
y_preds, y_true = get_inference_results("edl_seed_42.pth")
y_probs, uncertainties, alpha_sums = get_edl_probs(y_preds)

for label, prob, uncertainty in zip(y_true, y_probs, uncertainties):
    # print(label, prob, uncertainty)
    print(f"Label: {label} | Prediction: {prob.argmax()} | Probablity: {prob.max().item():.3f} | Uncertainty: {uncertainty.item():.6} ")

## UQ Techniques: Conformal Uncertainty

The initial step in utilizing conformal uncertainty is to assign an unconformal score (also known as an uncertainty score) to each data point in our calibration set (this is also called the fitting stage). The simplest way to define a score function is by subtracting the probabilities of the model's predictions from 1. Our calibration data scores will be sorted and saved as reference scores for our conformal predictor. Please note that we use `netcal` to calibrate our model's predictions and then use the calibrated predictions to calculate the reference scores (both of these processes are referred to as "calibration," but they refer to distinct processes)!

The next step is to predict the uncertainty of our model's previously made test predictions. In order to obtain a score for each of the test predictions, we employ the same scoring function as before. Then, we compare each test score to the previously established reference scores to determine how many of those reference scores are higher than the test score. The ratio of this number to the total number of reference scores yields the p-value for this particular prediction. (This is not the same value as the traditional statistical p-value!) The greater the p-value, the greater confidence we have in our model's prediction. Please note that since we will be utilizing Mondrian Conformal Prediction, we will repeat this procedure independently for each of the class probabilities in our model's predictions.


We then compare the p-values of the predictions to a threshold value (usually 0.05 or 0.1) to determine whether the model's prediction is sufficiently confident. If the p-value is greater than the threshold, we can conclude that the model is confident in its prediction and use it accordingly. If the p-value is less than the threshold, we can conclude that the model's prediction is not supported by sufficient evidence and we should not use it.


---

>
The following class has two main methods: 1) "fit", which is used to calculate the reference scores, and 2) "predict", which is used to calculate the p-values for the test predictions.

---

In [None]:
class MondrianConformalPredictor():
    def __init__(self, num_classes: int):
        """A conformal predictor based on Mondrian forests.
        Reference: https://arxiv.org/pdf/2107.07511.pdf
        Args:
            num_classes (int): Number of classes
        """
        self.num_classes = num_classes
        self.reference_score_dict = None

    def fit(self, calibration_probs: np.ndarray, calibration_labels: np.ndarray) -> dict:
        """Fit the predictor to calibration data and calculate reference scores.
        Reference scores are calculated based on the calibration data and will be used to calculate p-values later on,
            so that's why they are named as "reference" scores.
        Reference scores for each class will be collected in a dictionary with class indices as keys.
        Please note that each reference score is a 1D array, however, the length of reference score arrays will
            be different for different classes (as we are dealing with a presumably imbalanced dataset). This is
            the reason that we use a dictionary to store reference scores for different classes instead of a 2-D array.
        Args:
            calibration_labels (np.ndarray): Labels for one fold of calibration data
            calibration_probs (np.ndarray): Probabilities for one fold of calibration data
        Returns:
            reference_score_dict (dict): A dictionary of sorted reference scores for each class. Sorting is important
                for calculating p-values later on.
        """
        reference_score_dict = dict()
        calibration_scores = 1 - calibration_probs
        for class_idx in range(self.num_classes):
            scores = calibration_scores[:, class_idx]
            selected_scores = np.where(calibration_labels == class_idx, scores, -1)
            selected_scores = np.sort(selected_scores[selected_scores != -1], axis=0) # From smaller scores (more certain) to larger scores (less certain)
            reference_score_dict[class_idx] = selected_scores
        self.reference_score_dict = reference_score_dict

    def _calculate_p_values(self, test_scores: np.ndarray, reference_score_dict: dict) -> np.ndarray:
        """Calculate the p-values for the test scores based on the reference scores for the calibration data.
        Args:
            test_scores (np.ndarray): Calculated non-conformal scores for test data
            reference_score_dict (dict): A dictionary of sorted reference scores for one fold of calibration data
        Returns:
            test_pvalues (np.ndarray): P-values for test scores
        """
        test_pvalues = np.zeros((test_scores.shape[0], self.num_classes))
        for class_idx in range(self.num_classes):

            # Finding the rank of the test scores w.r.t. the reference scores
            reference_scores = reference_score_dict[class_idx]
            num_reference_scores = reference_scores.shape[0]
            test_ranks = num_reference_scores - np.searchsorted(reference_scores, test_scores[:, class_idx], side='left')

            # Calculating the p-values
            test_pvalues[:, class_idx] = (test_ranks+1) / (num_reference_scores+1)

        return test_pvalues

    def predict(self, test_probs: np.ndarray, error_rate: float=0.2) -> list:
        """Predict the confident (certain) classes for test data
        Args:
            test_probs (np.ndarray): Probabilities for test data
            error_rate (float): the desirable error rate (tolerable error due to uncertainty)
        Returns:
            certain_predictions (list): List of confident (certain) classes for each test sample
            certain_probabilities (list): List of confident (certain) probabilities for each test sample
            certain_scores (list): List of confident (certain) scores for each test sample
        """
        assert self.reference_score_dict is not None, "Please fit the predictor first!"

        # Calculating the sample p-values for multiple reference folds and averaging them
        test_scores = 1 - test_probs
        test_pvalues = np.array(self._calculate_p_values(test_scores, self.reference_score_dict))

        # Sample certainties are 1 if the sample p-value is greater than the error rate
        sample_certainties = (test_pvalues > error_rate).astype(int)
        certain_predictions = [np.where(sample_certainties[i, :]==1)[0] for i in range(sample_certainties.shape[0])]
        certain_probablities = [test_probs[i, sample_certainties[i, :]==1] for i in range(sample_certainties.shape[0])]
        certain_pvalues = [test_pvalues[i, sample_certainties[i, :]==1] for i in range(sample_certainties.shape[0])]

        return certain_predictions, certain_probablities, certain_pvalues

Now let's see how our conformal predictor will alter the predictions of our model:

In [None]:
y_preds, y_true = get_inference_results("CE_seed_42.pth", set_name='test')

# Calibrate the predictions
calibrator = IsotonicRegression()
calibrator.fit(y_preds, y_true)
y_preds_calibrated = calibrator.transform(y_preds)[:, None]
y_preds_calibrated = np.concatenate([y_preds_calibrated, 1-y_preds_calibrated], axis=1)

# Mondrian Conformal Predictor
mccp = MondrianConformalPredictor(num_classes=2)
mccp.fit(y_preds, y_true)
certain_predictions, certain_probablities, certain_pvalues = mccp.predict(y_preds, error_rate=0.1)

for i, (label, prediction, probablity, pvalue) in enumerate(zip(y_true, certain_predictions, certain_probablities, certain_pvalues)):
    print(f"Label: {label} | Prediction set: {prediction} | Probablity: {probablity} | P-value: {pvalue}")
    if i == 20:
        break

As you can see above, there are sevral predictions that have more than one label in their prediction sets. This simply means that our current model thinks at an error rate of 0.1 (meaning we accept 10% incorrect predictions), both labels could be positive for those predictions. But you can also see a few others which have only one label in their prediction sets. This means that our model is confident in its prediction and we can use it accordingly. Our model confidence may improve if we train a better model, but even right now, you can play with the threshold value and see how it affects the predictions.

### Thank you to Dr. Bardia Khosravi, MD, MPH, and Dr. Pouria Rouzrokh, MD, MPH, for their invaluable assistance in preparing this notebook.