# Dependencies

In [None]:
from pathlib import Path

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.models.lenet import bayesian_lenet
from torch_uncertainty.models import mc_dropout
from torch_uncertainty.routines import ClassificationRoutine

from laplace import Laplace

from pathlib import Path
from safetensors.torch import load_file

# OptuNet Posterior Approximation

In [None]:
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_PATH = "data"

# Parameters from paper
EPOCHS = 60
BATCH_SIZE = 64
LEARNING_RATE = 0.04
WEIGHT_DECAY = 2e-4

NUM_WORKERS = 4
## OptuNet params
DROPOUT_RATE = 0.2 # last layer dropout rate

## Load Data

In [None]:
# Load MNIST data
root = Path(DATA_PATH)
datamodule = MNISTDataModule(root=root, batch_size=BATCH_SIZE, eval_ood=False, num_workers=NUM_WORKERS)

## OptuNet Model

In [32]:
class OptuNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # Add layers for OptuNet (use Section C.2.1 from the paper for details)
        # Layers: Conv2D (out_ch=2, ks=4, groups=1) -> Max Pooling (ks=3, stride=3) -> ReLU -> Conv2D (out_ch=10, ks=5, groups=2) -> Average Pooling -> ReLU -> Linear 10x10
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=4, groups=1, bias=False)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=3)
        self.conv2 = nn.Conv2d(in_channels=2, out_channels=10, kernel_size=5, groups=2, bias=False)
        self.pool2 = nn.AvgPool2d(kernel_size=2)
        self.fc1 = nn.Linear(in_features=10, out_features=10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.pool1(self.conv1(x)))  # First conv, max pooling, ReLU
        x = self.relu(self.pool2(self.conv2(x)))  # Second conv, avg pooling, ReLU
        x = self.fc1(x)  # Linear layer
        return x

def load_optunet_model(version: int):
    """Load the model corresponding to the given version."""
    model = OptuNet(num_classes=datamodule.num_classes)
    path = Path(f"models/mnist-optunet-0-8191/version_{version}.safetensors")

    if not path.exists():
        raise ValueError("File does not exist")

    state_dict = load_file(path)

    model.load_state_dict(state_dict=state_dict)
    return model

## Train / Test

In [None]:
def optim_lenet(model: nn.Module):
    optimizer = optim.SGD(
        model.parameters(),
        lr=0.04,
        weight_decay=0.0002
    )
    return optimizer

trainer = TUTrainer(accelerator="gpu", enable_progress_bar=False, max_epochs=1)

# model
# model = bayesian_lenet(datamodule.num_channels, datamodule.num_classes)
model = load_optunet_model(version=1000)

# loss
loss = ELBOLoss(
    model=model,
    inner_loss=nn.CrossEntropyLoss(),
    kl_weight=1 / 10000,
    num_samples=3,
)

# learning rate scheduler to  decay
#the learning rate twice during training, at epochs 15 and 30, dividing the learning rate by 2.
def scheduler_lenet(optimizer):
    scheduler = MultiStepLR(
        optimizer,
        milestones=[15, 30],  # Epochs at which to decay the learning rate
        gamma=0.5,            # Factor by which to multiply the learning rate
    )
    return scheduler

routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss,
    optim_recipe=torch.optim.SGD(model.parameters(),lr=0.04,weight_decay=0.0002),
    is_ensemble=True
)

# trainer.fit(model=routine, datamodule=datamodule)
results = trainer.test(model=routine, datamodule=datamodule)

In [None]:
# Fetch the training dataloader from the datamodule
train_dataloader = datamodule.train_dataloader()
# Hessian Laplace approximation
# This step approximates the posterior distribution over the model parameters with a Gaussian
la = Laplace(
    model,                      # The trained model
    likelihood='classification', # Specify task type
    prior_precision=1.0,         # Regularization term for the prior (hyperparameter)
    # subset_of_weights='all',     # Apply Laplace to all model weights
)

# Fit the Laplace approximation using the training data
la.fit(train_dataloader)  

# Refine the posterior with the Hessian
la.optimize_prior_precision()

In [None]:
predictions, labels = [], []

for batch in datamodule.test_dataloader()[0]:
    images, true_labels = batch
    with torch.no_grad():
        probs = model(images)  # Assuming the model outputs probabilities
        predictions.append(probs)
        labels.append(true_labels)

predictions = torch.cat(predictions).numpy()
labels = torch.cat(labels).numpy()


## Scoring

### AUPR

In [None]:
# Compute AUPR score
from sklearn.metrics import average_precision_score, auc, roc_curve, precision_recall_curve

# Get predicted by index based on the highest probability
ypreds = predictions.argmax(axis=1)

print(ypreds.shape)


# Compute precision-recall curve
precision, recall = precision_recall_curve(labels, ypreds)
aupr = auc(recall, precision)

print(f"AUPR: {aupr}")

In [None]:
#AUPR
from sklearn.metrics import precision_recall_curve, auc

precision, recall, _ = precision_recall_curve(labels_np, positive_probs)
aupr = auc(recall, precision)
print(f"AUPR: {aupr}")


### FPR95

In [None]:
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(labels_np, positive_probs)
# Find the threshold where TPR is closest to 0.95
threshold_at_95_tpr = thresholds[np.argmax(tpr >= 0.95)]
fpr95 = fpr[np.argmax(tpr >= 0.95)]
print(f"FPR95: {fpr95}")


### Accuracy

In [None]:
from sklearn.metrics import accuracy_score

# For multi-class classification
predictions = np.argmax(probs_np, axis=1)

# For binary classification (based on a 0.5 threshold)
predictions = (positive_probs >= 0.5).astype(int)

accuracy = accuracy_score(labels_np, predictions)
print(f"Accuracy: {accuracy}")


### SWAG

In [2]:
from utils.swa_gaussian.swag.posteriors import SWAG