# Comparing CRPS/MDN models with and without epistemic uncertainty methods

In this notebook, we compare different methods with and without epistemic uncertainty modeling. We implement and evaluate four models: a basic CRPS model, a CRPS ensemble, a Mixture Density Network (MDN), and an MDN with Bayesian Neural Network (BNN) parameters. Using the diabetes dataset as an example, we demonstrate how to incorporate epistemic uncertainty through ensembling and variational inference, and compare their performance using CRPS, negative log-likelihood (NLL), and RMSE metrics. The notebook shows how TorchNaut makes it easy to experiment with different uncertainty quantification approaches.

# Imports, loading and splitting data

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes
from sklearn.preprocessing import StandardScaler
from torchnaut import mdn, utils, kde, crps, epistemic
from torchnaut.utils import LabelScaler

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load dataset
data = load_diabetes()
X, y = data.data, data.target.reshape(-1, 1) # Note: reshape y to 2D.

# Split into train (70%), validation (15%), and test (15%)
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Preprocessing

In [2]:
# Standardize features using training data only
scaler_X = StandardScaler().fit(X_train)
X_train = scaler_X.transform(X_train)
X_val = scaler_X.transform(X_val)
X_test = scaler_X.transform(X_test)

# The utils.LabelScaler class provides a convenient way to standardize targets, with the ability to
# also inverse transform tensors of shape [batch_size, n_samples, n_features]
scaler_y = LabelScaler()
y_train_scaled = scaler_y.fit_transform(y_train)

# Convert numpy arrays to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

