# Posterior Symmetry Reproduction

In [None]:
from pathlib import Path

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.models.lenet import bayesian_lenet, lenet
from torch_uncertainty.models import mc_dropout
from torch_uncertainty.routines import ClassificationRoutine
from lightning.pytorch import LightningModule

from sklearn.metrics import precision_recall_curve, roc_curve, auc, accuracy_score
import numpy as np

from laplace import Laplace
from utils.swa_gaussian.swag.posteriors import SWAG
import utils.swa_gaussian.swag.posteriors as swag_posteriors
from utils.posterior_symmetry.mmd.mmd_torch import mmdagg

from pathlib import Path
from safetensors.torch import load_file

In [86]:
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_PATH = "data"

# Parameters from paper
EPOCHS = 60
BATCH_SIZE = 64
LEARNING_RATE = 0.04
WEIGHT_DECAY = 2e-4

NUM_WORKERS = 4
## OptuNet params
DROPOUT_RATE = 0.2 # last layer dropout rate

# Models
NMODELS = 100
MODEL_PATH = Path("models", "trained_optunets")

In [50]:
# Load MNIST data
root = Path(DATA_PATH)
datamodule = MNISTDataModule(root=root, batch_size=BATCH_SIZE, eval_ood=False, num_workers=NUM_WORKERS)

## OptuNet

In [None]:
class OptuNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Add layers for OptuNet (use Section C.2.1 from the paper for details)
        # Layers: Conv2D (out_ch=2, ks=4, groups=1) -> Max Pooling (ks=3, stride=3) -> ReLU -> Conv2D (out_ch=10, ks=5, groups=2) -> Average Pooling -> ReLU -> Linear 10x10
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=4, groups=1, bias=False)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=3)
        self.conv2 = nn.Conv2d(in_channels=2, out_channels=10, kernel_size=5, groups=2, bias=False)
        self.pool2 = nn.AvgPool2d(kernel_size=2)
        self.fc1 = nn.Linear(in_features=10, out_features=10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.pool1(self.conv1(x)))  # First conv, max pooling, ReLU
        x = self.relu(self.pool2(self.conv2(x)))  # Second conv, avg pooling, ReLU
        x = torch.mean(x, dim=(2, 3))
        x = self.fc1(x)  # Linear layer
        return x

# Optimizer and LR scheduler
def optim_optunet(model: nn.Module):
    optimizer = optim.SGD(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY
    )
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[15, 30],
        gamma=0.5
    )
    return {"optimizer": optimizer, "lr_scheduler": scheduler}

# Loss function
def loss_optunet(model: nn.Module):
    loss = ELBOLoss(
        model=model,
        inner_loss=nn.CrossEntropyLoss(),
        kl_weight= 1/10000,
        num_samples=3,
    )
    return loss

In [None]:
# Load functions
def load_optunet_model(version: int):
    model = OptuNet()
    path = Path(f"models/mnist-optunet-0-8191/version_{version}.safetensors")

    if not path.exists():
        raise ValueError("File does not exist")

    state_dict = load_file(path)
    model.load_state_dict(state_dict=state_dict)
    return model

def load_trained_optunet(path):
    checkpoint = torch.load(path)

    # Filter out unwanted keys (e.g., those related to loss)
    state_dict = {
        k.replace("model.", ""): v
        for k, v in checkpoint["state_dict"].items()
        if not k.startswith("loss.")
    }
    model = OptuNet()
    model.load_state_dict(state_dict)
    return model

In [173]:
# Compile posterior estimation models
posterior_models = []

for i in range(30):
    path = Path(f"models/trained_optunets/model_t{i}.pt")
    model = load_trained_optunet(path)
    model = model.to(DEVICE) # Needed?
    posterior_models.append(model)

print("Postreior models loaded:", len(posterior_models))


Postreior models loaded: 30


  checkpoint = torch.load(path)


### Model Training

In [None]:
trainer = TUTrainer(
    accelerator="gpu",
    enable_progress_bar=True,
    max_epochs=EPOCHS)

# model
model = OptuNet()#load_optunet_model(version=1000)

