Detect unreliable predictions at inference time using entropy and confidence, without retraining the model.

We use confidence & entropy
We define failure
We measure how well failures are detected

In [3]:
from google.colab import files
files.upload()

Output hidden; open in https://colab.research.google.com to view.

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
from PIL import Image, ImageEnhance, ImageFilter
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
!mkdir -p failure-aware-cv/results/Model
!mv resnet18_cifar10_fc_only.pth failure-aware-cv/results/Model/

In [6]:
%cd /content/failure-aware-cv

/content/failure-aware-cv


In [7]:
model = torchvision.models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)

model.load_state_dict(
    torch.load("results/Model/resnet18_cifar10_fc_only.pth", map_location=device)
)

model = model.to(device)
model.eval()

for p in model.parameters():
    p.requires_grad = False

print("✅ Trained model loaded")



✅ Trained model loaded


In [8]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

test_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,   # keep 1 for clean failure analysis
    shuffle=False
)

100%|██████████| 170M/170M [00:18<00:00, 9.08MB/s]


In [9]:
def add_gaussian_noise(img, severity=0.3):
    arr = np.array(img).astype(np.float32) / 255.0
    noise = np.random.normal(0, severity, arr.shape)
    noisy = np.clip(arr + noise, 0, 1)
    return Image.fromarray((noisy * 255).astype(np.uint8))

def apply_blur(img, radius=2):
    return img.filter(ImageFilter.GaussianBlur(radius))

def low_light(img, factor=0.4):
    return ImageEnhance.Brightness(img).enhance(factor)

In [10]:
def compute_entropy(probs):
    probs = probs + 1e-12
    return -torch.sum(probs * torch.log(probs)).item()

In [11]:
def run_failure_analysis(loader, degradation=None):
    records = []

    for img, label in tqdm(loader):
        label = label.item()
        pil_img = transforms.ToPILImage()(img[0])

        if degradation == "noise":
            pil_img = add_gaussian_noise(pil_img)
        elif degradation == "blur":
            pil_img = apply_blur(pil_img)
        elif degradation == "low_light":
            pil_img = low_light(pil_img)

        img_tensor = transform(pil_img).unsqueeze(0).to(device)

        with torch.no_grad():
            logits = model(img_tensor)
            probs = F.softmax(logits, dim=1)[0]

        confidence = probs.max().item()
        entropy = compute_entropy(probs)
        prediction = probs.argmax().item()
        correct = int(prediction == label)

        records.append({
            "confidence": confidence,
            "entropy": entropy,
            "correct": correct
        })

    return pd.DataFrame(records)

In [12]:
df_clean = run_failure_analysis(test_loader)
df_noise = run_failure_analysis(test_loader, "noise")
df_blur = run_failure_analysis(test_loader, "blur")
df_low = run_failure_analysis(test_loader, "low_light")

100%|██████████| 10000/10000 [01:01<00:00, 162.68it/s]
100%|██████████| 10000/10000 [02:05<00:00, 79.74it/s]
100%|██████████| 10000/10000 [01:26<00:00, 116.12it/s]
100%|██████████| 10000/10000 [01:04<00:00, 155.28it/s]


Define Failure Rule

In [13]:
CONF_THRESH = 0.6
ENTROPY_THRESH = 1.0

def is_failure(row):
    return (row["confidence"] < CONF_THRESH) or (row["entropy"] > ENTROPY_THRESH)

Apply Failure Detection

In [14]:
def add_failure_flag(df):
    df = df.copy()
    df["failure"] = df.apply(is_failure, axis=1)
    return df

df_clean_f = add_failure_flag(df_clean)
df_noise_f = add_failure_flag(df_noise)
df_blur_f = add_failure_flag(df_blur)
df_low_f = add_failure_flag(df_low)

Evaluate Failure Detection Quality

In [15]:
def failure_stats(df):
    failure_rate = df["failure"].mean()
    acc_all = df["correct"].mean()
    acc_confident = df[df["failure"] == False]["correct"].mean()
    acc_failed = df[df["failure"] == True]["correct"].mean()

    return {
        "Accuracy (All)": acc_all,
        "Failure Rate": failure_rate,
        "Accuracy (Confident Only)": acc_confident,
        "Accuracy (Failure Region)": acc_failed
    }

In [16]:
failure_summary = pd.DataFrame.from_dict({
    "Clean": failure_stats(df_clean_f),
    "Gaussian Noise": failure_stats(df_noise_f),
    "Blur": failure_stats(df_blur_f),
    "Low Light": failure_stats(df_low_f)
}, orient="index")

failure_summary

Unnamed: 0,Accuracy (All),Failure Rate,Accuracy (Confident Only),Accuracy (Failure Region)
Clean,0.8048,0.2479,0.904667,0.501815
Gaussian Noise,0.0899,0.857,0.091608,0.089615
Blur,0.6079,0.4898,0.795962,0.412005
Low Light,0.5987,0.6375,0.831724,0.466196


In [17]:
failure_summary.to_csv("results/failure_detection_summary.csv")