# OOD Evaluation with Ensembles on FashionMNIST

This notebook demonstrates how to:

- train a small ensemble on FashionMNIST,
- use ensemble-based uncertainty (mutual information) as an OOD score,
- evaluate OOD performance using updated metric definitions via a unified API,
- interpret OOD-related scores and metrics in practice.

Reference: `fashionmnist_ood_ensemble.ipynb` in `probly/notebooks/examples`.

## Dataset setup and library imports

In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm

from probly.evaluation.tasks import out_of_distribution_detection
from probly.quantification.classification import mutual_information
from probly.transformation import ensemble
from probly.evaluation.ood import evaluate_ood  # unified OOD metric API

transforms = T.Compose([T.ToTensor(), torch.flatten])

train = torchvision.datasets.FashionMNIST(
    root="~/datasets/", train=True, download=True, transform=transforms
)
test = torchvision.datasets.FashionMNIST(
    root="~/datasets/", train=False, download=True, transform=transforms
)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

ood = torchvision.datasets.MNIST(
    root="~/datasets/", train=False, download=True, transform=transforms
)
ood_loader = DataLoader(ood, batch_size=256, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Define the network architecture and construct an ensemble

In [None]:
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(784, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 10)
        self.act = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)
        return x

ensemble_model = ensemble(Net().to(device), 5)
ensemble_model

## Train the ensemble members on FashionMNIST

In [None]:
criterion = nn.CrossEntropyLoss()

for model in tqdm(ensemble_model, desc="Training ensemble"):
    optimizer = optim.Adam(model.parameters())
    for _ in range(10):  
        model.train()
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

## Evaluate ensemble accuracy on FashionMNIST test data

In [None]:
correct = 0
total = 0
ensemble_model.eval()

for inputs, targets in test_loader:
    inputs = inputs.to(device)
    targets = targets.to(device)

    outputs = []
    for model in ensemble_model:
        outputs.append(torch.softmax(model(inputs), dim=1))
    outputs = torch.stack(outputs, dim=1).mean(dim=1)

    correct += (outputs.argmax(1) == targets).sum().item()
    total += targets.size(0)

accuracy = correct / total
print(f"Accuracy on FashionMNIST test set: {accuracy:.3f}")

## Collect ensemble outputs on ID and OOD data

In [None]:
@torch.no_grad()
def torch_get_outputs(model: nn.Module, loader: DataLoader) -> torch.Tensor:
    outputs = []
    for data, _ in loader:
        data = data.to(device)
        out_members = []
        for m in model:
            out_members.append(torch.softmax(m(data), dim=1))
        out_members = torch.stack(out_members, dim=1)
        outputs.append(out_members)
    outputs = torch.cat(outputs, dim=0)
    return outputs

## Compute uncertainty scores using mutual information

In [None]:
ensemble_model.eval()

outputs_id = torch_get_outputs(ensemble_model, test_loader)
outputs_ood = torch_get_outputs(ensemble_model, ood_loader)

outputs_id = outputs_id.cpu().numpy()
outputs_ood = outputs_ood.cpu().numpy()

uncertainty_id = mutual_information(outputs_id)
uncertainty_ood = mutual_information(outputs_ood)

len(uncertainty_id), len(uncertainty_ood)

### Interpreting Ensemble-based Uncertainty

We use **mutual information** over the ensemble's predictive distribution as an
uncertainty measure:

- low mutual information → model predictions are consistent across ensemble members  
- high mutual information → ensemble members disagree, indicating uncertainty  

When FashionMNIST acts as in-distribution (ID) and MNIST as out-of-distribution (OOD),
we expect:

- **ID samples**: lower uncertainty  
- **OOD samples**: higher uncertainty  

Next, we visualize the distributions of these uncertainty scores.

## Visualize uncertainty distributions for ID vs. OOD

In [None]:
plt.hist(uncertainty_id, bins=50, alpha=0.5, label="In-Distribution (FashionMNIST)")
plt.hist(uncertainty_ood, bins=50, alpha=0.5, label="Out-of-Distribution (MNIST)")
plt.xlabel("Mutual Information (uncertainty)")
plt.ylabel("Count")
plt.legend()
plt.title("Uncertainty distributions: ID vs OOD")
plt.show()

## Compute AUROC using the classical OOD detection function

In [None]:
auroc_legacy = out_of_distribution_detection(uncertainty_id, uncertainty_ood)
print(f"Legacy AUROC with FashionMNIST as ID and MNIST as OOD: {auroc_legacy:.3f}")

## Updated OOD Metrics (Unified API)

We now use the unified `evaluate_ood` API, which supports:

- **Static metrics**
  - `auroc` – Area under ROC
  - `aupr` – Area under Precision–Recall curve
  - `fpr@95` – False Positive Rate when True Positive Rate is 95%

- **Dynamic metrics** (string specifications)
  - `fpr@0.8`   → FPR at TPR = 0.8
  - `fnr@90%`   → FNR at TPR = 90%
  - `tnr@0.99`  → True Negative Rate at TPR = 0.99  

The API expects:

evaluate_ood(in_distribution_scores, out_distribution_scores, metrics=...)

Where **higher scores indicate more likely OOD**

## Evaluate updated OOD metrics using the unified API

In [None]:
metrics_all = evaluate_ood(
    in_distribution=uncertainty_id,
    out_distribution=uncertainty_ood,
    metrics="all",
)

metrics_all

## Evaluate selected OOD metrics and interpret operational meaning

In [None]:
metrics_selected = evaluate_ood(
    in_distribution=uncertainty_id,
    out_distribution=uncertainty_ood,
    metrics=["auroc", "aupr", "fpr@95", "fpr@0.8", "fnr@90%", "tnr@0.95"],
)

for name, value in metrics_selected.items():
    print(f"{name:8s}: {value:.4f}")

### Interpreting OOD Metrics

Given `uncertainty_id` (FashionMNIST) and `uncertainty_ood` (MNIST):

- **AUROC**
  - Values close to 1.0 indicate that ID and OOD are well separated.
  - Here, a high AUROC means OOD samples tend to have higher uncertainty.

- **AUPR**
  - Focuses on the positive class (OOD) and is especially useful when OOD is rare.
  - A high AUPR indicates that among the samples flagged as OOD, many are truly OOD.

- **FPR@95**
  - False Positive Rate at TPR = 95%.
  - Operational view: “If we want to catch 95% of OOD samples, how many ID samples
    will we incorrectly mark as OOD?”

- **FPR@0.8**
  - False Positive Rate at TPR = 80%.
  - Lowering the TPR requirement often reduces FPR.

- **FNR@90%**
  - False Negative Rate at TPR = 90%.
  - Fraction of OOD samples that are missed at a relatively high detection rate.

- **TNR@0.95**
  - True Negative Rate at TPR = 0.95.
  - Fraction of ID samples correctly kept as ID when we still detect 95% of OOD samples.

These metrics together provide a richer view than AUROC alone:
they describe not only separability but also **operating points** that matter in practice.