# loss
loss = ELBOLoss(
    model=model,
    inner_loss=nn.CrossEntropyLoss(),
    kl_weight=1 / 10000,
    num_samples=3,
)

routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss,
    optim_recipe=optim_optunet(model),
    is_ensemble=True
)

trainer.fit(model=routine, datamodule=datamodule)

In [None]:
n_models = 100 # models to train

for i in range(n_models):
    model = OptuNet()

    trainer = TUTrainer(
        accelerator="gpu",
        enable_progress_bar=False,
        max_epochs=EPOCHS)

    routine = ClassificationRoutine(
        model=model,
        num_classes=datamodule.num_classes,
        loss=loss,
        optim_recipe=optim_optunet(model),
        is_ensemble=True
    )

    trainer.fit(model=routine, datamodule=datamodule)

    # Save the trained model
    save_path = Path(MODEL_PATH, f"model_t{i+1}.pt")
    trainer.save_checkpoint(save_path)

In [62]:
results = trainer.test(model=routine, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:02<00:00, 77.76it/s]



Testing DataLoader 0: 100%|██████████| 157/157 [00:02<00:00, 77.14it/s]


## Baselines

### Dropout

In [167]:
class OptuDrop(OptuNet):
    def __init__(self):
        super().__init__()
        self.dropout = nn.Dropout(p=DROPOUT_RATE)
    
    def forward(self, x):
        x = self.relu(self.pool1(self.conv1(x)))  # First conv, max pooling, ReLU
        x = self.relu(self.pool2(self.conv2(x)))  # Second conv, avg pooling, ReLU
        x = torch.mean(x, dim=(2, 3)) # Global average pooling
        x = self.dropout(self.fc1(x))  # Linear layer with dropout
        return x

In [168]:
model = OptuDrop()
loss_fn = nn.CrossEntropyLoss()
routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss_fn,
    optim_recipe=optim_optunet(model),
    is_ensemble=False
)

trainer = TUTrainer(
    accelerator="gpu",
    enable_progress_bar=True,
    max_epochs=EPOCHS
)

