This cell installs the `captum` library, which is essential for implementing interpretability methods like Integrated Gradients and Layer-wise Relevance Propagation (LRP).

In [None]:
!pip install captum

This cell forcefully reinstalls `numpy` and `scikit-image`. This step is often necessary to resolve potential dependency conflicts that can arise with different library versions, ensuring a stable environment for image processing and numerical operations.

In [None]:
!pip install --force-reinstall numpy scikit-image --no-cache-dir

This cell imports all the necessary Python libraries for the project. These include `torch` for deep learning, `torchvision` for computer vision utilities, `numpy` for numerical operations, `matplotlib` for plotting, `PIL` for image manipulation, `captum` for model interpretability, `scikit-image` and `sklearn` for image processing and clustering, and `cv2` (OpenCV) for additional image functionalities.

In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from captum.attr import IntegratedGradients
from skimage.measure import label, regionprops

from sklearn.cluster import MeanShift, AgglomerativeClustering
from skimage.measure import label, regionprops

import cv2
from matplotlib import patches
from PIL import Image
from itertools import combinations
from sklearn.cluster import AgglomerativeClustering
from matplotlib.colors import ListedColormap
import matplotlib.cm as cm
from matplotlib import patches

In [None]:
!pip install torchxrayvision

This cell initializes a pre-trained ResNet-18 model from `torchvision`. `pretrained=True` means the model comes with weights trained on the ImageNet dataset. `model.eval()` sets the model to evaluation mode, disabling dropout and batch normalization updates, which is standard practice when performing inference.

In [None]:
model = models.resnet18(pretrained=True)
model.eval()

This cell defines a series of image transformations to be applied to the input image. These transformations include resizing the image to 224x224 pixels, converting it to a PyTorch tensor, and normalizing its pixel values using the mean and standard deviation typical for ImageNet-trained models. This ensures the input format is compatible with the pre-trained ResNet-18 model.

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

This cell loads an image from the specified `image_path` (in this case, 'ambulance.png'), converts it to RGB format, and applies the predefined `transformations`. The `unsqueeze(0)` adds a batch dimension, making the tensor ready for input into the neural network. `input_tensor.requires_grad = True` is crucial for interpretability methods that rely on gradients, like LRP.

In [None]:
image_path = "/content/ambulance.png"
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0)
input_tensor.requires_grad = True

This cell performs a forward pass through the ResNet-18 model with the prepared `input_tensor`. It then calculates the softmax probabilities for the output, identifies the predicted class (the class with the highest probability), and prints the prediction details, including the top 5 classes and their probabilities.

In [None]:
output = model(input_tensor)
probabilities = torch.softmax(output, dim=1)
prediction = torch.argmax(output).item()
predicted_prob = probabilities[0, prediction].item()
print(f"Predicted Class: {prediction} with probability {predicted_prob:.4f}")

topk_probs, topk_indices = torch.topk(probabilities, k=5, dim=1)
print("Top 5 classes and probabilities:")
for i in range(5):
    print(f"Class {topk_indices[0, i].item()}: {topk_probs[0, i].item():.4f}")

This cell defines the `compute_lrp` function, which implements Layer-wise Relevance Propagation (LRP). LRP is an interpretability technique that decomposes the prediction of a neural network layer by layer, attributing relevance scores to input features. The function computes the relevance for a target class by backpropagating a one-hot gradient and taking the absolute sum of the gradient across color channels.

In [None]:
def compute_lrp(model, input_tensor, target_class):
    model.zero_grad()
    one_hot = torch.zeros_like(output)
    one_hot[0, target_class] = 1
    output.backward(gradient=one_hot)
    relevance = input_tensor.grad.clone().detach()
    relevance = relevance.abs().sum(dim=1)[0]
    return relevance

This cell computes the LRP heatmap for the predicted class using the `compute_lrp` function. The resulting relevance map is then normalized to a range of 0 to 1, making it easier to visualize and compare the importance of different image regions.

In [None]:
lrp_map = compute_lrp(model, input_tensor, prediction).cpu().numpy()
lrp_map = (lrp_map - lrp_map.min()) / (lrp_map.max() - lrp_map.min() + 1e-8)

