# Posterior Symmetry Reproduction

In [2]:
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR, LambdaLR
import torch.nn.functional as F

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.models.lenet import bayesian_lenet, lenet
from torch_uncertainty.models import mc_dropout
from torch_uncertainty.routines import ClassificationRoutine
from lightning.pytorch import LightningModule

from sklearn.metrics import precision_recall_curve, roc_curve, auc, accuracy_score
from sklearn.calibration import calibration_curve
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as st

from laplace import Laplace
from utils.swa_gaussian.swag.posteriors import SWAG
from torch_uncertainty.models import SWAG as TUSWAG
import utils.swa_gaussian.swag.posteriors as swag_posteriors
from utils.posterior_symmetry.mmd.mmd_torch import mmdagg
from utils.posterior_symmetry.symmetries.permutation import Permuter
from utils.posterior_symmetry.symmetries.scale import Scaler
from utils.bayes_neural_networks.src.Stochastic_Gradient_HMC_SA.optimizers import H_SA_SGHMC

from pathlib import Path
from safetensors.torch import load_file

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_PATH = "data"
MODEL_PATH = Path("models", "trained_optunets")
NUM_WORKERS = 4

# Parameters from paper
EPOCHS = 60
BATCH_SIZE = 64
LEARNING_RATE = 0.04
WEIGHT_DECAY = 2e-4

## Method params
DROPOUT_RATE = 0.2 # last layer dropout rate

# Models
NMODELS = 100
ENSEMBLE_MODELS = 10 # number of models to use in ensemble
N_SAMPLES = 100 # number of samples to draw from model

MC_SAMPLES = 3 # Posterior samples for model

In [7]:
# Load MNIST data
root = Path(DATA_PATH)
datamodule = MNISTDataModule(root=root, batch_size=BATCH_SIZE, eval_ood=False, num_workers=NUM_WORKERS)

## OptuNet

### Model Definition

In [8]:
class OptuNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Add layers for OptuNet (use Section C.2.1 from the paper for details)
        # Layers: Conv2D (out_ch=2, ks=4, groups=1) -> Max Pooling (ks=3, stride=3) -> ReLU -> Conv2D (out_ch=10, ks=5, groups=2) -> Average Pooling -> ReLU -> Linear 10x10
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=4, groups=1, bias=False)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=3)
        self.conv2 = nn.Conv2d(in_channels=2, out_channels=10, kernel_size=5, groups=2, bias=False)
        self.pool2 = nn.AvgPool2d(kernel_size=2)
        self.fc1 = nn.Linear(in_features=10, out_features=10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.pool1(self.conv1(x)))  # First conv, max pooling, ReLU
        x = self.relu(self.pool2(self.conv2(x)))  # Second conv, avg pooling, ReLU
        x = torch.mean(x, dim=(2, 3))
        x = self.fc1(x)  # Linear layer
        return x

# Optimizer and LR scheduler
def optim_optunet(model: nn.Module):
    optimizer = optim.SGD(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY
    )
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[15, 30],
        gamma=0.5
    )
    return {"optimizer": optimizer, "lr_scheduler": scheduler}