trainer.fit(model=routine, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | model            | OptuDrop         | 392    | train
1 | loss             | CrossEntropyLoss | 0      | train
2 | format_batch_fn  | Identity         | 0      | train
3 | val_cls_metrics  | MetricCollection | 0      | train
4 | test_cls_metrics | MetricCollection | 0      | train
5 | test_id_entropy  | Entropy          | 0      | train
6 | mixup            | Identity         | 0      | train
--------------------------------------------------------------
392       Trainable params
0         Non-trainable params
392       Total params
0.002     Total estimated model params size (MB)
32        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 938/938 [00:10<00:00, 91.55it/s, v_num=99, train_loss=1.690]



Epoch 1: 100%|██████████| 938/938 [00:05<00:00, 161.97it/s, v_num=99, train_loss=1.460, Acc%=44.20]



Epoch 2: 100%|██████████| 938/938 [00:06<00:00, 147.66it/s, v_num=99, train_loss=1.370, Acc%=72.50]



Epoch 3: 100%|██████████| 938/938 [00:06<00:00, 136.12it/s, v_num=99, train_loss=0.996, Acc%=74.60]



Epoch 4: 100%|██████████| 938/938 [00:06<00:00, 144.97it/s, v_num=99, train_loss=1.680, Acc%=74.70]



Epoch 5: 100%|██████████| 938/938 [00:06<00:00, 139.50it/s, v_num=99, train_loss=1.060, Acc%=67.10]



Epoch 6: 100%|██████████| 938/938 [00:06<00:00, 148.99it/s, v_num=99, train_loss=1.440, Acc%=75.70]



Epoch 7: 100%|██████████| 938/938 [00:06<00:00, 151.31it/s, v_num=99, train_loss=1.280, Acc%=69.70]



Epoch 8: 100%|██████████| 938/938 [00:05<00:00, 159.07it/s, v_num=99, train_loss=1.520, Acc%=74.60]



Epoch 9: 100%|██████████| 938/938 [00:07<00:00, 132.41it/s, v_num=99, train_loss=1.370, Acc%=72.30]



Epoch 10: 100%|██████████| 938/938 [00:07<00:00, 131.08it/s, v_num=99, train_loss=1.300, Acc%=75.70]



Epoch 11: 100%|██████████| 938/938 [00:06<00:00, 138.74it/s, v_num=99, train_loss=1.310, Acc%=78.30]



Epoch 12: 100%|██████████| 938/938 [00:06<00:00, 136.94it/s, v_num=99, train_loss=1.430, Acc%=77.30]



Epoch 13: 100%|██████████| 938/938 [00:07<00:00, 132.03it/s, v_num=99, train_loss=1.300, Acc%=73.60]



Epoch 14: 100%|██████████| 938/938 [00:07<00:00, 131.19it/s, v_num=99, train_loss=1.590, Acc%=77.70]



Epoch 15: 100%|██████████| 938/938 [00:06<00:00, 143.57it/s, v_num=99, train_loss=1.360, Acc%=74.00]



Epoch 17: 100%|██████████| 938/938 [00:07<00:00, 124.88it/s, v_num=99, train_loss=1.240, Acc%=77.20]



Epoch 18: 100%|██████████| 938/938 [00:07<00:00, 129.06it/s, v_num=99, train_loss=1.180, Acc%=77.60]



Epoch 19: 100%|██████████| 938/938 [00:06<00:00, 135.98it/s, v_num=99, train_loss=1.270, Acc%=76.40]



Epoch 23: 100%|██████████| 938/938 [00:07<00:00, 124.06it/s, v_num=99, train_loss=1.150, Acc%=77.80]



Epoch 25: 100%|██████████| 938/938 [00:06<00:00, 136.95it/s, v_num=99, train_loss=1.070, Acc%=78.40]



Epoch 26: 100%|██████████| 938/938 [00:06<00:00, 137.06it/s, v_num=99, train_loss=1.330, Acc%=77.70]



Epoch 59: 100%|██████████| 938/938 [00:08<00:00, 113.92it/s, v_num=99, train_loss=1.120, Acc%=78.90]

`Trainer.fit` stopped: `max_epochs=60` reached.


Epoch 59: 100%|██████████| 938/938 [00:08<00:00, 113.71it/s, v_num=99, train_loss=1.120, Acc%=78.90]


In [170]:
# Testing
results = trainer.test(model=routine, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:02<00:00, 77.84it/s]


### viBNN

### SWAG

In [None]:
# Train OptuNet for SWAG
model = OptuNet()

routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=nn.CrossEntropyLoss(),
    optim_recipe=optim_optunet(model),
    is_ensemble=False # Single model (not ensemble here)
)

trainer = TUTrainer(
    accelerator="gpu",
    max_epochs=2*EPOCHS, # Train twice as long for SWAG
    enable_progress_bar=True
)

# Fit the model
trainer.fit(model=routine, datamodule=datamodule)

# Save checkpoints every 10 epochs from epoch 80 onward
checkpoint_dir = Path("models/swag_checkpoints/")
checkpoint_dir.mkdir(parents=True, exist_ok=True)

checkpoints = []
for epoch in range(80, 121, 10):
    checkpoint_path = checkpoint_dir / f"model_epoch_{epoch}.pt"
    trainer.save_checkpoint(checkpoint_path)
    checkpoints.append(checkpoint_path)


In [159]:
# Create the SWAG object
swag_model = SWAG(
    base=OptuNet,
    max_num_models=20,
    var_clamp=1e-30
)

# Add the collected checkpoints to the SWAG posterior
for checkpoint_path in checkpoints:
    checkpoint = torch.load(checkpoint_path)
    swag_model.collect_model(load_trained_optunet(checkpoint_path)) 

  checkpoint = torch.load(checkpoint_path)
  checkpoint = torch.load(path)


In [None]:
# For consistent results
class SWAGLightningWrapper(LightningModule):
    def __init__(self, swag_model, num_samples=10, scale=0.1):
        super().__init__()
        self.swag_model = swag_model
        self.num_samples = num_samples
        self.scale = scale

    def forward(self, x):
        preds = []
        for _ in range(self.num_samples):
            sampled_model = self.swag_model.sample(scale=self.scale)  # Sample from SWAG posterior
            sampled_model.eval()
            with torch.no_grad():
                preds.append(sampled_model(x))
        return torch.stack(preds).mean(dim=0)  # Aggregate predictions

    def predict_step(self, batch, batch_idx):
        x, _ = batch
        return self.forward(x)


