# Posterior Symmetry Reproduction

In [91]:
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR, LambdaLR
import torch.nn.functional as F

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.models.lenet import bayesian_lenet, lenet
from torch_uncertainty.models import mc_dropout
from torch_uncertainty.routines import ClassificationRoutine
from lightning.pytorch import LightningModule

from sklearn.metrics import precision_recall_curve, roc_curve, auc, accuracy_score
from sklearn.calibration import calibration_curve
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as st

from laplace import Laplace
from utils.swa_gaussian.swag.posteriors import SWAG
import utils.swa_gaussian.swag.posteriors as swag_posteriors
from utils.posterior_symmetry.mmd.mmd_torch import mmdagg
from utils.posterior_symmetry.symmetries.permutation import Permuter
# from utils.posterior_symmetry.symmetries.scale import Scaler # ANVÄNDER torch_symmetry SOM INTE FINNS ELLER ??
from utils.bayes_neural_networks.src.Stochastic_Gradient_HMC_SA.optimizers import H_SA_SGHMC

from pathlib import Path
from safetensors.torch import load_file

In [3]:
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_PATH = "data"
MODEL_PATH = Path("models", "trained_optunets")
NUM_WORKERS = 4

# Parameters from paper
EPOCHS = 60
BATCH_SIZE = 64
LEARNING_RATE = 0.04
WEIGHT_DECAY = 2e-4

## Method params
DROPOUT_RATE = 0.2 # last layer dropout rate

# Models
NMODELS = 100
ENSEMBLE_MODELS = 10 # number of models to use in ensemble
N_SAMPLES = 100 # number of samples to draw from model

MC_SAMPLES = 3 # Posterior samples for model

In [4]:
# Load MNIST data
root = Path(DATA_PATH)
datamodule = MNISTDataModule(root=root, batch_size=BATCH_SIZE, eval_ood=False, num_workers=NUM_WORKERS)

## OptuNet

### Model Definition

In [5]:
class OptuNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Add layers for OptuNet (use Section C.2.1 from the paper for details)
        # Layers: Conv2D (out_ch=2, ks=4, groups=1) -> Max Pooling (ks=3, stride=3) -> ReLU -> Conv2D (out_ch=10, ks=5, groups=2) -> Average Pooling -> ReLU -> Linear 10x10
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=4, groups=1, bias=False)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=3)
        self.conv2 = nn.Conv2d(in_channels=2, out_channels=10, kernel_size=5, groups=2, bias=False)
        self.pool2 = nn.AvgPool2d(kernel_size=2)
        self.fc1 = nn.Linear(in_features=10, out_features=10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.pool1(self.conv1(x)))  # First conv, max pooling, ReLU
        x = self.relu(self.pool2(self.conv2(x)))  # Second conv, avg pooling, ReLU
        x = torch.mean(x, dim=(2, 3))
        x = self.fc1(x)  # Linear layer
        return x

# Optimizer and LR scheduler
def optim_optunet(model: nn.Module):
    optimizer = optim.SGD(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY
    )
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[15, 30],
        gamma=0.5
    )
    return {"optimizer": optimizer, "lr_scheduler": scheduler}

# Loss function
def loss_optunet(model: nn.Module):
    loss = ELBOLoss(
        model=model,
        inner_loss=nn.CrossEntropyLoss(),
        kl_weight= 1/10000,
        num_samples=3,
    )
    return loss

In [6]:
# Load functions
def load_optunet_model(version: int):
    model = OptuNet()
    path = Path(f"models/mnist-optunet-0-8191/version_{version}.safetensors")

    if not path.exists():
        raise ValueError("File does not exist")

    state_dict = load_file(path)
    model.load_state_dict(state_dict=state_dict)
    return model

def load_trained_optunet(path):
    checkpoint = torch.load(path)

    # Filter out unwanted keys (e.g., those related to loss)
    state_dict = {
        k.replace("model.", ""): v
        for k, v in checkpoint["state_dict"].items()
        if not k.startswith("loss.")
    }
    model = OptuNet()
    model.load_state_dict(state_dict)
    return model

### Model Training