def optim_cosine_optunet(model: nn.Module, warmup_steps=5, total_steps=60, min_lr=0, max_lr=0.04):
    optimizer = optim.SGD(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY
    )

    def lr_lambda(epoch):
        if epoch < warmup_steps:
            # Linear warmup: Increase from 0 to max_lr
            return float(epoch) / float(max(1, warmup_steps))
        else:
            # Cosine decay after warmup
            progress = (epoch - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return min_lr + 0.5 * (max_lr - min_lr) * (1 + torch.cos(torch.pi * progress))

    scheduler = LambdaLR(optimizer, lr_lambda)
    return {"optimizer": optimizer, "lr_scheduler": scheduler}

# Loss function
def loss_optunet(model: nn.Module):
    loss = ELBOLoss(
        model=model,
        inner_loss=nn.CrossEntropyLoss(),
        kl_weight= 1/10000,
        num_samples=3,
    )
    return loss

In [9]:
# Load functions
def load_optunet_model(version: int):
    model = OptuNet()
    path = Path(f"models/mnist-optunet-0-8191/version_{version}.safetensors")

    if not path.exists():
        raise ValueError("File does not exist")

    state_dict = load_file(path)
    model.load_state_dict(state_dict=state_dict)
    return model

def load_trained_optunet(path):
    checkpoint = torch.load(path)

    # Filter out unwanted keys (e.g., those related to loss)
    state_dict = {
        k.replace("model.", ""): v
        for k, v in checkpoint["state_dict"].items()
        if not k.startswith("loss.")
    }
    model = OptuNet()
    model.load_state_dict(state_dict)
    return model

### Model Training

In [10]:
def train_optunets(n_models = 100, start_idx = 0, tag=""):
    # tag (str) is used to tag save name with specific tag

    # Train n_models OptuNets
    for i in range(start_idx, n_models):
        model = OptuNet()

        trainer = TUTrainer(
            accelerator="gpu",
            enable_progress_bar=False,
            max_epochs=EPOCHS)
        
        # loss
        loss = ELBOLoss(
            model=model,
            inner_loss=nn.CrossEntropyLoss(),
            kl_weight=1/10000,
            num_samples=3,
        )

        routine = ClassificationRoutine(
            model=model,
            num_classes=datamodule.num_classes,
            loss=loss,
            optim_recipe=optim_optunet(model),
            is_ensemble=True
        )

        trainer.fit(model=routine, datamodule=datamodule)

        # Save the trained model
        save_path = Path(MODEL_PATH, f"model_{tag}{i}.pt")
        trainer.save_checkpoint(save_path)
    
    print(f"Trained {n_models} models. Saved to {MODEL_PATH}")

In [None]:
train_optunets(n_models=NMODELS, start_idx=30, tag="t")

### Model Loading

In [12]:
def save_trained_model(model, path):
    # Ensure path exists
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    # Save the model
    torch.save(model.state_dict(), path)

In [13]:
def load_trained_models(n_models = 100, start_idx = 0, tag=""):
    posterior_models = []

    for i in range(start_idx, n_models):
        path = Path(MODEL_PATH, f"model_{tag}{i}.pt")
        model = load_trained_optunet(path)
        model = model.to(DEVICE) # Needed?
        posterior_models.append(model)
    
    print(f"Loaded {len(posterior_models)} models")
    return posterior_models

In [14]:
posterior_models = load_trained_models(n_models=31, tag="t")

  checkpoint = torch.load(path)


Loaded 31 models


## Baselines

### Dropout

In [13]:
class OptuDrop(OptuNet):
    def __init__(self):
        super().__init__()
        self.dropout = nn.Dropout(p=DROPOUT_RATE)
    
    def forward(self, x):
        x = self.dropout(super().forward(x))
        return x

In [None]:
model = OptuDrop()
loss_fn = nn.CrossEntropyLoss()
routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss_fn,
    optim_recipe=optim_optunet(model),
    is_ensemble=False
)

trainer = TUTrainer(
    accelerator="gpu",
    enable_progress_bar=True,
    max_epochs=EPOCHS
)