This cell computes the Integrated Gradients (IG) attribution map. Integrated Gradients is another popular interpretability technique that attributes the prediction of a deep learning model to its input features. It does this by integrating gradients along a path from a baseline input (here, a zero tensor) to the actual input. The resulting `ig_map` is then processed to highlight positive contributions and normalized for visualization.

In [None]:
ig = IntegratedGradients(model)
baseline = torch.zeros_like(input_tensor)
attributions, _ = ig.attribute(input_tensor, baseline, target=prediction, return_convergence_delta=True)

ig_map = attributions.squeeze().detach().numpy()
ig_map = np.transpose(ig_map, (1, 2, 0))
ig_map = np.mean(ig_map, axis=2)
ig_map = np.maximum(ig_map, 0)
ig_map /= ig_map.max() + 1e-8

This cell defines the `threshold_map` function, a utility to create a binary mask from a saliency map. It identifies pixels whose relevance scores are above a specified percentile, effectively highlighting the most important regions according to the saliency map.

In [None]:
def threshold_map(saliency_map, percentile=80):
    threshold = np.percentile(saliency_map, percentile)
    return (saliency_map >= threshold).astype(np.uint8)

This cell applies the `threshold_map` function to both the LRP and IG heatmaps, creating binary masks (`lrp_thresh` and `ig_thresh`). It then computes the `intersection` of these two masks, representing pixels commonly highlighted by both methods. Finally, it visualizes the LRP heatmap, IG heatmap, their intersection, and an overlay of the individual thresholds, providing a comparative view of the model's focus.

In [None]:
lrp_thresh = threshold_map(lrp_map, percentile=95)
ig_thresh = threshold_map(ig_map, percentile=95)

intersection = (lrp_thresh & ig_thresh).astype(np.uint8)

plt.figure(figsize=(15, 5))

plt.subplot(1, 4, 1)
plt.imshow(lrp_map, cmap='gray')
plt.title("LRP Heatmap")
plt.axis("off")

plt.subplot(1, 4, 2)
plt.imshow(ig_map, cmap='gray')
plt.title("IG Heatmap")
plt.axis("off")

plt.subplot(1, 4, 3)
plt.imshow(intersection, cmap='gray')
plt.title("Common Pixels (IG ∩ LRP)")
plt.axis("off")

plt.subplot(1, 4, 4)
plt.imshow((lrp_thresh + ig_thresh), cmap='gray')
plt.title("Overlay of LRP + IG Thresholds")
plt.axis("off")

plt.tight_layout()
plt.show()

This cell defines the `cluster_agglomerative` function, which performs Agglomerative Clustering on the binary map of common pixels. This unsupervised learning technique groups spatially adjacent and highly relevant pixels into distinct 'causal concepts' or clusters. If no relevant pixels are found, it returns an empty map.

In [None]:
def cluster_agglomerative(binary_map, n_clusters=10, linkage='ward'):
    coords = np.argwhere(binary_map == 1)
    if coords.shape[0] == 0:
        return np.zeros_like(binary_map, dtype=int)

    agg = AgglomerativeClustering(n_clusters=n_clusters, linkage=linkage)
    agg.fit(coords)

    labels = agg.labels_
    labeled_map = np.full(binary_map.shape, -1, dtype=int)
    for i, coord in enumerate(coords):
        labeled_map[coord[0], coord[1]] = labels[i]

    return labeled_map

This cell applies the `cluster_agglomerative` function to the `intersection` map (common pixels from LRP and IG) to identify 10 distinct clusters using the 'ward' linkage method. It then masks out the background and visualizes these clusters using a distinct color map, showing the spatial grouping of the model's most relevant input regions.

In [None]:
# Example usage
agg_labels = cluster_agglomerative(intersection, n_clusters=10, linkage='ward')

# Mask out background (-1) and set its color
masked_labels = np.ma.masked_where(agg_labels == -1, agg_labels)

# Custom colormap for 10 clusters
cmap = plt.get_cmap('tab10', 10)
cmap.set_bad(color='white')  # Change this to any background color you want

# Plotting
fig, ax = plt.subplots(figsize=(12, 6))
im = ax.imshow(masked_labels, cmap=cmap)
ax.set_title("Agglomerative Clustering (10 clusters)")
ax.axis("off")

plt.tight_layout()
plt.show()

This cell defines the `unnormalize` function. This utility is crucial for visualizing images that have been normalized for neural network input. It reverses the normalization process using the original mean and standard deviation, converting the tensor back to a human-readable pixel value range (e.g., 0-1 or 0-255).