In [15]:
def train_optunets(n_models = 100, start_idx = 0, tag=""):
    # tag (str) is used to tag save name with specific tag

    # Train n_models OptuNets
    for i in range(start_idx, n_models):
        model = OptuNet()

        trainer = TUTrainer(
            accelerator="gpu",
            enable_progress_bar=False,
            max_epochs=EPOCHS)
        
        # loss
        loss = ELBOLoss(
            model=model,
            inner_loss=nn.CrossEntropyLoss(),
            kl_weight=1/10000,
            num_samples=3,
        )

        routine = ClassificationRoutine(
            model=model,
            num_classes=datamodule.num_classes,
            loss=loss,
            optim_recipe=optim_optunet(model),
            is_ensemble=True
        )

        trainer.fit(model=routine, datamodule=datamodule)

        # Save the trained model
        save_path = Path(MODEL_PATH, f"model_{tag}{i+1}.pt")
        trainer.save_checkpoint(save_path)
    
    print(f"Trained {n_models} models. Saved to {MODEL_PATH}")

In [None]:
train_optunets(n_models=NMODELS, start_idx=30, tag="t")

In [12]:
def warmup_cosine_scheduler(model, warmup_steps, total_steps, min_lr=0, max_lr=0.04):
    # Define the learning rate scheduler as a Lambda function
    def lr_lambda(epoch):
        if epoch < warmup_steps:
            # Linear warmup: Increase from 0 to max_lr
            return float(epoch) / float(max(1, warmup_steps))
        else:
            # Cosine decay after warmup
            progress = (epoch - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return min_lr + 0.5 * (max_lr - min_lr) * (1 + torch.cos(torch.pi * progress))

    optimizer = optim_optunet(model)["optimizer"]
    scheduler = LambdaLR(optimizer, lr_lambda)
    return scheduler

### Model Loading

In [15]:
def save_trained_model(model, path):
    # Ensure path exists
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    # Save the model
    torch.save(model.state_dict(), path)

In [7]:
def load_trained_models(n_models = 100, start_idx = 0, tag=""):
    posterior_models = []

    for i in range(start_idx, n_models):
        path = Path(MODEL_PATH, f"model_{tag}{i+1}.pt")
        model = load_trained_optunet(path)
        model = model.to(DEVICE) # Needed?
        posterior_models.append(model)
    
    print(f"Loaded {len(posterior_models)} models")
    return posterior_models

In [8]:
posterior_models = load_trained_models(n_models=30, tag="t")

  checkpoint = torch.load(path)


Loaded 30 models


## Baselines

### Dropout

In [13]:
class OptuDrop(OptuNet):
    def __init__(self):
        super().__init__()
        self.dropout = nn.Dropout(p=DROPOUT_RATE)
    
    def forward(self, x):
        x = self.dropout(super().forward(x))
        return x

In [14]:
model = OptuDrop()
loss_fn = nn.CrossEntropyLoss()
routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss_fn,
    optim_recipe=optim_optunet(model),
    is_ensemble=False
)

trainer = TUTrainer(
    accelerator="gpu",
    enable_progress_bar=True,
    max_epochs=EPOCHS
)