trainer.fit(model=routine, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | model            | OptuDrop         | 392    | train
1 | loss             | CrossEntropyLoss | 0      | train
2 | format_batch_fn  | Identity         | 0      | train
3 | val_cls_metrics  | MetricCollection | 0      | train
4 | test_cls_metrics | MetricCollection | 0      | train
5 | test_id_entropy  | Entropy          | 0      | train
6 | mixup   

Epoch 0: 100%|██████████| 938/938 [00:09<00:00, 94.13it/s, v_num=106, train_loss=1.910]



Epoch 1: 100%|██████████| 938/938 [00:05<00:00, 163.68it/s, v_num=106, train_loss=1.650, Acc%=41.30]



Epoch 2: 100%|██████████| 938/938 [00:06<00:00, 154.54it/s, v_num=106, train_loss=1.680, Acc%=48.20]



Epoch 3: 100%|██████████| 938/938 [00:06<00:00, 147.46it/s, v_num=106, train_loss=1.620, Acc%=65.00]



Epoch 4: 100%|██████████| 938/938 [00:05<00:00, 157.83it/s, v_num=106, train_loss=1.710, Acc%=67.50]



Epoch 5: 100%|██████████| 938/938 [00:05<00:00, 164.73it/s, v_num=106, train_loss=1.250, Acc%=65.80]



Epoch 6: 100%|██████████| 938/938 [00:05<00:00, 158.81it/s, v_num=106, train_loss=1.640, Acc%=71.70]



Epoch 7: 100%|██████████| 938/938 [00:05<00:00, 160.11it/s, v_num=106, train_loss=1.310, Acc%=68.70]



Epoch 8: 100%|██████████| 938/938 [00:06<00:00, 151.05it/s, v_num=106, train_loss=1.070, Acc%=71.80]



Epoch 9: 100%|██████████| 938/938 [00:06<00:00, 143.31it/s, v_num=106, train_loss=1.230, Acc%=74.90]



Epoch 10: 100%|██████████| 938/938 [00:05<00:00, 159.40it/s, v_num=106, train_loss=1.050, Acc%=72.70]



Epoch 11: 100%|██████████| 938/938 [00:05<00:00, 162.21it/s, v_num=106, train_loss=1.260, Acc%=76.10]



Epoch 12: 100%|██████████| 938/938 [00:06<00:00, 144.78it/s, v_num=106, train_loss=1.490, Acc%=68.00]



Epoch 13: 100%|██████████| 938/938 [00:06<00:00, 142.87it/s, v_num=106, train_loss=0.844, Acc%=75.20]



Epoch 14: 100%|██████████| 938/938 [00:06<00:00, 151.33it/s, v_num=106, train_loss=0.933, Acc%=74.30]



Epoch 15: 100%|██████████| 938/938 [00:06<00:00, 153.40it/s, v_num=106, train_loss=1.250, Acc%=73.60]



Epoch 16: 100%|██████████| 938/938 [00:06<00:00, 148.41it/s, v_num=106, train_loss=1.200, Acc%=77.80]



Epoch 17: 100%|██████████| 938/938 [00:06<00:00, 144.26it/s, v_num=106, train_loss=1.040, Acc%=76.70]



Epoch 18: 100%|██████████| 938/938 [00:06<00:00, 150.53it/s, v_num=106, train_loss=1.030, Acc%=77.90]



Epoch 19: 100%|██████████| 938/938 [00:06<00:00, 145.90it/s, v_num=106, train_loss=0.948, Acc%=76.80]



Epoch 25: 100%|██████████| 938/938 [00:06<00:00, 154.13it/s, v_num=106, train_loss=1.170, Acc%=79.00]



Epoch 27: 100%|██████████| 938/938 [00:06<00:00, 142.56it/s, v_num=106, train_loss=1.210, Acc%=78.50]



Epoch 59: 100%|██████████| 938/938 [00:07<00:00, 128.11it/s, v_num=106, train_loss=1.190, Acc%=83.20]

`Trainer.fit` stopped: `max_epochs=60` reached.


Epoch 59: 100%|██████████| 938/938 [00:07<00:00, 127.97it/s, v_num=106, train_loss=1.190, Acc%=83.20]


In [96]:
# Testing
results = trainer.test(model=routine, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 88.05it/s]


### viBNN

### SWAG

In [244]:
# Instantiate the model
model = OptuNet()

# Define SWAG model, using the number of samples to approximate the posterior
swag_model = TUSWAG(model, cycle_start=20, cycle_length=10)  # Set the number of samples for SWAG
# Set up optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Set up the ClassificationRoutine
routine = ClassificationRoutine(model,
                                num_classes=datamodule.num_classes,
                                loss=nn.CrossEntropyLoss(), 
                                optim_recipe=optim_optunet(model),
                                is_ensemble=False)

# Instantiate the trainer with the routine
trainer = TUTrainer(accelerator="gpu", max_epochs=EPOCHS, enable_progress_bar=True)

# Start the training process
trainer.fit(model=routine, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | model            | OptuNet          | 392    | train
1 | loss             | CrossEntropyLoss | 0      | train
2 | format_batch_fn  | Identity         | 0      | train
3 | val_cls_metrics  | MetricCollection | 0      | train
4 | test_cls_metrics | MetricCollection | 0      | train
5 | test_id_entropy  | Entropy          | 0      | train
6 | mixup            | Identity         | 0      | train
--------------------------------------------------------------
392       Trainable params
0         Non-trainable params
392       Total params
0.002     Total estimated model params size (MB)
31        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 938/938 [00:10<00:00, 89.24it/s, v_num=110, train_loss=2.190]



Epoch 1: 100%|██████████| 938/938 [00:05<00:00, 163.87it/s, v_num=110, train_loss=1.550, Acc%=19.40]



Epoch 2: 100%|██████████| 938/938 [00:06<00:00, 143.22it/s, v_num=110, train_loss=1.560, Acc%=47.30]



Epoch 3: 100%|██████████| 938/938 [00:06<00:00, 156.00it/s, v_num=110, train_loss=0.687, Acc%=70.30]



Epoch 4: 100%|██████████| 938/938 [00:06<00:00, 153.72it/s, v_num=110, train_loss=0.852, Acc%=71.80]



Epoch 5: 100%|██████████| 938/938 [00:06<00:00, 148.31it/s, v_num=110, train_loss=0.868, Acc%=71.80]



Epoch 7: 100%|██████████| 938/938 [00:06<00:00, 142.90it/s, v_num=110, train_loss=0.647, Acc%=79.10]



Epoch 8: 100%|██████████| 938/938 [00:07<00:00, 126.76it/s, v_num=110, train_loss=1.140, Acc%=78.10]



Epoch 10: 100%|██████████| 938/938 [00:07<00:00, 128.91it/s, v_num=110, train_loss=0.784, Acc%=78.60]



Epoch 11: 100%|██████████| 938/938 [00:06<00:00, 144.44it/s, v_num=110, train_loss=1.130, Acc%=77.80]



Epoch 12: 100%|██████████| 938/938 [00:06<00:00, 144.13it/s, v_num=110, train_loss=1.150, Acc%=76.90]



Epoch 13: 100%|██████████| 938/938 [00:06<00:00, 144.36it/s, v_num=110, train_loss=0.975, Acc%=74.60]



Epoch 14: 100%|██████████| 938/938 [00:06<00:00, 142.73it/s, v_num=110, train_loss=1.030, Acc%=68.10]



Epoch 20: 100%|██████████| 938/938 [00:06<00:00, 136.00it/s, v_num=110, train_loss=1.100, Acc%=79.20]



Epoch 59: 100%|██████████| 938/938 [00:07<00:00, 133.25it/s, v_num=110, train_loss=0.944, Acc%=79.70]

`Trainer.fit` stopped: `max_epochs=60` reached.


Epoch 59: 100%|██████████| 938/938 [00:07<00:00, 133.12it/s, v_num=110, train_loss=0.944, Acc%=79.70]


In [247]:
trainer.test(model=routine, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 96.82it/s]


[{'test/cal/ECE': 0.03635436296463013,
  'test/cal/aECE': 0.035149142146110535,
  'test/cls/Acc': 0.7967000007629395,
  'test/cls/Brier': 0.29376837611198425,
  'test/cls/NLL': 0.6377840638160706,
  'test/sc/AUGRC': 0.04629536718130112,
  'test/sc/AURC': 0.059700217097997665,
  'test/sc/Cov@5Risk': 0.5669000148773193,
  'test/sc/Risk@80Cov': 0.12212499976158142,
  'test/cls/Entropy': 0.6921000480651855}]

In [246]:
# Evaluate on test set
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for data, target in datamodule.test_dataloader()[0]:
        output = swag_model(data)
        preds = output.argmax(dim=1)
        all_preds.append(preds)
        all_labels.append(target)

all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
print(f'Test Accuracy: {accuracy * 100:.2f}%')


Test Accuracy: 79.67%


### Laplace

In [255]:
# Model
model = OptuNet()

def scheduler_laplace(optimizer):
    return MultiStepLR(
        optimizer,
        milestones=[15, 30],
        gamma=0.5
    )

# Routine
loss_fn = nn.CrossEntropyLoss()  # Standard cross-entropy loss
routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss_fn,
    optim_recipe=optim_optunet(model),
    # scheduler_recipe=scheduler_laplace,
    is_ensemble=False
)

# Train the model to MAP estimate
trainer = TUTrainer(
    accelerator="gpu",
    max_epochs=EPOCHS,
    enable_progress_bar=True
)
trainer.fit(model=routine, datamodule=datamodule)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | model            | OptuNet          | 392    | train
1 | loss             | CrossEntropyLoss | 0      | train
2 | format_batch_fn  | Identity         | 0      | train
3 | val_cls_metrics  | MetricCollection | 0      | train
4 | test_cls_metrics | MetricCollection | 0      | train
5 | test_id_entropy  | Entropy          | 0      | train
6 | mixup            | Identity         | 0      | train
--------------------------------------------------------------
392       Trainable params
0         Non-trainable params
392       Total params
0.002     Total estimated model params size (MB)
31        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 938/938 [00:09<00:00, 98.00it/s, v_num=111, train_loss=1.990]



Epoch 1: 100%|██████████| 938/938 [00:05<00:00, 163.53it/s, v_num=111, train_loss=1.570, Acc%=40.30]



Epoch 2: 100%|██████████| 938/938 [00:06<00:00, 148.50it/s, v_num=111, train_loss=1.060, Acc%=51.40]



Epoch 3: 100%|██████████| 938/938 [00:07<00:00, 126.88it/s, v_num=111, train_loss=1.060, Acc%=59.20]



Epoch 4: 100%|██████████| 938/938 [00:06<00:00, 155.34it/s, v_num=111, train_loss=1.360, Acc%=65.70]



Epoch 5: 100%|██████████| 938/938 [00:06<00:00, 146.63it/s, v_num=111, train_loss=1.070, Acc%=65.80]



Epoch 6: 100%|██████████| 938/938 [00:06<00:00, 154.08it/s, v_num=111, train_loss=1.150, Acc%=62.60]



Epoch 7: 100%|██████████| 938/938 [00:06<00:00, 153.33it/s, v_num=111, train_loss=1.010, Acc%=62.30]



Epoch 8: 100%|██████████| 938/938 [00:06<00:00, 150.81it/s, v_num=111, train_loss=1.250, Acc%=66.90]



Epoch 9: 100%|██████████| 938/938 [00:05<00:00, 157.10it/s, v_num=111, train_loss=1.290, Acc%=70.10]



Epoch 10: 100%|██████████| 938/938 [00:06<00:00, 151.55it/s, v_num=111, train_loss=1.180, Acc%=71.20]



Epoch 11: 100%|██████████| 938/938 [00:06<00:00, 152.28it/s, v_num=111, train_loss=1.310, Acc%=71.10]



Epoch 12: 100%|██████████| 938/938 [00:06<00:00, 151.89it/s, v_num=111, train_loss=1.130, Acc%=73.60]



Epoch 13: 100%|██████████| 938/938 [00:06<00:00, 139.59it/s, v_num=111, train_loss=0.788, Acc%=71.10]



Epoch 14: 100%|██████████| 938/938 [00:06<00:00, 143.90it/s, v_num=111, train_loss=0.957, Acc%=75.20]



Epoch 15: 100%|██████████| 938/938 [00:06<00:00, 152.56it/s, v_num=111, train_loss=0.793, Acc%=75.30]



Epoch 16: 100%|██████████| 938/938 [00:06<00:00, 150.52it/s, v_num=111, train_loss=0.925, Acc%=73.80]



Epoch 17: 100%|██████████| 938/938 [00:06<00:00, 154.15it/s, v_num=111, train_loss=0.819, Acc%=74.30]



Epoch 18: 100%|██████████| 938/938 [00:05<00:00, 158.87it/s, v_num=111, train_loss=0.597, Acc%=74.50]



Epoch 19: 100%|██████████| 938/938 [00:05<00:00, 158.12it/s, v_num=111, train_loss=0.736, Acc%=75.40]



Epoch 20: 100%|██████████| 938/938 [00:06<00:00, 153.78it/s, v_num=111, train_loss=0.776, Acc%=76.40]



Epoch 21: 100%|██████████| 938/938 [00:05<00:00, 157.26it/s, v_num=111, train_loss=0.687, Acc%=76.10]



Epoch 22: 100%|██████████| 938/938 [00:06<00:00, 151.80it/s, v_num=111, train_loss=0.602, Acc%=72.70]



Epoch 23: 100%|██████████| 938/938 [00:06<00:00, 150.79it/s, v_num=111, train_loss=0.942, Acc%=76.00]



Epoch 24: 100%|██████████| 938/938 [00:05<00:00, 157.47it/s, v_num=111, train_loss=0.862, Acc%=74.40]



Epoch 25: 100%|██████████| 938/938 [00:06<00:00, 154.81it/s, v_num=111, train_loss=1.180, Acc%=76.60]



Epoch 26: 100%|██████████| 938/938 [00:06<00:00, 153.44it/s, v_num=111, train_loss=1.600, Acc%=75.60]



Epoch 27: 100%|██████████| 938/938 [00:06<00:00, 152.66it/s, v_num=111, train_loss=1.180, Acc%=77.20]



Epoch 28: 100%|██████████| 938/938 [00:06<00:00, 152.92it/s, v_num=111, train_loss=1.200, Acc%=76.50]



Epoch 29: 100%|██████████| 938/938 [00:05<00:00, 156.67it/s, v_num=111, train_loss=0.617, Acc%=75.40]



Epoch 30: 100%|██████████| 938/938 [00:06<00:00, 144.65it/s, v_num=111, train_loss=1.210, Acc%=75.60]



Epoch 31: 100%|██████████| 938/938 [00:06<00:00, 148.81it/s, v_num=111, train_loss=0.837, Acc%=75.40]



Epoch 32: 100%|██████████| 938/938 [00:06<00:00, 146.23it/s, v_num=111, train_loss=1.410, Acc%=77.10]



Epoch 33: 100%|██████████| 938/938 [00:06<00:00, 136.56it/s, v_num=111, train_loss=0.536, Acc%=76.50]



Epoch 34: 100%|██████████| 938/938 [00:06<00:00, 146.90it/s, v_num=111, train_loss=0.813, Acc%=77.20]



Epoch 35: 100%|██████████| 938/938 [00:06<00:00, 152.77it/s, v_num=111, train_loss=0.887, Acc%=77.10]



Epoch 36: 100%|██████████| 938/938 [00:06<00:00, 145.62it/s, v_num=111, train_loss=0.781, Acc%=77.50]



Epoch 37: 100%|██████████| 938/938 [00:06<00:00, 134.45it/s, v_num=111, train_loss=0.650, Acc%=77.00]



Epoch 38: 100%|██████████| 938/938 [00:07<00:00, 126.57it/s, v_num=111, train_loss=0.866, Acc%=77.70]



Epoch 39: 100%|██████████| 938/938 [00:06<00:00, 147.64it/s, v_num=111, train_loss=1.100, Acc%=77.50]



Epoch 40: 100%|██████████| 938/938 [00:05<00:00, 159.41it/s, v_num=111, train_loss=0.762, Acc%=78.00]



Epoch 41: 100%|██████████| 938/938 [00:05<00:00, 156.69it/s, v_num=111, train_loss=0.779, Acc%=75.80]



Epoch 42: 100%|██████████| 938/938 [00:06<00:00, 140.07it/s, v_num=111, train_loss=1.180, Acc%=77.50]



Epoch 43: 100%|██████████| 938/938 [00:06<00:00, 139.94it/s, v_num=111, train_loss=0.852, Acc%=75.80]



Epoch 44: 100%|██████████| 938/938 [00:06<00:00, 138.22it/s, v_num=111, train_loss=1.190, Acc%=77.70]



Epoch 45: 100%|██████████| 938/938 [00:06<00:00, 143.15it/s, v_num=111, train_loss=0.915, Acc%=76.90]



Epoch 46: 100%|██████████| 938/938 [00:06<00:00, 150.63it/s, v_num=111, train_loss=0.921, Acc%=75.40]



Epoch 47: 100%|██████████| 938/938 [00:06<00:00, 150.21it/s, v_num=111, train_loss=0.806, Acc%=77.30]



Epoch 48: 100%|██████████| 938/938 [00:06<00:00, 154.64it/s, v_num=111, train_loss=0.798, Acc%=77.20]



Epoch 49: 100%|██████████| 938/938 [00:06<00:00, 146.48it/s, v_num=111, train_loss=0.626, Acc%=76.90]



Epoch 50: 100%|██████████| 938/938 [00:06<00:00, 150.42it/s, v_num=111, train_loss=0.702, Acc%=77.10]



Epoch 51: 100%|██████████| 938/938 [00:06<00:00, 147.36it/s, v_num=111, train_loss=1.090, Acc%=77.90]



Epoch 52: 100%|██████████| 938/938 [00:06<00:00, 141.51it/s, v_num=111, train_loss=0.914, Acc%=77.80]



Epoch 53: 100%|██████████| 938/938 [00:06<00:00, 146.21it/s, v_num=111, train_loss=1.080, Acc%=77.40]



Epoch 54: 100%|██████████| 938/938 [00:06<00:00, 142.31it/s, v_num=111, train_loss=1.000, Acc%=77.10]



Epoch 55: 100%|██████████| 938/938 [00:06<00:00, 143.81it/s, v_num=111, train_loss=0.917, Acc%=76.70]



Epoch 56: 100%|██████████| 938/938 [00:06<00:00, 138.98it/s, v_num=111, train_loss=0.645, Acc%=77.20]



Epoch 57: 100%|██████████| 938/938 [00:06<00:00, 143.60it/s, v_num=111, train_loss=0.672, Acc%=78.20]



Epoch 58: 100%|██████████| 938/938 [00:06<00:00, 154.10it/s, v_num=111, train_loss=1.070, Acc%=77.30]



Epoch 59: 100%|██████████| 938/938 [00:05<00:00, 162.55it/s, v_num=111, train_loss=0.675, Acc%=78.20]



Epoch 59: 100%|██████████| 938/938 [00:06<00:00, 140.49it/s, v_num=111, train_loss=0.675, Acc%=77.20]

`Trainer.fit` stopped: `max_epochs=60` reached.


Epoch 59: 100%|██████████| 938/938 [00:06<00:00, 140.35it/s, v_num=111, train_loss=0.675, Acc%=77.20]


In [256]:
trainer.test(model=routine, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 84.15it/s]



Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 83.52it/s]


