# DriftLens + FireRisk + Vision Transformer

This notebook reproduces the DriftLens image experiments originally done on STL-10,
but using the FireRisk dataset and Vision Transformers (ViT):

1. Setup and data loading for FireRisk  
2. Train a ViT on FireRisk  
3. Extract embeddings and set up DriftLens baseline/thresholds  
4. Experiment A: "New class" drift (analogous to use case 7.1)  
5. Experiment B: Gaussian blur drift (analogous to use case 8)

You must:
- Download and extract FireRisk under `data/FireRisk/images`
- Have a GPU if you want ViT training to be practical


In [None]:
%pip install -q driftlens timm torch torchvision matplotlib scikit-learn

In [None]:
import os, glob, math, random
from pathlib import Path
from PIL import Image

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms

import timm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from driftlens.driftlens import DriftLens

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

DATA_ROOT = "data/FireRisk"   # change if needed
IMG_DIR = os.path.join(DATA_ROOT, "images")
assert os.path.isdir(IMG_DIR), f"Expected images under {IMG_DIR}"


In [None]:
class FireRiskDataset(Dataset):
    def __init__(self, root, split="train", transform=None, classes_subset=None, seed=42, split_ratios=(0.7, 0.15, 0.15)):
        self.root = root
        self.transform = transform

        pattern = os.path.join(root, "images", "*.png")
        all_paths = sorted(glob.glob(pattern))
        assert all_paths, f"No PNGs found under {pattern}"

        paths, labels = [], []
        for p in all_paths:
            fname = os.path.basename(p)
            parts = fname.split("_")
            grid_code = int(parts[1])
            if classes_subset is not None and grid_code not in classes_subset:
                continue
            paths.append(p)
            labels.append(grid_code - 1)

        n = len(paths)
        idx = np.arange(n)
        y = np.array(labels)
        train_ratio, val_ratio, test_ratio = split_ratios
        assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6

        train_idx, temp_idx, y_train, y_temp = train_test_split(
            idx, y, test_size=(1 - train_ratio), stratify=y, random_state=seed
        )
        rel_val_ratio = val_ratio / (val_ratio + test_ratio)
        val_idx, test_idx, y_val, y_test = train_test_split(
            temp_idx, y_temp, test_size=(1 - rel_val_ratio), stratify=y_temp, random_state=seed
        )

        if split == "train":
            use_idx = train_idx
        elif split == "val":
            use_idx = val_idx
        elif split == "test":
            use_idx = test_idx
        else:
            raise ValueError(split)

        self.paths = [paths[i] for i in use_idx]
        self.labels = [labels[i] for i in use_idx]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        img = Image.open(self.paths[i]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, self.labels[i]


In [None]:
img_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

batch_size = 64
num_workers = 4

train_ds_full = FireRiskDataset(DATA_ROOT, split="train", transform=img_transform, classes_subset=None)
val_ds_full   = FireRiskDataset(DATA_ROOT, split="val",   transform=img_transform, classes_subset=None)
test_ds_full  = FireRiskDataset(DATA_ROOT, split="test",  transform=img_transform, classes_subset=None)

train_loader_full = DataLoader(train_ds_full, batch_size=batch_size, shuffle=True,  num_workers=num_workers)
val_loader_full   = DataLoader(val_ds_full,   batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader_full  = DataLoader(test_ds_full,  batch_size=batch_size, shuffle=False, num_workers=num_workers)

len(train_ds_full), len(val_ds_full), len(test_ds_full)


In [None]:
def create_vit(num_classes):
    model = timm.create_model("vit_base_patch16_224", pretrained=True)
    if hasattr(model, "head"):
        in_features = model.head.in_features
        model.head = nn.Linear(in_features, num_classes)
    else:
        raise RuntimeError("Unexpected ViT model structure.")
    model.to(device)
    return model


In [None]:
def run_epoch(model, loader, optimizer=None, criterion=None):
    train = optimizer is not None
    if train:
        model.train()
    else:
        model.eval()

    total_loss = 0.0
    total_correct = 0
    total = 0

    with torch.set_grad_enabled(train):
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            if train:
                optimizer.zero_grad()
            feats = model.forward_features(x)
            logits = model.head(feats)
            loss = criterion(logits, y) if criterion is not None else None

            if train:
                loss.backward()
                optimizer.step()

            if loss is not None:
                total_loss += float(loss.item()) * x.size(0)
            preds = logits.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            total += x.size(0)

    avg_loss = total_loss / total if total > 0 else 0.0
    acc = total_correct / total if total > 0 else 0.0
    return avg_loss, acc

def train_model(model, train_loader, val_loader, epochs=10, lr=3e-4, wd=0.05):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    best_val_acc = 0.0
    best_state = None
    history = []

    for epoch in range(1, epochs + 1):
        train_loss, train_acc = run_epoch(model, train_loader, optimizer, criterion)
        val_loss, val_acc = run_epoch(model, val_loader)
        history.append((epoch, train_loss, train_acc, val_loss, val_acc))
        print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} acc={train_acc:.4f} | val_loss={val_loss:.4f} acc={val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    if best_state is not None:
        model.load_state_dict({k: v.to(device) for k, v in best_state.items()})
    return history


In [None]:
@torch.no_grad()
def get_embeddings_and_preds(model, loader):
    model.eval()
    all_E, all_Y_hat, all_Y_true = [], [], []
    for x, y in loader:
        x = x.to(device)
        feats = model.forward_features(x)
        logits = model.head(feats)
        preds = logits.argmax(dim=1)
        all_E.append(feats.cpu())
        all_Y_hat.append(preds.cpu())
        all_Y_true.append(y)
    E = torch.cat(all_E, dim=0).numpy()
    Y_hat = torch.cat(all_Y_hat, dim=0).numpy()
    Y_true = torch.cat(all_Y_true, dim=0).numpy()
    return E, Y_hat, Y_true


In [None]:
def fit_driftlens_baseline(E_train, Y_pred_train, E_thr, Y_pred_thr,
                           batch_n_pc=150, per_label_n_pc=75,
                           window_size=1000, n_samples=10000):
    label_list = sorted(np.unique(Y_pred_train))
    dl = DriftLens()
    baseline = dl.estimate_baseline(
        E=E_train,
        Y=Y_pred_train,
        label_list=label_list,
        batch_n_pc=batch_n_pc,
        per_label_n_pc=per_label_n_pc,
    )
    per_batch_sorted, per_label_sorted = dl.random_sampling_threshold_estimation(
        label_list=label_list,
        E=E_thr,
        Y=Y_pred_thr,
        batch_n_pc=batch_n_pc,
        per_label_n_pc=per_label_n_pc,
        window_size=window_size,
        n_samples=n_samples,
        flag_shuffle=True,
        flag_replacement=True,
    )
    return dl, baseline, per_batch_sorted, per_label_sorted


In [None]:
def compute_window_distances(dl, E_stream, Yp_stream, window_size):
    n = E_stream.shape[0]
    n_windows = n // window_size
    distances = []
    for w in range(n_windows):
        s = w * window_size
        e = s + window_size
        Ew = E_stream[s:e]
        Ypw = Yp_stream[s:e]
        dist = dl.compute_window_distribution_distances(Ew, Ypw)
        distances.append(dist)
    return np.array(distances)


In [None]:
classes_train_nc = {1, 2, 3, 4, 5, 6}
classes_new_nc = {7}

train_ds_nc = FireRiskDataset(DATA_ROOT, split="train", transform=img_transform, classes_subset=classes_train_nc)
val_ds_nc   = FireRiskDataset(DATA_ROOT, split="val",   transform=img_transform, classes_subset=classes_train_nc)
test_ds_nc  = FireRiskDataset(DATA_ROOT, split="test",  transform=img_transform, classes_subset=classes_train_nc)

train_loader_nc = DataLoader(train_ds_nc, batch_size=batch_size, shuffle=True,  num_workers=num_workers)
val_loader_nc   = DataLoader(val_ds_nc,   batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader_nc  = DataLoader(test_ds_nc,  batch_size=batch_size, shuffle=False, num_workers=num_workers)

len(train_ds_nc), len(val_ds_nc), len(test_ds_nc)


In [None]:
num_classes_nc = len(classes_train_nc)
model_nc = create_vit(num_classes=num_classes_nc)
history_nc = train_model(model_nc, train_loader_nc, val_loader_nc, epochs=10, lr=3e-4, wd=0.05)


In [None]:
E_train_nc, Yp_train_nc, _ = get_embeddings_and_preds(model_nc, train_loader_nc)
E_thr_nc,   Yp_thr_nc,   _ = get_embeddings_and_preds(model_nc, test_loader_nc)

dl_nc, baseline_nc, per_batch_sorted_nc, per_label_sorted_nc = fit_driftlens_baseline(
    E_train_nc, Yp_train_nc, E_thr_nc, Yp_thr_nc,
    batch_n_pc=150, per_label_n_pc=75,
    window_size=1000, n_samples=10000,
)

sorted(np.unique(Yp_train_nc)), sorted(np.unique(Yp_thr_nc))


In [None]:
test_ds_all = FireRiskDataset(DATA_ROOT, split="test", transform=img_transform, classes_subset=None)
labels_all = test_ds_all.labels

indices_by_class = {}
for idx, y in enumerate(labels_all):
    c = y + 1
    indices_by_class.setdefault(c, []).append(idx)

for c in indices_by_class:
    random.shuffle(indices_by_class[c])

pre_drift_len = 40000
post_drift_len = 40000
window_size_nc = 1000

pre_indices = []
while len(pre_indices) < pre_drift_len:
    for c in classes_train_nc:
        if indices_by_class[c]:
            pre_indices.append(indices_by_class[c].pop())
        if len(pre_indices) >= pre_drift_len:
            break

post_indices = []
while len(post_indices) < post_drift_len:
    for c in classes_train_nc | classes_new_nc:
        if indices_by_class[c]:
            post_indices.append(indices_by_class[c].pop())
        if len(post_indices) >= post_drift_len:
            break

stream_indices_nc = pre_indices + post_indices
stream_ds_nc = Subset(test_ds_all, stream_indices_nc)
stream_loader_nc = DataLoader(stream_ds_nc, batch_size=batch_size, shuffle=False, num_workers=num_workers)

len(stream_ds_nc)


In [None]:
E_stream_nc, Yp_stream_nc, Ytrue_stream_nc = get_embeddings_and_preds(model_nc, stream_loader_nc)
distances_nc = compute_window_distances(dl_nc, E_stream_nc, Yp_stream_nc, window_size=window_size_nc)
len(distances_nc)


In [None]:
plt.figure(figsize=(10,4))
plt.plot(distances_nc, marker="o")
plt.axvline(x=(pre_drift_len // window_size_nc) - 0.5, linestyle="--")
plt.xlabel("Window index")
plt.ylabel("DriftLens distance")
plt.title("New-class drift (water) – FireRisk + ViT")
plt.tight_layout()
plt.show()


In [None]:
num_classes_full = len(set(train_ds_full.labels))
model_full = create_vit(num_classes=num_classes_full)
history_full = train_model(model_full, train_loader_full, val_loader_full, epochs=10, lr=3e-4, wd=0.05)


In [None]:
E_train_full, Yp_train_full, _ = get_embeddings_and_preds(model_full, train_loader_full)
E_thr_full,   Yp_thr_full,   _ = get_embeddings_and_preds(model_full, test_loader_full)

dl_full, baseline_full, per_batch_sorted_full, per_label_sorted_full = fit_driftlens_baseline(
    E_train_full, Yp_train_full, E_thr_full, Yp_thr_full,
    batch_n_pc=150, per_label_n_pc=75,
    window_size=1000, n_samples=10000,
)

sorted(np.unique(Yp_train_full)), sorted(np.unique(Yp_thr_full))


In [None]:
blur_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.GaussianBlur(kernel_size=11, sigma=3.0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

base_ds_clean = FireRiskDataset(DATA_ROOT, split="test", transform=img_transform,    classes_subset=None)
base_ds_blur  = FireRiskDataset(DATA_ROOT, split="test", transform=blur_transform,  classes_subset=None)

assert base_ds_clean.paths == base_ds_blur.paths
assert base_ds_clean.labels == base_ds_blur.labels

clean_loader = DataLoader(base_ds_clean, batch_size=batch_size, shuffle=False, num_workers=num_workers)
blur_loader  = DataLoader(base_ds_blur,  batch_size=batch_size, shuffle=False, num_workers=num_workers)

len(base_ds_clean), len(base_ds_blur)


In [None]:
E_clean, Yp_clean, _ = get_embeddings_and_preds(model_full, clean_loader)
E_blur,  Yp_blur,  _ = get_embeddings_and_preds(model_full, blur_loader)

window_size_blur = 1000
n_clean = 40000
n_blur = 40000

E_stream_blur = np.concatenate([E_clean[:n_clean], E_blur[:n_blur]], axis=0)
Yp_stream_blur = np.concatenate([Yp_clean[:n_clean], Yp_blur[:n_blur]], axis=0)

len(E_stream_blur), len(Yp_stream_blur)


In [None]:
distances_blur = compute_window_distances(dl_full, E_stream_blur, Yp_stream_blur, window_size=window_size_blur)
len(distances_blur)


In [None]:
plt.figure(figsize=(10,4))
plt.plot(distances_blur, marker="o")
plt.axvline(x=(n_clean // window_size_blur) - 0.5, linestyle="--")
plt.xlabel("Window index")
plt.ylabel("DriftLens distance")
plt.title("Gaussian blur drift – FireRisk + ViT")
plt.tight_layout()
plt.show()