trainer.fit(model=routine, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | model            | OptuDrop         | 392    | train
1 | loss             | CrossEntropyLoss | 0      | train
2 | format_batch_fn  | Identity         | 0      | train
3 | val_cls_metrics  | MetricCollection | 0      | train
4 | test_cls_metrics | MetricCollection | 0      | train
5 | test_id_entropy  | Entropy          | 0      | train
6 | mixup   

Epoch 0: 100%|██████████| 938/938 [00:09<00:00, 94.13it/s, v_num=106, train_loss=1.910]



Epoch 1: 100%|██████████| 938/938 [00:05<00:00, 163.68it/s, v_num=106, train_loss=1.650, Acc%=41.30]



Epoch 2: 100%|██████████| 938/938 [00:06<00:00, 154.54it/s, v_num=106, train_loss=1.680, Acc%=48.20]



Epoch 3: 100%|██████████| 938/938 [00:06<00:00, 147.46it/s, v_num=106, train_loss=1.620, Acc%=65.00]



Epoch 4: 100%|██████████| 938/938 [00:05<00:00, 157.83it/s, v_num=106, train_loss=1.710, Acc%=67.50]



Epoch 5: 100%|██████████| 938/938 [00:05<00:00, 164.73it/s, v_num=106, train_loss=1.250, Acc%=65.80]



Epoch 6: 100%|██████████| 938/938 [00:05<00:00, 158.81it/s, v_num=106, train_loss=1.640, Acc%=71.70]



Epoch 7: 100%|██████████| 938/938 [00:05<00:00, 160.11it/s, v_num=106, train_loss=1.310, Acc%=68.70]



Epoch 8: 100%|██████████| 938/938 [00:06<00:00, 151.05it/s, v_num=106, train_loss=1.070, Acc%=71.80]



Epoch 9: 100%|██████████| 938/938 [00:06<00:00, 143.31it/s, v_num=106, train_loss=1.230, Acc%=74.90]



Epoch 10: 100%|██████████| 938/938 [00:05<00:00, 159.40it/s, v_num=106, train_loss=1.050, Acc%=72.70]



Epoch 11: 100%|██████████| 938/938 [00:05<00:00, 162.21it/s, v_num=106, train_loss=1.260, Acc%=76.10]



Epoch 12: 100%|██████████| 938/938 [00:06<00:00, 144.78it/s, v_num=106, train_loss=1.490, Acc%=68.00]



Epoch 13: 100%|██████████| 938/938 [00:06<00:00, 142.87it/s, v_num=106, train_loss=0.844, Acc%=75.20]



Epoch 14: 100%|██████████| 938/938 [00:06<00:00, 151.33it/s, v_num=106, train_loss=0.933, Acc%=74.30]



Epoch 15: 100%|██████████| 938/938 [00:06<00:00, 153.40it/s, v_num=106, train_loss=1.250, Acc%=73.60]



Epoch 16: 100%|██████████| 938/938 [00:06<00:00, 148.41it/s, v_num=106, train_loss=1.200, Acc%=77.80]



Epoch 17: 100%|██████████| 938/938 [00:06<00:00, 144.26it/s, v_num=106, train_loss=1.040, Acc%=76.70]



Epoch 18: 100%|██████████| 938/938 [00:06<00:00, 150.53it/s, v_num=106, train_loss=1.030, Acc%=77.90]



Epoch 19: 100%|██████████| 938/938 [00:06<00:00, 145.90it/s, v_num=106, train_loss=0.948, Acc%=76.80]



Epoch 25: 100%|██████████| 938/938 [00:06<00:00, 154.13it/s, v_num=106, train_loss=1.170, Acc%=79.00]



Epoch 27: 100%|██████████| 938/938 [00:06<00:00, 142.56it/s, v_num=106, train_loss=1.210, Acc%=78.50]



Epoch 59: 100%|██████████| 938/938 [00:07<00:00, 128.11it/s, v_num=106, train_loss=1.190, Acc%=83.20]

`Trainer.fit` stopped: `max_epochs=60` reached.


Epoch 59: 100%|██████████| 938/938 [00:07<00:00, 127.97it/s, v_num=106, train_loss=1.190, Acc%=83.20]


In [96]:
# Testing
results = trainer.test(model=routine, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 88.05it/s]


### viBNN

### SWAG

In [None]:
# Train OptuNet for SWAG
model = OptuNet()

routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=nn.CrossEntropyLoss(),
    optim_recipe=optim_optunet(model),
    is_ensemble=False # Single model (not ensemble here)
)

trainer = TUTrainer(
    accelerator="gpu",
    max_epochs=2*EPOCHS, # Train twice as long for SWAG
    enable_progress_bar=True
)

# Fit the model
trainer.fit(model=routine, datamodule=datamodule)

# Save checkpoints every 10 epochs from epoch 80 onward
checkpoint_dir = Path("models/swag_checkpoints/")
checkpoint_dir.mkdir(parents=True, exist_ok=True)

checkpoints = []
for epoch in range(80, 121, 10):
    checkpoint_path = checkpoint_dir / f"model_epoch_{epoch}.pt"
    trainer.save_checkpoint(checkpoint_path)
    checkpoints.append(checkpoint_path)


In [159]:
# Create the SWAG object
swag_model = SWAG(
    base=OptuNet,
    max_num_models=20,
    var_clamp=1e-30
)

# Add the collected checkpoints to the SWAG posterior
for checkpoint_path in checkpoints:
    checkpoint = torch.load(checkpoint_path)
    swag_model.collect_model(load_trained_optunet(checkpoint_path)) 

  checkpoint = torch.load(checkpoint_path)
  checkpoint = torch.load(path)