[{'test/cal/ECE': 0.05171032249927521,
  'test/cal/aECE': 0.05260957032442093,
  'test/cls/Acc': 0.7716000080108643,
  'test/cls/Brier': 0.3349318504333496,
  'test/cls/NLL': 0.7045778632164001,
  'test/sc/AUGRC': 0.060908034443855286,
  'test/sc/AURC': 0.08823738992214203,
  'test/sc/Cov@5Risk': nan,
  'test/sc/Risk@80Cov': 0.15062500536441803,
  'test/cls/Entropy': 0.7887654304504395}]

In [257]:
# Apply Laplace approximation
laplace_model = Laplace(model, likelihood='classification', subset_of_weights='last_layer', hessian_structure='full')
laplace_model.fit(datamodule.train_dataloader())  # Fit the Laplace model on training data
laplace_model.optimize_prior_precision()  # Optimize prior precision



In [264]:
correct = 0
total = 0

# Test loop
for inputs, targets in datamodule.test_dataloader()[0]:
    # Compute predictive distribution using Laplace model
    predictive_probs = laplace_model(inputs)  # Returns predictive probabilities
    predicted_labels = predictive_probs.argmax(dim=1)

    correct += (predicted_labels == targets).sum().item()
    total += targets.size(0)

# Compute accuracy
accuracy = correct / total
print(f"Laplace Model Test Accuracy: {accuracy * 100:.2f}%")


