# ðŸ”¬ Stat-OOD: Step-by-Step Analysis

This notebook breaks down the Stat-OOD pipeline into modular steps. Use this to inspect data, check model outputs, and understand how OOD scores are calculated internally.

In [1]:
!nvidia-smi

Mon Jan 12 06:23:32 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P8              8W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [1]:
# 1. Setup Environment (Colab Only)
# Uncomment if running on Colab
# !pip install -q uv
# !git clone https://github.com/sucpark/stat-ood.git
# %cd stat-ood
# !uv sync

import sys
import os
sys.path.append(os.getcwd()) # Ensure src is visible

In [2]:
import torch
from omegaconf import OmegaConf
from src.data.loader import DataLoader
from src.models.wrapper import ModelWrapper
from src.ood.calculator import OODCalculator
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns

# Manual Configuration
cfg = OmegaConf.create({
    "name": "stat-ood-analysis",
    "dataset": {
        "name": "clinc_oos",
        "subset": "plus",
        "maxlen": 64,
        "loader": {
            "batch_size": 32,
            "num_workers": 0,
            "pin_memory": False
        }
    },
    "model": {
        "name": "bert-base-uncased",
        "num_labels": 150,
        "pooling": "cls" 
    },
    "ood_method": "mahalanobis", # or 'energy'
    "experiment": {
        "device": "cuda" if torch.cuda.is_available() else "cpu"
    }
})

device = torch.device(cfg.experiment.device)
print(f"Using device: {device}")

ModuleNotFoundError: No module named 'src'

## 2. Load Data

In [None]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model.name)
loader = DataLoader(cfg.dataset, tokenizer)
train_loader, val_loader, test_id_loader, test_ood_loader = loader.load()

print(f"Train Batches: {len(train_loader)}")
batch = next(iter(train_loader))
print(f"Sample Batch Keys: {batch.keys()}")
print(f"Input Shape: {batch['input_ids'].shape}")

## 3. Initialize Model

In [None]:
model = ModelWrapper(cfg.model)
model.to(device)
model.eval()
print("Model initialized successfully.")

## 4. Extract Features (Simulated Fitting)

In [None]:
print("Extracting features from Training set for fitting statistics...")
train_features = []
train_labels = []

with torch.no_grad():
    # Limit to 5 batches for quick demo
    for i, batch in enumerate(train_loader):
        if i > 5: break
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['intent'].to(device)
        
        # Forward
        model(input_ids, mask)
        
        if cfg.ood_method == 'energy':
            feats = model.get_features('logits')
        else:
            feats = model.get_features('pooled_output')
            
        train_features.append(feats.cpu())
        train_labels.append(labels.cpu())

train_features = torch.cat(train_features)
train_labels = torch.cat(train_labels)
print(f"Extracted Features Shape: {train_features.shape}")

## 5. Fit OOD Calculator

In [None]:
ood_calc = OODCalculator(cfg)
ood_calc.fit(train_features, train_labels)
print("OOD Calculator fitted.")

## 6. Score Analysis (ID vs OOD)

In [None]:
def get_scores(loader, limit=5):
    scores = []
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= limit: break
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            
            model(input_ids, mask)
            
            if cfg.ood_method == 'energy':
                feats = model.get_features('logits')
            else:
                feats = model.get_features('pooled_output')
            
            dists = ood_calc.predict(feats)
            scores.append(dists)
    return torch.cat(scores)

id_scores = get_scores(test_id_loader)
ood_scores = get_scores(test_ood_loader)

# Plot
plt.figure(figsize=(10, 6))
sns.kdeplot(id_scores.numpy(), label='ID (Known Intents)', fill=True)
sns.kdeplot(ood_scores.numpy(), label='OOD (Unknown)', fill=True)
plt.title(f"OOD Score Distribution ({cfg.ood_method})")
plt.xlabel("Score (Distance/Energy)")
plt.legend()
plt.show()