# CUE Meets LRP

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vincent-el/cue-meets-lrp/blob/main/cue_meets_lrp.ipynb)

I spent an afternoon with [zennit](https://github.com/chr5tphr/zennit) to see how different LRP composite rules explain the same prediction. Three rules, three visual signatures — which is most understandable to a human viewer?

To evaluate, I use the [CUE model](https://arxiv.org/abs/2506.14775) (Labarta et al., 2025) as a lens: not just *what* the network saw, but *how legible* each explanation is.

In [None]:
# Setup: install dependencies, load model and image
import subprocess, sys
for pkg in ["zennit", "torchvision"]:
    try: __import__(pkg)
    except ImportError: subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from torchvision.models import vgg16, VGG16_Weights
from torchvision import transforms
from PIL import Image
import urllib.request, io

from zennit.composites import EpsilonPlusFlat, EpsilonGammaBox, EpsilonAlpha2Beta1Flat
from zennit.attribution import Gradient

# Load VGG16
model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).eval()
labels = VGG16_Weights.IMAGENET1K_V1.meta["categories"]

# One image: border collie catching frisbee
url = "https://images.unsplash.com/photo-1503256207526-0d5d80fa2f47?w=640"
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
img = Image.open(io.BytesIO(urllib.request.urlopen(req).read())).convert("RGB").resize((224, 224))

# Preprocess and predict
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
x = preprocess(img).unsqueeze(0)

with torch.no_grad():
    logits = model(x)
    pred_idx = logits.argmax(1).item()
    pred_conf = torch.softmax(logits, 1)[0, pred_idx].item()

print(f"Prediction: {labels[pred_idx]} ({pred_conf:.0%})")

In [None]:
# Generate LRP heatmaps with three composite rules
composites = {
    "Epsilon+Flat": EpsilonPlusFlat(),
    "EpsilonGammaBox": EpsilonGammaBox(low=-3.0, high=3.0),
    "EpsilonAlpha2Beta1": EpsilonAlpha2Beta1Flat(),
}

heatmaps = {}
x_grad = preprocess(img).unsqueeze(0).requires_grad_(True)
target = torch.eye(1000)[[pred_idx]]

for name, composite in composites.items():
    with Gradient(model=model, composite=composite) as attributor:
        _, relevance = attributor(x_grad, target)
    heatmaps[name] = relevance.sum(1).squeeze().detach().cpu().numpy()

# Visualize: Original + 3 heatmaps in a row
fig, axes = plt.subplots(1, 4, figsize=(14, 3.5))

axes[0].imshow(img)
axes[0].set_title(f"{labels[pred_idx]}\n({pred_conf:.0%})", fontsize=11)
axes[0].axis("off")

for i, (name, hmap) in enumerate(heatmaps.items()):
    vmax = max(abs(hmap.min()), abs(hmap.max()))
    axes[i + 1].imshow(hmap, cmap="bwr", vmin=-vmax, vmax=vmax, interpolation="bilinear")
    axes[i + 1].set_title(name, fontsize=11)
    axes[i + 1].axis("off")

plt.tight_layout()
plt.show()

## Evaluating Through the CUE Lens

The CUE model frames explanation quality as three cognitive stages:

- **Legibility**: Can you *see* the signal against the background?
- **Readability**: Can you *parse* the structure into coherent regions?
- **Interpretability**: Can you *derive meaning* — does the focus make sense?

Below I use simple quantitative proxies for each stage. These are rough approximations, not rigorous CUE measurements — but they structure the comparison.

In [None]:
def contrast_ratio(hmap):
    """Legibility: How much does the hottest spot stand out?
    
    Higher = clearer signal against background.
    """
    abs_hmap = np.abs(hmap)
    return float(abs_hmap.max() / (np.median(abs_hmap) + 1e-8))

def blob_count(hmap, threshold_pct=75):
    """Readability: How many distinct high-relevance regions?
    
    Fewer blobs = easier to parse structure.
    """
    abs_hmap = np.abs(hmap)
    threshold = np.percentile(abs_hmap, threshold_pct)
    binary = abs_hmap >= threshold
    labeled, n_blobs = ndimage.label(binary)
    return n_blobs

def gini(hmap):
    """Interpretability: How focused is the explanation?
    
    0 = uniform spread, 1 = all relevance in one pixel.
    Higher = more focused, potentially easier to interpret.
    """
    vals = np.abs(hmap).flatten()
    vals = np.sort(vals)
    n = len(vals)
    idx = np.arange(1, n + 1)
    return float((2 * np.sum(idx * vals) / (n * np.sum(vals) + 1e-8)) - (n + 1) / n)

# Compute scores
print(f"{'Rule':<20} {'Contrast':>10} {'Blobs':>8} {'Gini':>8}")
print("-" * 48)
for name, hmap in heatmaps.items():
    c = contrast_ratio(hmap)
    b = blob_count(hmap)
    g = gini(hmap)
    print(f"{name:<20} {c:>10.1f} {b:>8} {g:>8.2f}")

## A Vision Model's Perspective

Quantitative proxies measure signal properties. As a second lens, I ask a vision-language model which explanation best communicates the prediction — closer to how a human might respond.

In [None]:
# BLIP-2: stable VLM for image captioning/QA
for pkg in ["transformers", "accelerate"]:
    try: __import__(pkg)
    except ImportError: subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

from transformers import Blip2Processor, Blip2ForConditionalGeneration

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

print(f"Loading BLIP-2 on {device}...")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
vlm = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", torch_dtype=dtype
).to(device).eval()