Laplace Model Test Accuracy: 77.16%


### SGHMC

### pSGLD

### DE

In [None]:
def train_deep_ensemble(model_class, datamodule, num_ensembles=10, save_dir="models/deep_ensemble"):
    """
    Train a deep ensemble of models.
    Args:
        model_class: The model class (e.g., OptuNet).
        datamodule: Data module providing train/val/test splits.
        num_ensembles (int): Number of models in the ensemble.
        save_dir (str): Directory to save the trained models.
    Returns:
        list: Trained models.
    """
    save_path = Path(save_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    trained_models = []
    for i in range(num_ensembles):
        print(f"Training model {i + 1}/{num_ensembles}...")
        
        # Initialize a new model
        model = model_class()
        
        # Define loss and optimizer
        loss_fn = nn.CrossEntropyLoss()

        routine = ClassificationRoutine(
            model=model,
            num_classes=datamodule.num_classes,
            loss=loss_fn,
            optim_recipe=optim_optunet(model),  # Replace with appropriate optimizer
            is_ensemble=False
        )

        # Trainer
        trainer = TUTrainer(
            accelerator="gpu",
            max_epochs=EPOCHS,
            enable_progress_bar=True
        )

        # Train the model
        trainer.fit(model=routine, datamodule=datamodule)

        # Save the model
        model_path = save_path / f"model_{i+1}.pt"
        trainer.save_checkpoint(model_path)

        trained_models.append(model)

    return trained_models


In [35]:
def ensemble_predict(models, dataloader):
    """
    Perform inference using a deep ensemble.
    Args:
        models (list): List of trained models.
        dataloader (DataLoader): DataLoader for test data.
    Returns:
        np.ndarray: Averaged predictions from the ensemble.
    """
    all_preds = []

    for model in models:
        model.eval()
        preds = []

        for inputs, _ in dataloader:
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs)  # Logits shape: (batch_size, num_classes)
                preds.append(outputs.cpu().numpy())
        
        all_preds.append(np.concatenate(preds, axis=0))  # Combine batches

    # Stack predictions from all models and average
    ensemble_preds = np.mean(np.stack(all_preds, axis=0), axis=0)
    return ensemble_preds