In [164]:
# Wrap SWAG model for testing
swag_wrapper = SWAGLightningWrapper(
    swag_model=swag_model,
    num_samples=10,  # Number of posterior samples as per the paper
    scale=0.1        # Scale parameter for SWAG sampling
)

# Test the SWAG model
results = trainer.test(model=swag_wrapper, datamodule=datamodule)
print("SWAG Test Results:", results)


MisconfigurationException: No `test_step()` method defined to run `Trainer.test`.

In [155]:
trainer.test(model=routine, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 97.30it/s] 


[{'test/cal/ECE': 0.05344296991825104,
  'test/cal/aECE': 0.05416256561875343,
  'test/cls/Acc': 0.8267999887466431,
  'test/cls/Brier': 0.25886476039886475,
  'test/cls/NLL': 0.570395290851593,
  'test/sc/AUGRC': 0.037595879286527634,
  'test/sc/AURC': 0.050700489431619644,
  'test/sc/Cov@5Risk': 0.6302000284194946,
  'test/sc/Risk@80Cov': 0.09300000220537186,
  'test/cls/Entropy': 0.6775237917900085}]

### Laplace

In [77]:
# Model
model = OptuNet()

def scheduler_laplace(optimizer):
    return MultiStepLR(
        optimizer,
        milestones=[15, 30],
        gamma=0.5
    )

# Routine
loss_fn = nn.CrossEntropyLoss()  # Standard cross-entropy loss
routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss_fn,
    optim_recipe=optim_optunet(model),
    # scheduler_recipe=scheduler_laplace,
    is_ensemble=False
)

# Train the model to MAP estimate
trainer = TUTrainer(
    accelerator="gpu",
    max_epochs=EPOCHS,
    enable_progress_bar=True
)
trainer.fit(model=routine, datamodule=datamodule)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params | Mode 
--------------------------------------------------------------
0 | model            | OptuNet          | 392    | train
1 | loss             | CrossEntropyLoss | 0      | train
2 | format_batch_fn  | Identity         | 0      | train
3 | val_cls_metrics  | MetricCollection | 0      | train
4 | test_cls_metrics | MetricCollection | 0      | train
5 | test_id_entropy  | Entropy          | 0      | train
6 | mixup            | Identity         | 0      | train
--------------------------------------------------------------
392       Trainable params
0         Non-trainable params
392       Total params
0.002     Total estimated model params size (MB)
31        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 938/938 [00:09<00:00, 96.90it/s, v_num=57, train_loss=1.770]



Epoch 1: 100%|██████████| 938/938 [00:05<00:00, 160.81it/s, v_num=57, train_loss=1.070, Acc%=52.30]



Epoch 2: 100%|██████████| 938/938 [00:06<00:00, 143.72it/s, v_num=57, train_loss=1.740, Acc%=59.20]



Epoch 3: 100%|██████████| 938/938 [00:06<00:00, 139.31it/s, v_num=57, train_loss=0.841, Acc%=53.10]



Epoch 4: 100%|██████████| 938/938 [00:06<00:00, 144.08it/s, v_num=57, train_loss=1.010, Acc%=73.50]



Epoch 5: 100%|██████████| 938/938 [00:06<00:00, 148.42it/s, v_num=57, train_loss=1.680, Acc%=65.40]



Epoch 6: 100%|██████████| 938/938 [00:06<00:00, 148.22it/s, v_num=57, train_loss=0.599, Acc%=76.90]



Epoch 7: 100%|██████████| 938/938 [00:06<00:00, 143.06it/s, v_num=57, train_loss=0.713, Acc%=79.40]



Epoch 8: 100%|██████████| 938/938 [00:06<00:00, 138.37it/s, v_num=57, train_loss=1.010, Acc%=79.20]



Epoch 9: 100%|██████████| 938/938 [00:06<00:00, 144.06it/s, v_num=57, train_loss=0.803, Acc%=80.20]