def ask_vlm(image, prompt):
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, dtype)
    with torch.no_grad():
        out = vlm.generate(**inputs, max_new_tokens=50)
    return processor.decode(out[0], skip_special_tokens=True)

# Convert heatmap to PIL for VLM
def heatmap_to_pil(hmap, size=(224, 224)):
    vmax = max(abs(hmap.min()), abs(hmap.max()))
    fig, ax = plt.subplots(figsize=(2.24, 2.24), dpi=100)
    ax.imshow(hmap, cmap="bwr", vmin=-vmax, vmax=vmax, interpolation="bilinear")
    ax.axis("off")
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    buf = io.BytesIO()
    fig.savefig(buf, format="PNG", bbox_inches="tight", pad_inches=0)
    plt.close(fig)
    buf.seek(0)
    return Image.open(buf).convert("RGB").resize(size)

# Ask about each heatmap
print(f"\nAsking BLIP-2 about each explanation...\n")
for name, hmap in heatmaps.items():
    hmap_img = heatmap_to_pil(hmap)
    prompt = f"Question: This heatmap shows which image regions led to classifying something as '{labels[pred_idx]}'. Red supports the prediction, blue contradicts. Is the focus clear and concentrated, or scattered? Answer:"
    response = ask_vlm(hmap_img, prompt)
    print(f"{name}:\n  {response.strip()}\n")

## Reflection

The quantitative proxies and VLM give different perspectives on the same heatmaps. Neither is a substitute for asking real humans — which is what the CUE framework ultimately calls for.

This is the question I want to learn to answer properly: *what makes an explanation actually land?*

## What I'd Explore Next

This demo evaluates *static* explanations — one image, one prediction, one heatmap.

But many AI systems make *sequential* decisions: a robot navigating, an agent learning a policy, a trajectory unfolding over time. Explaining these requires not "which pixels?" but "which past observations shaped this action?" — temporal relevance, possibly counterfactual reasoning.

The CUE lens still applies: can you *trace* the sequence (legibility), *parse* the causal structure (readability), *understand* the strategy (interpretability)? But the *form* of explanation changes. That's the frontier I'm most curious about.

---

*Built with [zennit](https://github.com/chr5tphr/zennit) (Fraunhofer HHI) · Vincent Lange · vincentelange@gmail.com*