In [None]:
de_models = train_deep_ensemble(OptuNet, datamodule, num_ensembles=5)

## Scores

### Preds Sampling

In [272]:
def evaluate_confidence_interval(model, dataloader, confidence=0.95):
    # model.eval()

    all_preds = []
    all_targets = []
    lower_bounds = []
    upper_bounds = []

    with torch.no_grad():
        for inputs, targets in dataloader:
            # inputs = inputs.to(DEVICE)
            # targets = targets.to(DEVICE)

            # For Monte carlo estimation of the ELBO using 3 samples
            preds = torch.stack([model(inputs) for _ in range(MC_SAMPLES)], dim=0)

            # Calculate mean and standard deviation
            preds_mean = preds.mean(dim=0)
            preds_std = preds.std(dim=0)

            # Apply softmax to the mean predictions for probabilities
            preds_mean = F.softmax(preds_mean, dim=1)


            # Compute confidence intervals
            z_value = st.norm.ppf(1 - (1 - confidence) / 2)  #dynamically computing the z-score for a given confidence level (e.g., 95%, 99%).
            ci_lower = preds_mean - z_value * preds_std
            ci_upper = preds_mean + z_value * preds_std

            all_preds.append(preds_mean)
            all_targets.append(targets)
            lower_bounds.append(ci_lower)
            upper_bounds.append(ci_upper)

    # Concatenate results for all batches
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    lower_bounds = torch.cat(lower_bounds)
    upper_bounds = torch.cat(upper_bounds)

    return all_preds, all_targets, lower_bounds, upper_bounds


