# ðŸ”® Stat-OOD: Interactive Inference & Visualization

This notebook focuses on **Using** the trained model. 
1. **Interactive Demo**: Type a sentence and see if it's detected as OOD.
2. **Cluster Visualization**: Visualize the embedding space using t-SNE to see how OOD samples separate from ID samples.

In [None]:
# Setup (Colab)
# !pip install -q uv
# !git clone https://github.com/sucpark/stat-ood.git
# %cd stat-ood
# !uv sync
# !pip install -q scikit-learn matplotlib seaborn

import sys
import os
sys.path.append(os.getcwd()) 

import torch
from omegaconf import OmegaConf
from src.data.loader import DataLoader
from src.models.wrapper import ModelWrapper
from src.ood.calculator import OODCalculator
from transformers import AutoTokenizer
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE

# Load Config (Pre-set for E5 + Mahalanobis)
cfg = OmegaConf.create({
    "name": "stat-ood-viz",
    "dataset": { "name": "clinc_oos", "subset": "plus", "maxlen": 64, "loader": { "batch_size": 32, "num_workers": 0, "pin_memory": False } },
    "model": { "name": "intfloat/e5-base-v2", "num_labels": 150, "pooling": "mean" },
    "ood_method": "mahalanobis",
    "experiment": { "device": "cuda" if torch.cuda.is_available() else "cpu" }
})

device = torch.device(cfg.experiment.device)
print(f"Device: {device}")

## 1. Prepare Model & Calculator

In [None]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model.name)
loader = DataLoader(cfg.dataset, tokenizer)
train_loader, _, test_id_loader, test_ood_loader = loader.load()

model = ModelWrapper(cfg.model).to(device).eval()

print("Fitting OOD Calculator (this may take a moment)...")
train_features = []
train_labels = []

with torch.no_grad():
    for i, batch in enumerate(train_loader):
        if i > 50: break # Use partial data for speed in demo
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        model(input_ids, mask)
        feat = model.get_features(key='pooled_output' if cfg.ood_method!='energy' else 'logits')
        train_features.append(feat.cpu())
        train_labels.append(batch['intent'].cpu())

ood_calc = OODCalculator(cfg)
ood_calc.fit(torch.cat(train_features), torch.cat(train_labels))
print("Ready!")

## 2. Interactive Inference Mode
Type any sentence to see its OOD Score.

In [None]:
def predict_sentence(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64).to(device)
    with torch.no_grad():
        model(inputs['input_ids'], inputs['attention_mask'])
        feat = model.get_features(key='pooled_output' if cfg.ood_method!='energy' else 'logits')
        score = ood_calc.predict(feat).item()
    return score

# Threshold (Simulated - usually determined by FPR@95 on Val set)
threshold = 20.0 

examples = [
    "I want to transfer money to my savings account", # ID
    "Can you book a flight to Paris?", # ID (if flight is ID)
    "How do I cook a steak?", # OOD (Likely)
    "Tell me a joke about dogs" # OOD
]

print(f"--- Threshold: {threshold} ---\n")
for text in examples:
    score = predict_sentence(text)
    verdict = "ðŸ”´ OOD" if score > threshold else "ðŸŸ¢ ID"
    print(f"'{text}'\n  -> Score: {score:.2f} [{verdict}]\n")

## 3. t-SNE Visualization
We compress the high-dimensional embeddings into 2D points. 
You should see distinct clusters for ID classes, and OOD samples scattered or far away.

In [None]:
def extract_for_viz(loader, limit=200):
    feats = []
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if len(feats) > limit: break
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            model(input_ids, mask)
            f = model.get_features(key='pooled_output' if cfg.ood_method!='energy' else 'logits')
            feats.append(f.cpu())
    return torch.cat(feats)

print("Extracting ID and OOD features...")
id_feats = extract_for_viz(test_id_loader, limit=500)
ood_feats = extract_for_viz(test_ood_loader, limit=500)

X = torch.cat([id_feats, ood_feats]).numpy()
y = np.array([0] * len(id_feats) + [1] * len(ood_feats)) # 0=ID, 1=OOD

print("Running t-SNE...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
X_emb = tsne.fit_transform(X)

plt.figure(figsize=(10, 8))
plt.scatter(X_emb[y==0, 0], X_emb[y==0, 1], c='blue', label='ID', alpha=0.5, s=10)
plt.scatter(X_emb[y==1, 0], X_emb[y==1, 1], c='red', label='OOD', alpha=0.5, s=10)
plt.legend()
plt.title("t-SNE of ID vs OOD Embeddings")
plt.show()