y_train_tensor = torch.tensor(y_train_scaled, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Models and optimizers

At this point we are comparing a number of models, so we need some modularity

In [3]:
models = {}

In [4]:
# CRPS Model
class CRPS_Model(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            crps.EpsilonSampler(16), 
            nn.Linear(64 + 16, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        ])

    def forward(self, x, n_samples=100):
        with crps.EpsilonSampler.n_samples(n_samples):
            for layer in self.layers:
                x = layer(x)
        return x.squeeze()

    # methods for templated training and evaluation:

    @staticmethod
    def train_loss(model, batch_X, batch_y):
        outputs = model(batch_X)
        return crps.crps_loss(outputs, batch_y).mean()
    
    @staticmethod
    def validation_loss(model, batch_X, batch_y, scaler_y):
        outputs = model(batch_X)
        outputs_scaled = scaler_y.inverse_transform(outputs.unsqueeze(-1)).squeeze()
        return crps.crps_loss(outputs_scaled, batch_y)

    @staticmethod
    def test_eval(model, batch_X, batch_y, scaler_y):
        outputs = model(batch_X)
        outputs_scaled = scaler_y.inverse_transform(outputs.unsqueeze(-1)).squeeze()

        crps_loss = crps.crps_loss(outputs_scaled, batch_y)
        with torch.device(outputs.device):
            nll = kde.nll_gpu(outputs_scaled, batch_y)
        pred_mean = outputs_scaled.mean(dim=-1)

        return nll, crps_loss, pred_mean

models["CRPS"] = CRPS_Model(input_dim=X_train.shape[1]).to(device)

In [5]:
# CRPS Ensemble Model
class CRPS_Ensemble_Model(epistemic.CRPSEnsemble):
    # most of the logic is already implemented in the parent class
    
    # methods for templated training and evaluation:

    @staticmethod
    def train_loss(model, batch_X, batch_y):
        outputs = model(batch_X)

        # ensemble outputs are given as [batch_size x n_ensemble x n_samples] and we
        # apply crps loss to each model individually (i.e., last dimension as usual)
        return crps.crps_loss(outputs, batch_y.unsqueeze(-2)).mean()
    
    @staticmethod
    def validation_loss(model, batch_X, batch_y, scaler_y):
        outputs = model(batch_X)
        outputs_scaled = scaler_y.inverse_transform(outputs.unsqueeze(-1)).squeeze()

        # for validation and testing, we need to flatten the last two dimensions to
        # get a sample from the mixture distribution of the component models
        outputs_scaled = outputs_scaled.flatten(start_dim=-2)
        return crps.crps_loss(outputs_scaled, batch_y)

    @staticmethod
    def test_eval(model, batch_X, batch_y, scaler_y):
        outputs = model(batch_X)
        outputs_scaled = scaler_y.inverse_transform(outputs.unsqueeze(-1)).squeeze()
        outputs_scaled = outputs_scaled.flatten(start_dim=-2)

        crps_loss = crps.crps_loss(outputs_scaled, batch_y)
        with torch.device(outputs.device):
            nll = kde.nll_gpu(outputs_scaled, batch_y)
        pred_mean = outputs_scaled.mean(dim=-1)

        return nll, crps_loss, pred_mean
    
models["CRPS Ensemble"] = CRPS_Ensemble_Model([
    CRPS_Model(input_dim=X_train.shape[1]).to(device) for _ in range(5)
])

In [6]:
# MDN model
class MDN_Model(nn.Module):
    def __init__(self, input_dim, n_components=100):
        super().__init__()
        self.mdn = mdn.MDN(n_components=n_components) # Helper class
        self.layers = nn.ModuleList([
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, self.mdn.network_output_dim), # MDN output parameters
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    # methods for templated training and evaluation:

    @staticmethod
    def train_loss(model, batch_X, batch_y):
        outputs = model(batch_X)
        
        nll = (-1) * model.mdn.get_dist(outputs).log_prob(batch_y.squeeze())
        return torch.clamp(nll, min=-20).mean()
    
    @staticmethod
    def validation_loss(model, batch_X, batch_y, scaler_y):
        outputs = model(batch_X)
        outputs_scaled = model.mdn.inverse_transform(outputs, scaler_y).squeeze(-1)
        nll = (-1) * model.mdn.get_dist(outputs_scaled).log_prob(batch_y.squeeze())
        return nll

    @staticmethod
    def test_eval(model, batch_X, batch_y, scaler_y):
        outputs = model(batch_X)
        outputs_scaled = model.mdn.inverse_transform(outputs, scaler_y).squeeze(-1)
        samples = model.mdn.get_dist(outputs_scaled).sample((1000,)).T

        crps_loss = crps.crps_loss(samples, batch_y)
        pred_mean = model.mdn.get_dist(outputs_scaled).mean
        nll = (-1) * model.mdn.get_dist(outputs_scaled).log_prob(batch_y.squeeze())
        return nll, crps_loss, pred_mean

models["MDN"] = MDN_Model(input_dim=X_train.shape[1]).to(device)

In [7]:
# MDN model with Bayesian Neural Network
class MDN_BNN_Model(nn.Module):
    def __init__(self, input_dim, n_components=100, kl_weight=0.01):
        super().__init__()
        self.mdn = mdn.MDN(n_components=n_components)
        self.kl_weight = kl_weight
        self.layers = nn.ModuleList([
            epistemic.BayesianLinear(input_dim, 64),
            nn.ReLU(),
            epistemic.BayesianLinear(64, 32),
            nn.ReLU(),
            epistemic.BayesianLinear(32, self.mdn.network_output_dim),
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    @staticmethod
    def train_loss(model, batch_X, batch_y):
        outputs = model(batch_X)

        # BNN training is done by just doing backprop on a single forward pass
        nll = (-1) * model.mdn.get_dist(outputs).log_prob(batch_y.squeeze())
        # However, we also need KL divergence loss for the BNN
        kl_loss = epistemic.get_kl_term(model)
        return torch.clamp(nll, min=-20).mean() + model.kl_weight * kl_loss

    @staticmethod
    def validation_loss(model, batch_X, batch_y, scaler_y):
        # BNN validation is done by sampling from the posterior
        num_evals = 10
        lls = []
        for i in range(num_evals):
            outputs = model(batch_X)
            outputs_scaled = model.mdn.inverse_transform(outputs, scaler_y).squeeze()
            lls.append(model.mdn.get_dist(outputs_scaled).log_prob(batch_y.squeeze()))

        # We need to average the likelihoods in a numerically stable way
        nll = (-1)*(
            torch.logsumexp(
                torch.stack(lls, dim=0),
                dim=0,
            ) - torch.log(torch.tensor(num_evals))
        )

        return nll

    @staticmethod
    def test_eval(model, batch_X, batch_y, scaler_y):
        # BNN validation is done by sampling from the posterior
        num_evals = 50
        lls = []
        samples = []
        means = []

        for i in range(num_evals):
            outputs = model(batch_X)
            outputs_scaled = model.mdn.inverse_transform(outputs, scaler_y).squeeze()
            pred_samples = model.mdn.get_dist(outputs_scaled).sample((1000,)).T
            pred_mean = model.mdn.get_dist(outputs_scaled).mean

            lls.append(model.mdn.get_dist(outputs_scaled).log_prob(batch_y.squeeze()))
            means.append(pred_mean)
            samples.append(pred_samples)
    
        pred_mean = torch.stack(means, dim=0).mean(dim=0)

        nll = (-1)*(
            torch.logsumexp(
                torch.stack(lls, dim=0),
                dim=0,
            ) - torch.log(torch.tensor(num_evals))
        )
    
        # Mixture of samples
        crps_loss = crps.crps_loss(torch.concatenate(samples, dim=-1), batch_y)

        return nll, crps_loss, pred_mean

# The theory suggest a kl_weight of 1/training_set_size, however in practice it seems that
# any nonzero kl_weight always results in a worse performing model.
models["MDN BNN"] = MDN_BNN_Model(input_dim=X_train.shape[1], kl_weight=0).to(device)

In [8]:
def train(model, lr=0.001):
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=5)

    # Training loop with validation
    epochs = 100
    best_val_loss = np.inf
    best_model = None
    patience = 0
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch_X, batch_y in train_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            optimizer.zero_grad()

            loss = model.train_loss(model, batch_X, batch_y)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
            optimizer.step()
            train_loss += loss.item()
        scheduler.step()

        # Validation phase
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                val_loss = model.validation_loss(model, batch_X, batch_y, scaler_y)
                val_losses.append(val_loss.cpu().numpy())
            val_loss = np.concatenate(val_losses).mean()

        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
        if best_val_loss > val_loss:
            best_val_loss = val_loss
            best_model = model.state_dict()
            patience = 0
        else:
            patience += 1
            if patience >= 9:
                print("Early stopping!")
                model.load_state_dict(best_model)
                break

In [9]:
train(models["CRPS"], lr=0.0001)

Epoch [1/100], Train Loss: 0.8233, Val Loss: 59.9787
Epoch [2/100], Train Loss: 0.8147, Val Loss: 59.6076
Epoch [3/100], Train Loss: 0.8108, Val Loss: 59.0376
Epoch [4/100], Train Loss: 0.8156, Val Loss: 57.9708
Epoch [5/100], Train Loss: 0.7983, Val Loss: 56.7538
Epoch [6/100], Train Loss: 0.7877, Val Loss: 55.2330
Epoch [7/100], Train Loss: 0.7650, Val Loss: 53.6087
Epoch [8/100], Train Loss: 0.7480, Val Loss: 52.0432
Epoch [9/100], Train Loss: 0.7310, Val Loss: 50.1892
Epoch [10/100], Train Loss: 0.7317, Val Loss: 48.2144
Epoch [11/100], Train Loss: 0.6915, Val Loss: 46.1995
Epoch [12/100], Train Loss: 0.6688, Val Loss: 44.5743
Epoch [13/100], Train Loss: 0.6613, Val Loss: 42.7342
Epoch [14/100], Train Loss: 0.6416, Val Loss: 40.2690
Epoch [15/100], Train Loss: 0.6236, Val Loss: 39.0085
Epoch [16/100], Train Loss: 0.5995, Val Loss: 37.0422
Epoch [17/100], Train Loss: 0.5759, Val Loss: 35.5175
Epoch [18/100], Train Loss: 0.5560, Val Loss: 34.1464
Epoch [19/100], Train Loss: 0.5377, V

In [10]:
train(models["MDN"], lr=0.0001)

Epoch [1/100], Train Loss: 1.5578, Val Loss: 9.7773
Epoch [2/100], Train Loss: 1.5521, Val Loss: 9.6779
Epoch [3/100], Train Loss: 1.5345, Val Loss: 9.5565
Epoch [4/100], Train Loss: 1.5498, Val Loss: 9.3632
Epoch [5/100], Train Loss: 1.5425, Val Loss: 9.0799
Epoch [6/100], Train Loss: 1.5382, Val Loss: 8.7556
Epoch [7/100], Train Loss: 1.5011, Val Loss: 8.4014
Epoch [8/100], Train Loss: 1.5278, Val Loss: 8.1186
Epoch [9/100], Train Loss: 1.4936, Val Loss: 7.7884
Epoch [10/100], Train Loss: 1.4489, Val Loss: 7.5522
Epoch [11/100], Train Loss: 1.4546, Val Loss: 7.3057
Epoch [12/100], Train Loss: 1.4374, Val Loss: 7.1236
Epoch [13/100], Train Loss: 1.4170, Val Loss: 6.9355
Epoch [14/100], Train Loss: 1.3942, Val Loss: 6.7836
Epoch [15/100], Train Loss: 1.3792, Val Loss: 6.6755
Epoch [16/100], Train Loss: 1.3567, Val Loss: 6.5542
Epoch [17/100], Train Loss: 1.3507, Val Loss: 6.4471
Epoch [18/100], Train Loss: 1.3285, Val Loss: 6.3828
Epoch [19/100], Train Loss: 1.3292, Val Loss: 6.3472
Ep

In [11]:
train(models["CRPS Ensemble"], lr=0.0001)

Epoch [1/100], Train Loss: 0.7840, Val Loss: 53.6423
Epoch [2/100], Train Loss: 0.7873, Val Loss: 53.4073
Epoch [3/100], Train Loss: 0.7793, Val Loss: 52.8570
Epoch [4/100], Train Loss: 0.7720, Val Loss: 52.1343
Epoch [5/100], Train Loss: 0.7466, Val Loss: 51.1202
Epoch [6/100], Train Loss: 0.7512, Val Loss: 49.8888
Epoch [7/100], Train Loss: 0.7286, Val Loss: 48.5927
Epoch [8/100], Train Loss: 0.7142, Val Loss: 46.7620
Epoch [9/100], Train Loss: 0.6798, Val Loss: 45.0625
Epoch [10/100], Train Loss: 0.6827, Val Loss: 43.3226
Epoch [11/100], Train Loss: 0.6495, Val Loss: 41.3851
Epoch [12/100], Train Loss: 0.6421, Val Loss: 39.6103
Epoch [13/100], Train Loss: 0.6225, Val Loss: 37.7615
Epoch [14/100], Train Loss: 0.5861, Val Loss: 35.9893
Epoch [15/100], Train Loss: 0.5591, Val Loss: 34.1865
Epoch [16/100], Train Loss: 0.5435, Val Loss: 32.8487
Epoch [17/100], Train Loss: 0.5148, Val Loss: 31.2191
Epoch [18/100], Train Loss: 0.5017, Val Loss: 30.2618
Epoch [19/100], Train Loss: 0.4860, V

In [12]:
train(models["MDN BNN"], lr=0.0001)

Epoch [1/100], Train Loss: 1.5075, Val Loss: 7.1795
Epoch [2/100], Train Loss: 1.5035, Val Loss: 7.1583
Epoch [3/100], Train Loss: 1.5205, Val Loss: 7.1344
Epoch [4/100], Train Loss: 1.5137, Val Loss: 7.1187
Epoch [5/100], Train Loss: 1.5124, Val Loss: 7.0804
Epoch [6/100], Train Loss: 1.4523, Val Loss: 7.0249
Epoch [7/100], Train Loss: 1.4669, Val Loss: 6.8752
Epoch [8/100], Train Loss: 1.4735, Val Loss: 6.8369
Epoch [9/100], Train Loss: 1.4532, Val Loss: 6.8497
Epoch [10/100], Train Loss: 1.4582, Val Loss: 6.7065
Epoch [11/100], Train Loss: 1.4617, Val Loss: 6.7093
Epoch [12/100], Train Loss: 1.4220, Val Loss: 6.6346
Epoch [13/100], Train Loss: 1.4459, Val Loss: 6.6052
Epoch [14/100], Train Loss: 1.4232, Val Loss: 6.5564
Epoch [15/100], Train Loss: 1.4247, Val Loss: 6.5590
Epoch [16/100], Train Loss: 1.4236, Val Loss: 6.5264
Epoch [17/100], Train Loss: 1.4195, Val Loss: 6.5078
Epoch [18/100], Train Loss: 1.4012, Val Loss: 6.4634
Epoch [19/100], Train Loss: 1.3856, Val Loss: 6.3981
Ep

In [13]:
marginal_point_pred = X_train_tensor.mean()
marginal_dist_pred = torch.tensor(y_train).T.to(device)

marginal_crps = []
marginal_nll = []
with torch.no_grad():
    for batch_X, batch_y in test_loader:
        with torch.device(device):
            mnll = kde.nll_gpu(marginal_dist_pred.repeat(batch_y.shape[0], 1), batch_y.to(device))
            marginal_nll.append(mnll.cpu().numpy())
        mcrps = crps.crps_loss(marginal_dist_pred.repeat(batch_y.shape[0], 1), batch_y.to(device)).cpu()
        marginal_crps.append(mcrps)
    test_crps_marginal = np.concatenate(marginal_crps).mean()
    test_nll_marginal = np.concatenate(marginal_nll).mean()

print(f"Marginal Test CRPS: {test_crps_marginal:.4f}")
print(f"Marginal Test NLL: {test_nll_marginal:.4f}")
print(f"Marginal Test RMSE: {((y_test_tensor-marginal_point_pred)**2).mean()**(1/2):.4f})")

Marginal Test CRPS: 43.0078
Marginal Test NLL: 5.6368
Marginal Test RMSE: 166.5705)


In [14]:
for name, model in models.items():
    test_crps = []
    test_rmse = []
    test_nll = []
    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            nll, crps_loss, pred_mean = model.test_eval(model, batch_X.to(device), batch_y.to(device), scaler_y)
            test_crps.append(crps_loss.cpu().numpy())
            test_nll.append(nll.cpu().numpy())
            test_rmse.append(((batch_y - pred_mean.cpu().numpy()) ** 2).mean(dim=-1).cpu().numpy())
        test_crps = np.concatenate(test_crps).mean()
        test_nll = np.concatenate(test_nll).mean()
        test_rmse = np.sqrt(np.concatenate(test_rmse).mean())

    print(f"{name} Test CRPS: {test_crps:.4f}")
    print(f"{name} Test NLL: {test_nll:.4f}")
    print(f"{name} Test RMSE: {test_rmse:.4f}")

CRPS Test CRPS: 31.0463
CRPS Test NLL: 5.4289
CRPS Test RMSE: 86.8475
CRPS Ensemble Test CRPS: 30.6537
CRPS Ensemble Test NLL: 5.3919
CRPS Ensemble Test RMSE: 85.9982
MDN Test CRPS: 38.3105
MDN Test NLL: 7.2156
MDN Test RMSE: 79.8511
MDN BNN Test CRPS: 38.9359
MDN BNN Test NLL: 6.3563
MDN BNN Test RMSE: 77.6576