### Posterior Estimation

In [309]:
def monte_carlo_sampling(model, data_loader, num_samples=100):
    # model.eval()
    predictions = []

    for _ in range(num_samples):
        sampled_preds = []
        for inputs, _ in data_loader:
            # inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs)
                sampled_preds.append(outputs.cpu().numpy())
        predictions.append(np.concatenate(sampled_preds, axis=0))

    return np.array(predictions) # Shape: (num_samples, num_examples, num_classes)


In [299]:
def target_model_predictions(models, data_loader):
    all_predictions = []

    for model in models:
        model.eval()
        preds = []
        for inputs, _ in data_loader:
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs)  # Logits
                preds.append(outputs.cpu().numpy())
        all_predictions.append(np.concatenate(preds, axis=0))  # Combine batches

    return np.array(all_predictions)  # Shape: (num_models, num_datapoints, num_classes)

In [300]:
def generate_target_samples(models):
    weight_samples = []

    for model in models:
        model.eval()
        weights = []
        for param in model.parameters():
            weights.append(param.detach().cpu().numpy().flatten())  # Flatten weights
        weight_samples.append(np.concatenate(weights))

    return np.array(weight_samples)  # Shape: (num_models, num_weights)


In [301]:
def generate_source_samples(model, dataloader, num_samples=N_SAMPLES):
    # model.eval()
    samples = []

    for _ in range(num_samples):
        sampled_weights = []
        for param in model.parameters():
            sampled_weights.append(param.detach().cpu().numpy().flatten())  # Flatten weights
        samples.append(np.concatenate(sampled_weights))

    return np.array(samples)  # Shape: (num_samples, num_weights)