In [None]:
# For consistent results
class SWAGLightningWrapper(LightningModule):
    def __init__(self, swag_model, num_samples=10, scale=0.1):
        super().__init__()
        self.swag_model = swag_model
        self.num_samples = num_samples
        self.scale = scale

    def forward(self, x):
        preds = []
        for _ in range(self.num_samples):
            sampled_model = self.swag_model.sample(scale=self.scale)  # Sample from SWAG posterior
            sampled_model.eval()
            with torch.no_grad():
                preds.append(sampled_model(x))
        return torch.stack(preds).mean(dim=0)  # Aggregate predictions

    def predict_step(self, batch, batch_idx):
        x, _ = batch
        return self.forward(x)


In [None]:
# Wrap SWAG model for testing
swag_wrapper = SWAGLightningWrapper(
    swag_model=swag_model,
    num_samples=10,  # Number of posterior samples as per the paper
    scale=0.1        # Scale parameter for SWAG sampling
)

# Test the SWAG model
results = trainer.test(model=swag_wrapper, datamodule=datamodule)
print("SWAG Test Results:", results)


In [155]:
trainer.test(model=routine, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 97.30it/s] 


[{'test/cal/ECE': 0.05344296991825104,
  'test/cal/aECE': 0.05416256561875343,
  'test/cls/Acc': 0.8267999887466431,
  'test/cls/Brier': 0.25886476039886475,
  'test/cls/NLL': 0.570395290851593,
  'test/sc/AUGRC': 0.037595879286527634,
  'test/sc/AURC': 0.050700489431619644,
  'test/sc/Cov@5Risk': 0.6302000284194946,
  'test/sc/Risk@80Cov': 0.09300000220537186,
  'test/cls/Entropy': 0.6775237917900085}]

### Laplace

In [77]:
# Model
model = OptuNet()

def scheduler_laplace(optimizer):
    return MultiStepLR(
        optimizer,
        milestones=[15, 30],
        gamma=0.5
    )

# Routine
loss_fn = nn.CrossEntropyLoss()  # Standard cross-entropy loss
routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss_fn,
    optim_recipe=optim_optunet(model),
    # scheduler_recipe=scheduler_laplace,
    is_ensemble=False
)

# Train the model to MAP estimate
trainer = TUTrainer(
    accelerator="gpu",
    max_epochs=EPOCHS,
    enable_progress_bar=True
)
trainer.fit(model=routine, datamodule=datamodule)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | model            | OptuNet          | 392    | train
1 | loss             | CrossEntropyLoss | 0      | train
2 | format_batch_fn  | Identity         | 0      | train
3 | val_cls_metrics  | MetricCollection | 0      | train
4 | test_cls_metrics | MetricCollection | 0      | train
5 | test_id_entropy  | Entropy          | 0      | train
6 | mixup            | Identity         | 0      | train
--------------------------------------------------------------
392       Trainable params
0         Non-trainable params
392       Total params
0.002     Total estimated model params size (MB)
31        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 938/938 [00:09<00:00, 96.90it/s, v_num=57, train_loss=1.770]



Epoch 1: 100%|██████████| 938/938 [00:05<00:00, 160.81it/s, v_num=57, train_loss=1.070, Acc%=52.30]



Epoch 2: 100%|██████████| 938/938 [00:06<00:00, 143.72it/s, v_num=57, train_loss=1.740, Acc%=59.20]



Epoch 3: 100%|██████████| 938/938 [00:06<00:00, 139.31it/s, v_num=57, train_loss=0.841, Acc%=53.10]



Epoch 4: 100%|██████████| 938/938 [00:06<00:00, 144.08it/s, v_num=57, train_loss=1.010, Acc%=73.50]



Epoch 5: 100%|██████████| 938/938 [00:06<00:00, 148.42it/s, v_num=57, train_loss=1.680, Acc%=65.40]



Epoch 6: 100%|██████████| 938/938 [00:06<00:00, 148.22it/s, v_num=57, train_loss=0.599, Acc%=76.90]



Epoch 7: 100%|██████████| 938/938 [00:06<00:00, 143.06it/s, v_num=57, train_loss=0.713, Acc%=79.40]



Epoch 8: 100%|██████████| 938/938 [00:06<00:00, 138.37it/s, v_num=57, train_loss=1.010, Acc%=79.20]



Epoch 9: 100%|██████████| 938/938 [00:06<00:00, 144.06it/s, v_num=57, train_loss=0.803, Acc%=80.20]