In [None]:
def unnormalize(img_tensor, mean, std):
    img_tensor = img_tensor.clone()
    for i in range(img_tensor.shape[0]):
        img_tensor[i] = img_tensor[i] * std[i] + mean[i]
    return img_tensor

This cell uses the `unnormalize` function to convert the model's input tensor back into a human-readable NumPy array. It applies the inverse of the normalization process and clips the pixel values to the valid range (0-1), preparing the image for display without the normalization artifacts.

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transformed_img = unnormalize(input_tensor[0], mean, std)
transformed_img_np = transformed_img.detach().cpu().numpy().transpose(1, 2, 0)
transformed_img_np = np.clip(transformed_img_np, 0, 1)

This cell defines the `draw_cluster_boundaries` function. This function takes an image and a labeled map of clusters, then draws bounding boxes around each identified cluster. It also labels each bounding box with its cluster ID, providing a clear visual representation of the segmented causal concepts on the original image.

In [None]:
def draw_cluster_boundaries(image_np, labeled_map, title="Cluster Boundaries", skip_noise=True):
    fig, ax = plt.subplots(1, figsize=(8, 8))
    ax.imshow(image_np)

    clusters = np.unique(labeled_map)
    for cl in clusters:
        if skip_noise and cl == -1:
            continue
        coords = np.argwhere(labeled_map == cl)
        if coords.size == 0:
            continue
        min_row, min_col = coords.min(axis=0)
        max_row, max_col = coords.max(axis=0)
        width = max_col - min_col
        height = max_row - min_row

        rect = patches.Rectangle((min_col, min_row), width, height, linewidth=2,
                                 edgecolor='red', facecolor='none')
        ax.add_patch(rect)
        ax.text(min_col, min_row, f"{cl}", color='yellow', fontsize=12, weight='bold')

    ax.set_title(title)
    ax.axis("off")
    plt.tight_layout()
    plt.show()

This cell calls the `draw_cluster_boundaries` function to visualize the identified clusters (`agg_labels`) on the unnormalized and transformed image (`transformed_img_np`). This creates an overlay of bounding boxes, clearly showing where each causal concept is located on the original image.

In [None]:
draw_cluster_boundaries(transformed_img_np, agg_labels, title="Mean-Shift Cluster Boundaries on Transformed Image")

This cell defines the `zero_bbox` function, a utility used for ablation studies. It takes an image and a bounding box, then sets all pixel values within that bounding box to zero (effectively blacking out that region). This allows us to assess the impact of specific image regions on the model's prediction.

In [None]:
def zero_bbox(image_np, bbox):
    modified_image = image_np.copy()
    modified_image[bbox[0]:bbox[2], bbox[1]:bbox[3], :] = 0
    return modified_image

This cell simply prints the baseline prediction: the predicted class and its probability for the original, unmodified image. This serves as a reference point for comparing how predictions change when different causal concepts are removed.

In [None]:
print(f"Baseline Prediction: Class {prediction} with probability {predicted_prob:.4f}")

This cell performs a 'single-concept ablation' study. It iterates through each identified cluster (causal concept), zeros out its bounding box on a copy of the original image, and then re-evaluates the model's prediction. For each ablated image, it displays the modified image, its new prediction, and then summarizes the results by comparing them to the baseline prediction. This helps understand the individual contribution of each concept to the model's decision.

In [None]:
unique_clusters = np.unique(agg_labels)
results = []

original_np = (transformed_img_np * 255).astype(np.uint8)

for cl in unique_clusters:
    if cl == -1:
        continue

    coords = np.argwhere(agg_labels == cl)
    min_row, min_col = coords.min(axis=0)
    max_row, max_col = coords.max(axis=0)
    bbox = (min_row, min_col, max_row + 1, max_col + 1)

    modified_np = zero_bbox(original_np, bbox)
    modified_img = Image.fromarray(modified_np)

    input_tensor_modified = transform(modified_img).unsqueeze(0)
    input_tensor_modified.requires_grad = True

    output_modified = model(input_tensor_modified)
    probabilities_modified = torch.softmax(output_modified, dim=1)
    prediction_modified = torch.argmax(output_modified).item()
    predicted_prob_modified = probabilities_modified[0, prediction_modified].item()

    results.append((cl, prediction_modified, predicted_prob_modified))

    plt.figure(figsize=(6, 6))
    plt.imshow(modified_np)
    plt.title(f"Causal Concept {cl} Zeroed Out:\nPredicted Class {prediction_modified} (Prob: {predicted_prob_modified:.4f})")
    plt.axis("off")
    plt.show()