### MMD

In [302]:
def calculate_mmd(model, posterior_models, test_dataset, num_samples=100):
    # # Posterior estimation with weights?
    # target_weights = generate_target_samples(posterior_models)
    # source_weights = generate_source_samples(model, test_dataset, num_samples=num_samples)

    # mmd_weights = mmdagg(
    #     X=source_weights,
    #     Y=target_weights,
    #     alpha=0.05,
    #     kernel="laplace_gaussian",
    #     number_bandwidths=10,
    #     weights_type="uniform",
    #     B1=2000,
    #     B2=2000,
    #     B3=50,
    #     seed=42424242
    # )

    # Posterior estimation with predictions?
    target_preds = target_model_predictions(posterior_models, test_dataset)
    source_preds = monte_carlo_sampling(model, test_dataset, num_samples=num_samples)
    target_avg = np.mean(target_preds, axis=1)
    source_avg = np.mean(source_preds, axis=1)

    mmd_preds = mmdagg(
        X=source_avg,
        Y=target_avg,
        alpha=0.05,
        kernel="laplace_gaussian",
        number_bandwidths=10,
        weights_type="uniform",
        B1=2000,
        B2=2000,
        B3=50,
        seed=42424242
    )

    return None, mmd_preds

### AUPR

In [303]:
def aupr_score(predictions, labels):
    n_classes = predictions.shape[1]
    precision = {}
    recall = {}
    aupr = {}

    for i in range(n_classes):
        # Binarize the labels for class i
        binary_labels = (labels == i).astype(int)
        precision[i], recall[i], _ = precision_recall_curve(binary_labels, predictions[:, i])
        aupr[i] = auc(recall[i], precision[i])

    # Optional: Aggregate AUPR
    mean_aupr = np.mean(list(aupr.values()))
    return mean_aupr, aupr

### FPR95

In [304]:
def fpr95_score(predictions, labels):
    # Calculate predicted classes and confidences
    predicted_classes = np.argmax(predictions, axis=1)  # Class with highest probability
    confidences = np.max(predictions, axis=1)           # Confidence scores (max probability)
    binary_labels = (predicted_classes == labels).astype(int) # Determine binary labels (1 for correct, 0 for incorrect)
    fpr, tpr, thresholds = roc_curve(binary_labels, confidences)

    # Find the threshold where TPR is closest to 95%
    idx = np.where(tpr >= 0.95)[0][0]
    return fpr[idx]


### ACE

In [314]:
def ace_score(predictions, labels, n_bins=10):
    # Convert predictions and labels to numpy arrays
    predicted_probs = predictions.cpu().numpy()
    true_labels = labels.cpu().numpy()

    # One-hot encode true labels for multi-class calibration
    num_classes = predicted_probs.shape[1]
    true_labels_one_hot = np.eye(num_classes)[true_labels]  # Shape: (num_samples, num_classes)

    # Initialize ACE
    ace = 0.0

    # Loop over each class
    for class_idx in range(num_classes):
        # Get predicted probabilities and true labels for the current class
        prob_pred = predicted_probs[:, class_idx]
        prob_true = true_labels_one_hot[:, class_idx]

        # Compute calibration curve
        fraction_of_positives, mean_predicted_value = calibration_curve(prob_true, prob_pred, n_bins=n_bins)

        # Compute ACE for this class
        ace += np.mean(np.abs(fraction_of_positives - mean_predicted_value))

    # Average over all classes
    ace /= num_classes
    return ace

### General Scoring Function

In [315]:
def score_model(model, posterior_models, test_loader):
    # model.to(DEVICE)
    preds, targets, lower_bounds, upper_bounds = evaluate_confidence_interval(model, test_loader)
    predictions = preds.cpu().numpy()
    labels = targets.cpu().numpy()

    # Scores
    mean_aupr, aupr = aupr_score(predictions, labels)
    fpr95 = fpr95_score(predictions, labels)
    ace = ace_score(preds, targets)
    mmd_weights, mmd_preds = calculate_mmd(model, posterior_models, test_loader, num_samples=N_SAMPLES)

    print(f"AUPR: {mean_aupr}")
    print(f"FPR95: {fpr95}")
    print(f"ACE: {ace:.4f}")
    print(f"MMD: {np.sum(mmd_preds)}")

In [313]:
score_model(laplace_model, posterior_models, datamodule.test_dataloader()[0])

ACE score: 0.1539476851886034
AUPR: 0.8135783680644151
FPR95: 0.7802101576182137
ACE: 0.3467
ACE2: 0.1539
MMD: 13.145983918683527