Epoch 10: 100%|██████████| 938/938 [00:06<00:00, 144.23it/s, v_num=57, train_loss=0.867, Acc%=74.70]



Epoch 19: 100%|██████████| 938/938 [00:06<00:00, 151.63it/s, v_num=57, train_loss=1.390, Acc%=80.30]



Epoch 24: 100%|██████████| 938/938 [00:06<00:00, 152.22it/s, v_num=57, train_loss=0.696, Acc%=82.30]



Epoch 27: 100%|██████████| 938/938 [00:07<00:00, 123.66it/s, v_num=57, train_loss=0.513, Acc%=83.00]



Epoch 28: 100%|██████████| 938/938 [00:06<00:00, 134.27it/s, v_num=57, train_loss=1.040, Acc%=75.80]



Epoch 30: 100%|██████████| 938/938 [00:07<00:00, 120.40it/s, v_num=57, train_loss=0.573, Acc%=83.70]



Epoch 38: 100%|██████████| 938/938 [00:06<00:00, 151.51it/s, v_num=57, train_loss=0.517, Acc%=81.90]



Epoch 43: 100%|██████████| 938/938 [00:06<00:00, 143.95it/s, v_num=57, train_loss=0.956, Acc%=83.40]



Epoch 45: 100%|██████████| 938/938 [00:06<00:00, 155.05it/s, v_num=57, train_loss=0.769, Acc%=79.90]



Epoch 48: 100%|██████████| 938/938 [00:06<00:00, 146.01it/s, v_num=57, train_loss=0.626, Acc%=81.50]



Epoch 51: 100%|██████████| 938/938 [00:06<00:00, 145.88it/s, v_num=57, train_loss=0.469, Acc%=82.30]



Epoch 54: 100%|██████████| 938/938 [00:06<00:00, 147.30it/s, v_num=57, train_loss=1.710, Acc%=83.80]



Epoch 55: 100%|██████████| 938/938 [00:05<00:00, 158.17it/s, v_num=57, train_loss=0.833, Acc%=76.70]



Epoch 56: 100%|██████████| 938/938 [00:06<00:00, 154.98it/s, v_num=57, train_loss=0.886, Acc%=75.40]



Epoch 59: 100%|██████████| 938/938 [00:06<00:00, 140.42it/s, v_num=57, train_loss=0.419, Acc%=81.80]

`Trainer.fit` stopped: `max_epochs=60` reached.


Epoch 59: 100%|██████████| 938/938 [00:06<00:00, 140.28it/s, v_num=57, train_loss=0.419, Acc%=81.80]


In [25]:
# Apply Laplace approximation
laplace_model = Laplace(model, likelihood='classification', subset_of_weights='last_layer', hessian_structure='full')
laplace_model.fit(datamodule.train_dataloader())  # Fit the Laplace model on training data
laplace_model.optimize_prior_precision()  # Optimize prior precision



In [79]:
# Test
results = trainer.test(model=routine, datamodule=datamodule)
print("Laplace Test Results:", results)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 88.21it/s]


Laplace Test Results: [{'test/cal/ECE': 0.020094318315386772, 'test/cal/aECE': 0.02005278505384922, 'test/cls/Acc': 0.8176000118255615, 'test/cls/Brier': 0.2700233459472656, 'test/cls/NLL': 0.5702756643295288, 'test/sc/AUGRC': 0.04408245161175728, 'test/sc/AURC': 0.05958322063088417, 'test/sc/Cov@5Risk': 0.5226000547409058, 'test/sc/Risk@80Cov': 0.10912500321865082, 'test/cls/Entropy': 0.5809915661811829}]


### SGHMC

### pSGLD

### DE