print("\nComparison with Baseline:")
for cl, pred, prob in results:
    print(f"Causal Concept {cl}: Modified Prediction - Class {pred} (Probability: {prob:.4f}) vs Baseline Class {prediction} (Probability: {predicted_prob:.4f})")

This cell defines the `compute_pcs_and_bf` function, which calculates two metrics to quantify the impact of removing a causal concept: Probability Change Score (PCS) and Bayes Factor (BF). PCS measures the relative drop in probability for the baseline class, while BF quantifies how much more likely the original prediction is compared to the modified prediction, given the observed probabilities.

In [None]:
def compute_pcs_and_bf(p_orig, p_mod, epsilon=1e-6):
    p_orig = np.clip(p_orig, epsilon, 1 - epsilon)
    p_mod = np.clip(p_mod, epsilon, 1 - epsilon)

    pcs = (p_orig - p_mod) / p_orig

    bf = (p_orig / (1 - p_orig)) / (p_mod / (1 - p_mod))

    return pcs, bf

This cell defines the `plot_concept_summary` function. This function visualizes the impact of removing each causal concept (cluster) on the model's prediction of the baseline class. It calculates the change in probability for the baseline class when each cluster's bounding box is zeroed out and then plots these changes as a bar chart, indicating positive or negative contributions.

In [None]:
def unnormalize_and_to_np(img_tensor, mean, std):
    x = img_tensor.clone()
    for i in range(x.shape[0]):
        x[i] = x[i] * std[i] + mean[i]
    x_np = x.detach().cpu().numpy().transpose(1, 2, 0)
    x_np = np.clip(x_np, 0, 1)
    return (x_np * 255).astype(np.uint8)

def plot_concept_summary(
    input_tensor, agg_labels, model, transform,
    prediction, predicted_prob
):
    """
    Plots the per-cluster probability drop after zeroing out bounding boxes.
    Uses the same cluster-label map as single_concept_ablation().
    """

    mean = [0.485, 0.456, 0.406]
    std  = [0.229, 0.224, 0.225]
    original_np = unnormalize_and_to_np(input_tensor[0], mean, std)

    baseline_class = prediction
    baseline_prob = predicted_prob

    clusters = [cl for cl in np.unique(agg_labels) if cl != -1]
    effects = []
    concept_labels = []

    def get_bbox_for_cluster(cluster_label, label_map):
        coords = np.argwhere(label_map == cluster_label)
        min_row, min_col = coords.min(axis=0)
        max_row, max_col = coords.max(axis=0)
        return (min_row, min_col, max_row + 1, max_col + 1)

    for cl in clusters:
        bbox = get_bbox_for_cluster(cl, agg_labels)
        modified_np = zero_bbox(original_np, bbox)
        modified_img = Image.fromarray(modified_np)
        input_mod = transform(modified_img).unsqueeze(0)
        input_mod.requires_grad = True

        output_mod = model(input_mod)
        probs_mod = torch.softmax(output_mod, dim=1)
        mod_prob = probs_mod[0, baseline_class].item()

        diff = mod_prob - baseline_prob
        effects.append(diff)
        concept_labels.append(cl)

    # Plot the summary
    colors = ['green' if diff < 0 else 'red' for diff in effects]

    plt.figure(figsize=(10, 6))
    bars = plt.bar([str(c) for c in concept_labels], effects, color=colors)
    plt.xlabel("Causal Concept")
    plt.ylabel("Change in Baseline Class Probability\n(Modified - Baseline)")
    plt.title("Effect of Removing Each Causal Concept on Model Prediction")
    plt.axhline(0, color='black', linewidth=0.8)

    for bar, diff in zip(bars, effects):
        height = bar.get_height()
        plt.annotate(f"{diff:.3f}",
                     xy=(bar.get_x() + bar.get_width() / 2, height),
                     xytext=(0, 3),
                     textcoords="offset points",
                     ha='center', va='bottom')

    plt.show()

    # Print summary
    print("Summary of Causal Concept Effects (per concept):")
    for cl, diff in zip(concept_labels, effects):
        effect_type = "Positive Contribution" if diff < 0 else "Negative Effect"
        print(f"Concept {cl}: Change = {diff:.4f} → {effect_type}")