Epoch 10: 100%|██████████| 938/938 [00:06<00:00, 144.23it/s, v_num=57, train_loss=0.867, Acc%=74.70]



Epoch 19: 100%|██████████| 938/938 [00:06<00:00, 151.63it/s, v_num=57, train_loss=1.390, Acc%=80.30]



Epoch 24: 100%|██████████| 938/938 [00:06<00:00, 152.22it/s, v_num=57, train_loss=0.696, Acc%=82.30]



Epoch 27: 100%|██████████| 938/938 [00:07<00:00, 123.66it/s, v_num=57, train_loss=0.513, Acc%=83.00]



Epoch 28: 100%|██████████| 938/938 [00:06<00:00, 134.27it/s, v_num=57, train_loss=1.040, Acc%=75.80]



Epoch 30: 100%|██████████| 938/938 [00:07<00:00, 120.40it/s, v_num=57, train_loss=0.573, Acc%=83.70]



Epoch 38: 100%|██████████| 938/938 [00:06<00:00, 151.51it/s, v_num=57, train_loss=0.517, Acc%=81.90]



Epoch 43: 100%|██████████| 938/938 [00:06<00:00, 143.95it/s, v_num=57, train_loss=0.956, Acc%=83.40]



Epoch 45: 100%|██████████| 938/938 [00:06<00:00, 155.05it/s, v_num=57, train_loss=0.769, Acc%=79.90]



Epoch 48: 100%|██████████| 938/938 [00:06<00:00, 146.01it/s, v_num=57, train_loss=0.626, Acc%=81.50]



Epoch 51: 100%|██████████| 938/938 [00:06<00:00, 145.88it/s, v_num=57, train_loss=0.469, Acc%=82.30]



Epoch 54: 100%|██████████| 938/938 [00:06<00:00, 147.30it/s, v_num=57, train_loss=1.710, Acc%=83.80]



Epoch 55: 100%|██████████| 938/938 [00:05<00:00, 158.17it/s, v_num=57, train_loss=0.833, Acc%=76.70]



Epoch 56: 100%|██████████| 938/938 [00:06<00:00, 154.98it/s, v_num=57, train_loss=0.886, Acc%=75.40]



Epoch 59: 100%|██████████| 938/938 [00:06<00:00, 140.42it/s, v_num=57, train_loss=0.419, Acc%=81.80]

`Trainer.fit` stopped: `max_epochs=60` reached.


Epoch 59: 100%|██████████| 938/938 [00:06<00:00, 140.28it/s, v_num=57, train_loss=0.419, Acc%=81.80]


In [78]:
# Apply Laplace approximation
laplace_model = Laplace(model, likelihood='classification', subset_of_weights='last_layer', hessian_structure='full')
laplace_model.fit(datamodule.train_dataloader())  # Fit the Laplace model on training data
laplace_model.optimize_prior_precision()  # Optimize prior precision



In [79]:
# Test
results = trainer.test(model=routine, datamodule=datamodule)
print("Laplace Test Results:", results)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 88.21it/s]


Laplace Test Results: [{'test/cal/ECE': 0.020094318315386772, 'test/cal/aECE': 0.02005278505384922, 'test/cls/Acc': 0.8176000118255615, 'test/cls/Brier': 0.2700233459472656, 'test/cls/NLL': 0.5702756643295288, 'test/sc/AUGRC': 0.04408245161175728, 'test/sc/AURC': 0.05958322063088417, 'test/sc/Cov@5Risk': 0.5226000547409058, 'test/sc/Risk@80Cov': 0.10912500321865082, 'test/cls/Entropy': 0.5809915661811829}]


### SGHMC

### pSGLD

## Scores

### Posterior Estimation

In [196]:
def monte_carlo_sampling(model, data_loader, num_samples=50):
    """
    Perform Monte Carlo Dropout sampling on the model.
    Args:
        model (nn.Module): Trained OptuDrop model with dropout.
        data_loader (DataLoader): DataLoader for test data.
        num_samples (int): Number of MC samples.
    Returns:
        np.ndarray: Array of predictions from all samples.
    """
    model.eval()
    predictions = []

    for _ in range(num_samples):
        sampled_preds = []
        for inputs, _ in data_loader:
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs)
                sampled_preds.append(outputs.cpu().numpy())
        predictions.append(np.concatenate(sampled_preds, axis=0))

    return np.array(predictions)