In [None]:
def train_deep_ensemble(model_class, datamodule, num_ensembles=10, save_dir="models/deep_ensemble"):
    """
    Train a deep ensemble of models.
    Args:
        model_class: The model class (e.g., OptuNet).
        datamodule: Data module providing train/val/test splits.
        num_ensembles (int): Number of models in the ensemble.
        save_dir (str): Directory to save the trained models.
    Returns:
        list: Trained models.
    """
    save_path = Path(save_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    trained_models = []
    for i in range(num_ensembles):
        print(f"Training model {i + 1}/{num_ensembles}...")
        
        # Initialize a new model
        model = model_class()
        
        # Define loss and optimizer
        loss_fn = nn.CrossEntropyLoss()

        routine = ClassificationRoutine(
            model=model,
            num_classes=datamodule.num_classes,
            loss=loss_fn,
            optim_recipe=optim_optunet(model),  # Replace with appropriate optimizer
            is_ensemble=False
        )

        # Trainer
        trainer = TUTrainer(
            accelerator="gpu",
            max_epochs=EPOCHS,
            enable_progress_bar=True
        )

        # Train the model
        trainer.fit(model=routine, datamodule=datamodule)

        # Save the model
        model_path = save_path / f"model_{i+1}.pt"
        trainer.save_checkpoint(model_path)

        trained_models.append(model)

    return trained_models


In [35]:
def ensemble_predict(models, dataloader):
    """
    Perform inference using a deep ensemble.
    Args:
        models (list): List of trained models.
        dataloader (DataLoader): DataLoader for test data.
    Returns:
        np.ndarray: Averaged predictions from the ensemble.
    """
    all_preds = []

    for model in models:
        model.eval()
        preds = []

        for inputs, _ in dataloader:
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs)  # Logits shape: (batch_size, num_classes)
                preds.append(outputs.cpu().numpy())
        
        all_preds.append(np.concatenate(preds, axis=0))  # Combine batches

    # Stack predictions from all models and average
    ensemble_preds = np.mean(np.stack(all_preds, axis=0), axis=0)
    return ensemble_preds


In [None]:
de_models = train_deep_ensemble(OptuNet, datamodule, num_ensembles=5)

## Scores

### Preds Sampling

In [95]:
def evaluate_confidence_interval(model, dataloader, confidence=0.95):
    model.eval()

    all_preds = []
    all_targets = []
    lower_bounds = []
    upper_bounds = []

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(DEVICE)
            targets = targets.to(DEVICE)

            # For Monte carlo estimation of the ELBO using 3 samples
            preds = torch.stack([model(inputs) for _ in range(MC_SAMPLES)], dim=0)

            # Calculate mean and standard deviation
            preds_mean = preds.mean(dim=0)
            preds_std = preds.std(dim=0)

            # Apply softmax to the mean predictions for probabilities
            preds_mean = F.softmax(preds_mean, dim=1)


            # Compute confidence intervals
            z_value = st.norm.ppf(1 - (1 - confidence) / 2)  #dynamically computing the z-score for a given confidence level (e.g., 95%, 99%).
            ci_lower = preds_mean - z_value * preds_std
            ci_upper = preds_mean + z_value * preds_std

            all_preds.append(preds_mean)
            all_targets.append(targets)
            lower_bounds.append(ci_lower)
            upper_bounds.append(ci_upper)

    # Concatenate results for all batches
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    lower_bounds = torch.cat(lower_bounds)
    upper_bounds = torch.cat(upper_bounds)

    return all_preds, all_targets, lower_bounds, upper_bounds


### Posterior Estimation

In [53]:
def monte_carlo_sampling(model, data_loader, num_samples=50):
    """
    Perform Monte Carlo Dropout sampling on the model.
    Args:
        model (nn.Module): Trained OptuDrop model with dropout.
        data_loader (DataLoader): DataLoader for test data.
        num_samples (int): Number of MC samples.
    Returns:
        np.ndarray: Array of predictions from all samples.
    """
    model.eval()
    predictions = []

    for _ in range(num_samples):
        sampled_preds = []
        for inputs, _ in data_loader:
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs)
                sampled_preds.append(outputs.cpu().numpy())
        predictions.append(np.concatenate(sampled_preds, axis=0))

    return np.array(predictions)


In [54]:
def extract_model_weights(models):
    """
    Extract model weights to use as samples for posterior comparison.
    Args:
        models (list): List of trained models.
    Returns:
        np.ndarray: Flattened weight arrays for each model.
    """
    weight_samples = []
    for model in models:
        weights = []
        for param in model.parameters():
            weights.append(param.detach().cpu().numpy().flatten())
        weight_samples.append(np.concatenate(weights))
    return np.array(weight_samples)


In [64]:
def target_model_predictions(models, data_loader):
    """
    Generate predictions from trained models to form the target posterior in prediction space.
    Args:
        models (list): List of trained models.
        data_loader (DataLoader): DataLoader for test data.
    Returns:
        np.ndarray: Array of predictions for all models.
    """
    all_predictions = []

    for model in models:
        model.eval()
        preds = []
        for inputs, _ in data_loader:
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs)  # Logits
                preds.append(outputs.cpu().numpy())
        all_predictions.append(np.concatenate(preds, axis=0))  # Combine batches

    return np.array(all_predictions)  # Shape: (num_models, num_datapoints, num_classes)