This cell calls the `plot_concept_summary` function, which visualizes the impact of removing each causal concept (cluster) on the model's prediction for the baseline class. The plot shows the change in probability, helping to identify which concepts positively or negatively influence the model's decision.

In [None]:
plot_concept_summary(
    input_tensor=input_tensor,
    agg_labels=agg_labels,
    model=model,
    transform=transform,
    prediction=prediction,
    predicted_prob=predicted_prob
)

This cell defines the `plot_pcs_visualization` function. This function extends the causal concept analysis by visualizing the Probability Change Score (PCS) and Bayes Factor (BF) for each cluster directly on the image. It draws bounding boxes around each concept, color-coded by their PCS, providing an intuitive understanding of which regions have the most significant impact on the model's output when removed.

In [None]:
from matplotlib.patches import Rectangle
from matplotlib.colors import Normalize
import matplotlib.cm as cm

def plot_pcs_visualization(
    input_tensor,
    agg_labels,
    model,
    transform,
    prediction,
    predicted_prob,
    compute_pcs_and_bf  # a function that returns (pcs, bf)
):
    mean = [0.485, 0.456, 0.406]
    std  = [0.229, 0.224, 0.225]

    original_np = unnormalize_and_to_np(input_tensor[0], mean, std)

    def get_bbox(cluster_label):
        coords = np.argwhere(agg_labels == cluster_label)
        min_row, min_col = coords.min(axis=0)
        max_row, max_col = coords.max(axis=0)
        return (min_row, min_col, max_row + 1, max_col + 1)

    detailed_results = []

    clusters = [cl for cl in np.unique(agg_labels) if cl != -1]

    for cl in clusters:
        bbox = get_bbox(cl)

        modified_np = zero_bbox(original_np, bbox)
        modified_img = Image.fromarray(modified_np)
        input_mod = transform(modified_img).unsqueeze(0)
        input_mod.requires_grad = True

        output = model(input_mod)
        probs = torch.softmax(output, dim=1)
        mod_prob = probs[0, prediction].item()
        mod_pred = torch.argmax(probs).item()

        pcs, bf = compute_pcs_and_bf(predicted_prob, mod_prob)

        detailed_results.append({
            'cluster': cl,
            'modified_pred': mod_pred,
            'modified_prob': mod_prob,
            'pcs': pcs,
            'bf': bf,
            'bbox': bbox
        })

    # Sort by PCS descending
    detailed_results.sort(key=lambda x: x['pcs'], reverse=True)

    # Plot
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.imshow(original_np)
    ax.axis("off")

    cmap = cm.get_cmap("Reds")
    norm = Normalize(vmin=0, vmax=max(r['pcs'] for r in detailed_results))

    for res in detailed_results:
        cl = res['cluster']
        pcs = res['pcs']
        bf = res['bf']
        bbox = res['bbox']
        color = cmap(norm(pcs))

        rect = Rectangle(
            (bbox[1], bbox[0]),
            bbox[3] - bbox[1],
            bbox[2] - bbox[0],
            linewidth=2,
            edgecolor=color,
            facecolor='none'
        )
        ax.add_patch(rect)
        ax.text(
            bbox[1], bbox[0] - 5,
            f"#{cl} | PCS: {pcs:.2f} | BF: {bf:.1f}",
            color='black', fontsize=9,
            bbox=dict(facecolor='white', edgecolor='gray', boxstyle='round,pad=0.3')
        )

    plt.show()

This cell calls the `plot_pcs_visualization` function, which visualizes the Probability Change Score (PCS) and Bayes Factor (BF) for each causal concept (cluster). The plot displays bounding boxes around each concept on the original image, with colors indicating their PCS, offering a detailed visual analysis of their individual causal contributions.

In [None]:
plot_pcs_visualization(
    input_tensor=input_tensor,
    agg_labels=agg_labels,
    model=model,
    transform=transform,
    prediction=prediction,
    predicted_prob=predicted_prob,
    compute_pcs_and_bf=compute_pcs_and_bf
)

### DECLARATION:

Generative AI tools were used to write some parts of the code and explainations of the code in this notebook.