In [194]:
def extract_model_weights(models):
    """
    Extract model weights to use as samples for posterior comparison.
    Args:
        models (list): List of trained models.
    Returns:
        np.ndarray: Flattened weight arrays for each model.
    """
    weight_samples = []
    for model in models:
        weights = []
        for param in model.parameters():
            weights.append(param.detach().cpu().numpy().flatten())
        weight_samples.append(np.concatenate(weights))
    return np.array(weight_samples)


In [214]:
def target_model_predictions(models, data_loader):
    """
    Generate predictions from trained models to form the target posterior in prediction space.
    Args:
        models (list): List of trained models.
        data_loader (DataLoader): DataLoader for test data.
    Returns:
        np.ndarray: Array of predictions for all models.
    """
    all_predictions = []

    for model in models:
        model.eval()
        preds = []
        for inputs, _ in data_loader:
            inputs = inputs.cuda()
            with torch.no_grad():
                outputs = model(inputs)  # Logits
                preds.append(outputs.cpu().numpy())
        all_predictions.append(np.concatenate(preds, axis=0))  # Combine batches

    return np.array(all_predictions)  # Shape: (num_models, num_datapoints, num_classes)

In [None]:
# Sample model posterior
dropout_samples = monte_carlo_sampling(model, datamodule.test_dataloader()[0], num_samples=50)
dropout_samples_flat = dropout_samples.reshape(dropout_samples.shape[0], -1)

In [None]:
# Target posterior
target_predictions = target_model_predictions(posterior_models, datamodule.test_dataloader()[0])
# target_posterior_samples = extract_model_weights(posterior_models)
target_posterior_samples = target_predictions.reshape(30, -1)  # Shape: (30, 10000 * 10)

In [216]:
print(dropout_samples.shape)
print(dropout_samples_flat.shape)
print(target_posterior_samples.shape)

(50, 10000, 10)
(50, 100000)
(30, 100000)


In [217]:
print(dropout_samples_flat[0][:10])
print(target_posterior_samples[0][:10])

[ -5.2107353   -0.95131683   3.909545     3.1656933   -1.405637
  -5.1011004  -10.773238    10.522338     0.7293606    4.6583385 ]
[-1.2191579  -0.7090111   4.4251957   1.871046   -0.4916743  -5.1103597
 -7.6788454   7.466436   -1.2615769  -0.34655005]


In [218]:
# Use mmdagg function
mmd_score = mmdagg(
    X=dropout_samples_flat,
    Y=target_posterior_samples,
    alpha=0.05,
    kernel="laplace_gaussian",
    number_bandwidths=10,
    weights_type="uniform",
    B1=2000,
    B2=2000,
    B3=50,
    seed=42424242,
)
print("MMD Score:", mmd_score)

MMD Score: [0.84391623 0.79542857 0.73696664 0.67138301 0.60202212 0.53211281
 0.4643618  0.40076836 0.34261158 0.29054258 0.0176627  0.58616206
 0.88334639 0.92725724 0.93261007 0.93324742 0.93332312 0.93326705
 0.64932177 0.11887361]


In [None]:
print("MMD avg:", np.mean(mmd_score))
print("MMD Shape:", mmd_score.shape)

MMD avg: 0.6297592563799301
MMD Shape: (20,)


: 

### TEST CODE

In [None]:
test_outputs = trainer.test(model=routine, datamodule=datamodule)[0]

# Example inputs
probs = torch.softmax(model_outputs, dim=1)  # Mean probabilities from SWAG samples
individual_entropies = compute_entropies(model_samples)  # Implement entropy computation for sampled models
in_confidences = torch.max(probs, dim=1)[0]  # Confidence scores for in-distribution data
ood_confidences = torch.max(ood_probs, dim=1)[0]  # Replace with OOD predictions
labels = torch.cat([torch.ones(len(in_confidences)), torch.zeros(len(ood_confidences))])

# Compute scores
scores = score_methods(test_outputs, probs, individual_entropies, in_confidences, ood_confidences, labels)
print(scores)