In [66]:
# Sample model posterior
dropout_samples = monte_carlo_sampling(model, datamodule.test_dataloader()[0], num_samples=5)
dropout_samples_flat = dropout_samples.reshape(dropout_samples.shape[0], -1)

In [None]:
# Target posterior
target_predictions = target_model_predictions(posterior_models, datamodule.test_dataloader()[0])
# target_posterior_samples = extract_model_weights(posterior_models)
target_posterior_samples = target_predictions.reshape(30, -1)  # Shape: (30, 10000 * 10)

In [218]:
# Use mmdagg function
mmd_score = mmdagg(
    X=dropout_samples_flat,
    Y=target_posterior_samples,
    alpha=0.05,
    kernel="laplace_gaussian",
    number_bandwidths=10,
    weights_type="uniform",
    B1=2000,
    B2=2000,
    B3=50,
    seed=42424242,
)
print("MMD Score:", mmd_score)

MMD Score: [0.84391623 0.79542857 0.73696664 0.67138301 0.60202212 0.53211281
 0.4643618  0.40076836 0.34261158 0.29054258 0.0176627  0.58616206
 0.88334639 0.92725724 0.93261007 0.93324742 0.93332312 0.93326705
 0.64932177 0.11887361]


### AUPR

In [92]:
def aupr_score(predictions, labels):
    n_classes = predictions.shape[1]
    precision = {}
    recall = {}
    aupr = {}

    for i in range(n_classes):
        # Binarize the labels for class i
        binary_labels = (labels == i).astype(int)
        precision[i], recall[i], _ = precision_recall_curve(binary_labels, predictions[:, i])
        aupr[i] = auc(recall[i], precision[i])

    # Optional: Aggregate AUPR
    mean_aupr = np.mean(list(aupr.values()))
    return mean_aupr, aupr

### FPR95

In [85]:
def fpr95_score(predictions, labels):
    # Calculate predicted classes and confidences
    predicted_classes = np.argmax(predictions, axis=1)  # Class with highest probability
    confidences = np.max(predictions, axis=1)           # Confidence scores (max probability)
    binary_labels = (predicted_classes == labels).astype(int) # Determine binary labels (1 for correct, 0 for incorrect)
    fpr, tpr, thresholds = roc_curve(binary_labels, confidences)

    # Find the threshold where TPR is closest to 95%
    idx = np.where(tpr >= 0.95)[0][0]
    return fpr[idx]


### ACE

In [86]:
def calculate_ace(preds, targets, num_bins=10):
    # Convert preds to probabilities
    probs = torch.softmax(preds, dim=1).cpu().numpy()  # For multi-class, probs[:, 1] for class 1, or probs[:, positive_class] for specific class
    true_targets = targets.cpu().numpy()
    positive_class = 1

    # For binary classification, you can use probs[:, 1] for class 1 (positive_class)
    prob_true = probs[:, positive_class]

    binary_targets = (true_targets == positive_class).astype(int)

    # Get calibration curve: This will return the true fraction of positives and predicted probabilities for each bin
    fraction_of_positives, mean_predicted_value = calibration_curve(binary_targets, prob_true, n_bins=num_bins)

    # Calculate ACE: This is the average absolute difference between the true fraction of positives and the predicted probability
    ace = np.mean(np.abs(fraction_of_positives - mean_predicted_value))
    return ace

### General Scoring Function

In [101]:
def score_model(model, test_loader):
    model.to(DEVICE)
    preds, targets, lower_bounds, upper_bounds = evaluate_confidence_interval(model, test_loader)
    predictions = preds.cpu().numpy()
    labels = targets.cpu().numpy()

    # Scores
    mean_aupr, aupr = aupr_score(predictions, labels)
    fpr95 = fpr95_score(predictions, labels)
    ace = calculate_ace(preds, targets)

    print(f"Mean AUPR: {mean_aupr}")
    print(f"FPR95: {fpr95}")
    print(f"ACE: {ace:.4f}")

In [102]:
score_model(model, datamodule.test_dataloader()[0])

Mean AUPR: 0.9040881282443637
FPR95: 0.6742108397855867
ACE: 0.4506