In [None]:
predictions, labels = [], []

for batch in datamodule.test_dataloader()[0]:
    images, true_labels = batch
    with torch.no_grad():
        probs = model(images)  # Assuming the model outputs probabilities
        predictions.append(probs)
        labels.append(true_labels)

predictions = torch.cat(predictions).numpy()
labels = torch.cat(labels).numpy()


### AUPR

In [None]:
n_classes = predictions.shape[1]  # Number of classes
precision = {}
recall = {}
aupr = {}

for i in range(n_classes):
    # Binarize the labels for class i
    binary_labels = (labels == i).astype(int)
    precision[i], recall[i], _ = precision_recall_curve(binary_labels, predictions[:, i])
    aupr[i] = auc(recall[i], precision[i])

# Optional: Aggregate AUPR
mean_aupr = np.mean(list(aupr.values()))
print(f"AUPR for each class: {aupr}")
print(f"Mean AUPR: {mean_aupr}")


AUPR for each class: {0: 0.5868727984354928, 1: 0.9735599255356758, 2: 0.7819963789640088, 3: 0.6210323998041789, 4: 0.9572730746676985, 5: 0.8067717806007965, 6: 0.921096030318654, 7: 0.8066637236494992, 8: 0.5346925620695928, 9: 0.686130600547784}
Mean AUPR: 0.7676089274593382


In [None]:
def compute_aupr(labels, scores):
    """
    Computes Area Under Precision-Recall Curve (AUPR).
    Args:
        labels (torch.Tensor): Ground truth labels (1 for in-distribution, 0 for OOD).
        scores (torch.Tensor): Confidence scores.
    Returns:
        float: AUPR.
    """
    precision, recall, _ = precision_recall_curve(labels.cpu().numpy(), scores.cpu().numpy())
    return auc(recall, precision)

### FPR95

In [41]:
from sklearn.metrics import roc_curve, auc
import numpy as np

# Assuming `labels` are integers representing classes and `predictions` are probabilities
n_classes = predictions.shape[1]  # Number of classes
fpr = {}
tpr = {}
roc_auc = {}

for i in range(n_classes):
    # Binarize the labels for class `i`
    binary_labels = (labels == i).astype(int)
    fpr[i], tpr[i], _ = roc_curve(binary_labels, predictions[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Optional: Compute micro-average ROC curve and AUC
labels_one_hot = np.eye(n_classes)[labels]  # Convert labels to one-hot encoding
fpr["micro"], tpr["micro"], _ = roc_curve(labels_one_hot.ravel(), predictions.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Print the results
print(f"ROC AUC for each class: {roc_auc}")
print(f"Micro-averaged ROC AUC: {roc_auc['micro']}")


ROC AUC for each class: {0: 0.9461464319652472, 1: 0.9950144979389818, 2: 0.9628590068737078, 3: 0.944034350598575, 4: 0.99230256391494, 5: 0.9727843604695817, 6: 0.9862095652900706, 7: 0.9559805386754469, 8: 0.9133061186233155, 9: 0.9450068943516803, 'micro': 0.959201013888889}
Micro-averaged ROC AUC: 0.959201013888889


In [None]:
def compute_fpr95(in_confidences, ood_confidences):
    """
    Computes False Positive Rate at 95% Recall.
    Args:
        in_confidences (torch.Tensor): Confidence scores for in-distribution data.
        ood_confidences (torch.Tensor): Confidence scores for out-of-distribution data.
    Returns:
        float: FPR at 95% recall.
    """
    labels = torch.cat([torch.ones_like(in_confidences), torch.zeros_like(ood_confidences)])
    scores = torch.cat([in_confidences, ood_confidences])
    fpr, tpr, _ = roc_curve(labels.cpu().numpy(), scores.cpu().numpy())
    idx_95 = (tpr >= 0.95).nonzero()[0][0]
    return fpr[idx_95]

### Accuracy

In [None]:
# For multi-class classification
predictions = np.argmax(predictions, axis=1)

accuracy = accuracy_score(labels, predictions)
print(f"Accuracy: {accuracy}")


Accuracy: 0.8439
