In [None]:
# Imports
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch
import json
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import kagglehub
import spacy
from scipy.stats import pearsonr
from scipy.spatial.distance import cosine

np.random.seed(42)
torch.manual_seed(42)

In [None]:
# Device setup
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

In [None]:
# Load COCO dataset
path = kagglehub.dataset_download("nagasai524/mini-coco2014-dataset-for-image-captioning")

with open(os.path.join(path, "captions.json"), "r") as f:
    data = json.load(f)
    annotations = data["annotations"] if isinstance(data, dict) else data

captions = {}
for item in annotations:
    img_id = item["image_id"]
    if img_id not in captions:
        captions[img_id] = []
    captions[img_id].append(item["caption"])

for root, dirs, files in os.walk(path):
    if any(f.endswith(".jpg") for f in files):
        img_folder = root
        break

img_ids = list(captions.keys())
print(f"Ready: {len(captions)} images loaded")

In [None]:
# Image loading function
def get_image(idx):
    """Load and resize image to 384x384"""
    img_id = img_ids[idx]
    img_path = os.path.join(img_folder, f"COCO_train2014_{img_id:012d}.jpg")
    if not os.path.exists(img_path):
        img_path = os.path.join(img_folder, f"{img_id}.jpg")
    return Image.open(img_path).convert("RGB").resize((384, 384)), captions[img_id]

# Show selected image
selected = [87]
fig, axes = plt.subplots(1, len(selected), figsize=(8, 8))
if len(selected) == 1:
    axes = [axes]

for idx, ax in zip(selected, axes):
    img, caps = get_image(idx)
    ax.imshow(img)
    ax.axis("off")
plt.show()

for idx in selected:
    img, caps = get_image(idx)
    print(f"\nGround truth captions:")
    for cap in caps:
        print(f"  - {cap}")

In [None]:
# Normalize images
normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.48145466, 0.4578275, 0.40821073], 
                        [0.26862954, 0.26130258, 0.27577711])
])

def to_tensor(idx):
    img, _ = get_image(idx)
    return normalize(img).unsqueeze(0).to(device)

test = to_tensor(0)
print(f"Tensor ready for BLIP: {test.shape}")

In [None]:
# Load BLIP model
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
model.eval()
print("BLIP model loaded")

In [None]:
# Generate caption baseline
def generate_blip_caption(idx, show_image=True):
    """Run BLIP on one COCO image and print GT and predicted captions"""
    img, gt_caps = get_image(idx)

    if show_image:
        plt.imshow(img)
        plt.axis("off")
        plt.show()

    inputs = processor(img, return_tensors="pt").to(device)

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=30)

    pred_caption = processor.decode(output_ids[0], skip_special_tokens=True)

    print(f"BLIP caption: {pred_caption}")

    return img, gt_caps, pred_caption

for i in [87]:
    generate_blip_caption(i, show_image=False)

In [None]:
# Extract attention
def extract_blip_attention(idx):
    """Extract vision attention from BLIP model"""
    img, gt_caps = get_image(idx)
    inputs = processor(img, return_tensors="pt").to(device)
    
    vision_attentions = []
    
    def hook_fn(module, input, output):
        if len(output) > 1 and output[1] is not None:
            vision_attentions.append(output[1].detach())
    
    hooks = []
    for layer in model.vision_model.encoder.layers:
        hook = layer.self_attn.register_forward_hook(hook_fn)
        hooks.append(hook)
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=30)
        pred_caption = processor.decode(outputs[0], skip_special_tokens=True)
    
    for hook in hooks:
        hook.remove()
    
    print(f"Caption: {pred_caption}")
    print(f"Captured {len(vision_attentions)} layers of attention")
    
    return img, gt_caps, pred_caption, vision_attentions

In [None]:
# Aggregation methods
try:
    nlp = spacy.load("en_core_web_sm")
except:
    import os
    os.system("python -m spacy download en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

# Sentence-based
def aggregate_sentence_based(vision_attentions, k_percent=0.5):
    """Aggregate attention across all layers (mean pooling)"""
    all_attention = []
    for layer_attn in vision_attentions:
        attn = layer_attn.mean(dim=1)
        cls_to_patches = attn[0, 0, 1:]
        all_attention.append(cls_to_patches)
    
    aggregated = torch.stack(all_attention).mean(dim=0)
    aggregated_2d = aggregated.reshape(24, 24)
    
    k = int(576 * k_percent)
    threshold = torch.topk(aggregated.flatten(), k).values.min()
    mask = (aggregated_2d >= threshold).float()
    
    return aggregated_2d.cpu().numpy(), mask.cpu().numpy()


# Word-based (Stability-based)
def aggregate_word_based(vision_attentions, k_percent=0.5):
    """
    Select patches with STABLE attention across layers.
    Rationale: Patches corresponding to semantic objects/words should receive
    consistent attention across layers, as opposed to spurious activations.
    High stability = likely object regions = word-level semantics.
    """
    all_attention = []
    for layer_attn in vision_attentions:
        attn = layer_attn.mean(dim=1)
        cls_to_patches = attn[0, 0, 1:]
        all_attention.append(cls_to_patches)
    
    all_attention = torch.stack(all_attention)  # [num_layers, 576]
    
    # Compute stability: inverse of coefficient of variation (std/mean)
    # Higher stability = lower relative variance = more consistent attention
    mean_attn = all_attention.mean(dim=0)
    std_attn = all_attention.std(dim=0)
    
    # Avoid division by zero
    epsilon = 1e-10
    stability = mean_attn / (std_attn + epsilon)  # High stability = consistent patches
    
    # Weight by both stability AND attention magnitude
    # This selects patches that are both important AND consistently attended
    stability_weighted = stability * mean_attn
    stability_2d = stability_weighted.reshape(24, 24)
    
    k = int(576 * k_percent)
    threshold = torch.topk(stability_weighted.flatten(), k).values.min()
    mask = (stability_2d >= threshold).float()
    
    return stability_2d.cpu().numpy(), mask.cpu().numpy()


# POS-weighted
def aggregate_pos_weighted(vision_attentions, caption, k_percent=0.5):
    """
    Weight attention by caption complexity (number of content words).
    More content words = more semantic information = weight layers differently.
    """
    doc = nlp(caption)
    
    # Count content words (nouns, verbs, adjectives)
    content_words = sum(1 for token in doc if token.pos_ in ["NOUN", "PROPN", "VERB", "ADJ"])
    function_words = len(doc) - content_words
    
    # Complexity ratio: higher = more content-heavy caption
    content_ratio = content_words / max(len(doc), 1)
    
    all_attention = []
    for i, layer_attn in enumerate(vision_attentions):
        attn = layer_attn.mean(dim=1)
        cls_to_patches = attn[0, 0, 1:]
        
        # Weight middle layers more for content-heavy captions
        # (middle layers in ViT capture semantic/object features)
        layer_pos = i / len(vision_attentions)
        middle_bias = 1.0 + content_ratio * (1.0 - 4 * (layer_pos - 0.5)**2)
        
        all_attention.append(cls_to_patches * middle_bias)
    
    aggregated = torch.stack(all_attention).mean(dim=0)
    aggregated_2d = aggregated.reshape(24, 24)
    
    k = int(576 * k_percent)
    threshold = torch.topk(aggregated.flatten(), k).values.min()
    mask = (aggregated_2d >= threshold).float()
    
    return aggregated_2d.cpu().numpy(), mask.cpu().numpy()


# Temporal-decay
def aggregate_temporal_decay(vision_attentions, k_percent=0.5, decay=0.95):
    """
    Weight early layers higher (exponential decay for later layers).
    Rationale: Early layers capture low-level features that may be more vulnerable.
    """
    all_attention = []
    
    for i, layer_attn in enumerate(vision_attentions):
        attn = layer_attn.mean(dim=1)
        cls_to_patches = attn[0, 0, 1:]
        weight = decay ** i
        all_attention.append(cls_to_patches * weight)
    
    aggregated = torch.stack(all_attention).mean(dim=0)
    aggregated_2d = aggregated.reshape(24, 24)
    
    k = int(576 * k_percent)
    threshold = torch.topk(aggregated.flatten(), k).values.min()
    mask = (aggregated_2d >= threshold).float()
    
    return aggregated_2d.cpu().numpy(), mask.cpu().numpy()


# Variance-based
def aggregate_variance_based(vision_attentions, k_percent=0.5):
    """
    Select patches with HIGH variance across layers.
    Rationale: High variance = different layers attend differently = potentially
    important features with layer-specific processing.
    """
    all_attention = []
    for layer_attn in vision_attentions:
        attn = layer_attn.mean(dim=1)
        cls_to_patches = attn[0, 0, 1:]
        all_attention.append(cls_to_patches)
    
    all_attention = torch.stack(all_attention)
    variance = all_attention.var(dim=0)
    variance_2d = variance.reshape(24, 24)
    
    k = int(576 * k_percent)
    threshold = torch.topk(variance.flatten(), k).values.min()
    mask = (variance_2d >= threshold).float()
    
    return variance_2d.cpu().numpy(), mask.cpu().numpy()


# Entropy-based
def aggregate_entropy_based(vision_attentions, k_percent=0.5):
    """
    Select patches with HIGH entropy (uncertainty) across layers.
    Rationale: High entropy = inconsistent attention distribution = uncertain regions
    that might be critical decision boundaries.
    """
    all_attention = []
    for layer_attn in vision_attentions:
        attn = layer_attn.mean(dim=1)
        cls_to_patches = attn[0, 0, 1:]
        all_attention.append(cls_to_patches)
    
    all_attention = torch.stack(all_attention)
    
    # Normalize each layer's attention to probability distribution
    all_attention = torch.softmax(all_attention, dim=1)
    
    epsilon = 1e-10
    entropy = -(all_attention * torch.log(all_attention + epsilon)).sum(dim=0)
    entropy_2d = entropy.reshape(24, 24)
    
    k = int(576 * k_percent)
    threshold = torch.topk(entropy.flatten(), k).values.min()
    mask = (entropy_2d >= threshold).float()
    
    return entropy_2d.cpu().numpy(), mask.cpu().numpy()

print("All 6 aggregation methods loaded!")

In [None]:
# Visualization function with adjustable k_percent
def visualize_attention_overlay(img, pred_caption, agg_attn, mask, method_name="Sentence-based", k_viz=0.3):
    """
    Visualize attention with overlay on original image
    k_viz: percentage for visualization (default 30% for clearer view)
    """
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(img)
    axes[0].set_title(f"Original\n{pred_caption[:40]}")
    axes[0].axis("off")
    
    img_array = np.array(img.resize((384, 384)))
    attn_resized = np.array(Image.fromarray((agg_attn * 255).astype(np.uint8)).resize((384, 384)))
    attn_resized = attn_resized / 255.0
    
    axes[1].imshow(img_array)
    axes[1].imshow(attn_resized, cmap="hot", alpha=0.5)
    axes[1].set_title(f"{method_name}\nHeatmap Overlay")
    axes[1].axis("off")
    
    axes[2].imshow(agg_attn, cmap="hot", interpolation="bilinear")
    axes[2].set_title("Attention Heatmap\n(Smoothed)")
    axes[2].axis("off")
    
    # Recompute mask with k_viz for visualization
    if isinstance(mask, np.ndarray):
        mask_tensor = torch.from_numpy(mask)
        agg_tensor = torch.from_numpy(agg_attn)
    else:
        mask_tensor = mask
        agg_tensor = agg_attn
    
    k_viz_pixels = int(576 * k_viz)
    threshold_viz = torch.topk(agg_tensor.flatten(), k_viz_pixels).values.min()
    mask_viz = (agg_tensor >= threshold_viz).float().numpy()
    
    mask_resized = np.array(Image.fromarray((mask_viz * 255).astype(np.uint8)).resize((384, 384), Image.NEAREST))
    mask_resized = mask_resized / 255.0
    
    axes[3].imshow(img_array)
    axes[3].imshow(mask_resized, cmap="Reds", alpha=0.4)
    axes[3].set_title(f"Top {int(k_viz*100)}% Patches\nSelected for Attack")
    axes[3].axis("off")
    
    plt.tight_layout()
    plt.show()

In [None]:
# Statistical analysis functions
def compute_attention_statistics(agg_attn, mask, method_name):
    """Compute statistical measures for attention maps"""
    stats = {
        'method': method_name,
        'mean': np.mean(agg_attn),
        'std': np.std(agg_attn),
        'min': np.min(agg_attn),
        'max': np.max(agg_attn),
        'median': np.median(agg_attn),
        'sparsity': np.sum(mask) / mask.size,  # Fraction of patches selected
        'entropy': -np.sum(agg_attn * np.log(agg_attn + 1e-10))  # Attention entropy
    }
    return stats

def compute_mask_overlap(mask1, mask2):
    """Compute Intersection over Union (IoU) between two binary masks"""
    intersection = np.sum(mask1 * mask2)
    union = np.sum(np.maximum(mask1, mask2))
    return intersection / (union + 1e-10)

def analyze_attention_diversity(methods_results):
    """
    Analyze diversity between different aggregation methods
    methods_results: dict of {method_name: (agg_attn, mask)}
    """
    method_names = list(methods_results.keys())
    n_methods = len(method_names)
    
    # Compute statistics for each method
    print("=" * 80)
    print("ATTENTION STATISTICS PER METHOD")
    print("=" * 80)
    stats_df = []
    for name, (agg, mask) in methods_results.items():
        stats = compute_attention_statistics(agg, mask, name)
        stats_df.append(stats)
    
    stats_df = pd.DataFrame(stats_df)
    print(stats_df.to_string(index=False))
    print()
    
    # Compute pairwise correlations between attention maps
    print("=" * 80)
    print("PAIRWISE CORRELATIONS (Pearson) - Attention Maps")
    print("=" * 80)
    corr_matrix = np.zeros((n_methods, n_methods))
    for i, name1 in enumerate(method_names):
        for j, name2 in enumerate(method_names):
            agg1, _ = methods_results[name1]
            agg2, _ = methods_results[name2]
            corr, _ = pearsonr(agg1.flatten(), agg2.flatten())
            corr_matrix[i, j] = corr
    
    corr_df = pd.DataFrame(corr_matrix, 
                          index=method_names, 
                          columns=method_names)
    print(corr_df.to_string())
    print()
    
    # Compute pairwise IoU between binary masks
    print("=" * 80)
    print("PAIRWISE IoU (Intersection over Union) - Binary Masks")
    print("=" * 80)
    iou_matrix = np.zeros((n_methods, n_methods))
    for i, name1 in enumerate(method_names):
        for j, name2 in enumerate(method_names):
            _, mask1 = methods_results[name1]
            _, mask2 = methods_results[name2]
            iou = compute_mask_overlap(mask1, mask2)
            iou_matrix[i, j] = iou
    
    iou_df = pd.DataFrame(iou_matrix, 
                         index=method_names, 
                         columns=method_names)
    print(iou_df.to_string())
    print()
    
    # Visualize correlation heatmap
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    im1 = axes[0].imshow(corr_matrix, cmap='RdYlGn', vmin=0, vmax=1)
    axes[0].set_xticks(range(n_methods))
    axes[0].set_yticks(range(n_methods))
    axes[0].set_xticklabels(method_names, rotation=45, ha='right')
    axes[0].set_yticklabels(method_names)
    axes[0].set_title('Correlation Matrix\n(Attention Maps)')
    
    for i in range(n_methods):
        for j in range(n_methods):
            text = axes[0].text(j, i, f'{corr_matrix[i, j]:.2f}',
                              ha="center", va="center", color="black", fontsize=9)
    
    plt.colorbar(im1, ax=axes[0])
    
    im2 = axes[1].imshow(iou_matrix, cmap='RdYlGn', vmin=0, vmax=1)
    axes[1].set_xticks(range(n_methods))
    axes[1].set_yticks(range(n_methods))
    axes[1].set_xticklabels(method_names, rotation=45, ha='right')
    axes[1].set_yticklabels(method_names)
    axes[1].set_title('IoU Matrix\n(Binary Masks)')
    
    for i in range(n_methods):
        for j in range(n_methods):
            text = axes[1].text(j, i, f'{iou_matrix[i, j]:.2f}',
                              ha="center", va="center", color="black", fontsize=9)
    
    plt.colorbar(im2, ax=axes[1])
    plt.tight_layout()
    plt.show()
    
    return stats_df, corr_df, iou_df

print("Statistical analysis functions loaded!")

In [None]:
# Compare all 6 methods on selected image
idx = 87
img, gt_caps, pred_caption, vision_attn = extract_blip_attention(idx)

# Generate attention with k=0.5 (50% for actual attack)
methods = {
    "Sentence-based": lambda: aggregate_sentence_based(vision_attn, k_percent=0.5),
    "Word-based": lambda: aggregate_word_based(vision_attn, k_percent=0.5),
    "POS-weighted": lambda: aggregate_pos_weighted(vision_attn, pred_caption, k_percent=0.5),
    "Temporal-decay": lambda: aggregate_temporal_decay(vision_attn, k_percent=0.5),
    "Variance-based": lambda: aggregate_variance_based(vision_attn, k_percent=0.5),
    "Entropy-based": lambda: aggregate_entropy_based(vision_attn, k_percent=0.5)
}

# Store results for analysis
methods_results = {}

# Visualize with k=0.3 (30% for clearer view like paper)
for name, method in methods.items():
    print(f"\n=== {name} ===")
    agg, mask = method()
    methods_results[name] = (agg, mask)
    visualize_attention_overlay(img, pred_caption, agg, mask, name, k_viz=0.3)

# Perform statistical analysis
print("\n\n")
print("#" * 80)
print("QUANTITATIVE ANALYSIS OF ATTENTION METHODS")
print("#" * 80)
stats_df, corr_df, iou_df = analyze_attention_diversity(methods_results)

In [None]:
# Multi-image analysis: Test diversity across different images
test_indices = [87, 0, 42, 100, 150]  # Sample of diverse images

print("=" * 80)
print("MULTI-IMAGE ANALYSIS: Average Correlation and IoU Across Images")
print("=" * 80)

all_correlations = []
all_ious = []

for idx in test_indices:
    print(f"\nProcessing image {idx}...")
    img, gt_caps, pred_caption, vision_attn = extract_blip_attention(idx)
    
    methods_results = {}
    for name, method_fn in methods.items():
        # Recreate lambdas with current vision_attn
        if name == "Sentence-based":
            agg, mask = aggregate_sentence_based(vision_attn, k_percent=0.5)
        elif name == "Word-based":
            agg, mask = aggregate_word_based(vision_attn, k_percent=0.5)
        elif name == "POS-weighted":
            agg, mask = aggregate_pos_weighted(vision_attn, pred_caption, k_percent=0.5)
        elif name == "Temporal-decay":
            agg, mask = aggregate_temporal_decay(vision_attn, k_percent=0.5)
        elif name == "Variance-based":
            agg, mask = aggregate_variance_based(vision_attn, k_percent=0.5)
        elif name == "Entropy-based":
            agg, mask = aggregate_entropy_based(vision_attn, k_percent=0.5)
        
        methods_results[name] = (agg, mask)
    
    # Compute correlations and IoU for this image
    method_names = list(methods_results.keys())
    n_methods = len(method_names)
    
    corr_matrix = np.zeros((n_methods, n_methods))
    iou_matrix = np.zeros((n_methods, n_methods))
    
    for i, name1 in enumerate(method_names):
        for j, name2 in enumerate(method_names):
            agg1, mask1 = methods_results[name1]
            agg2, mask2 = methods_results[name2]
            
            corr, _ = pearsonr(agg1.flatten(), agg2.flatten())
            corr_matrix[i, j] = corr
            
            iou = compute_mask_overlap(mask1, mask2)
            iou_matrix[i, j] = iou
    
    all_correlations.append(corr_matrix)
    all_ious.append(iou_matrix)

# Average across all images
avg_corr = np.mean(all_correlations, axis=0)
avg_iou = np.mean(all_ious, axis=0)
std_corr = np.std(all_correlations, axis=0)
std_iou = np.std(all_ious, axis=0)

print("\n" + "=" * 80)
print("AVERAGE CORRELATION ACROSS IMAGES")
print("=" * 80)
corr_df_avg = pd.DataFrame(avg_corr, index=method_names, columns=method_names)
print(corr_df_avg.to_string())

print("\n" + "=" * 80)
print("AVERAGE IoU ACROSS IMAGES")
print("=" * 80)
iou_df_avg = pd.DataFrame(avg_iou, index=method_names, columns=method_names)
print(iou_df_avg.to_string())

# Visualize average correlations
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

im1 = axes[0].imshow(avg_corr, cmap='RdYlGn', vmin=0, vmax=1)
axes[0].set_xticks(range(n_methods))
axes[0].set_yticks(range(n_methods))
axes[0].set_xticklabels(method_names, rotation=45, ha='right')
axes[0].set_yticklabels(method_names)
axes[0].set_title(f'Average Correlation\n({len(test_indices)} images)')

for i in range(n_methods):
    for j in range(n_methods):
        text = axes[0].text(j, i, f'{avg_corr[i, j]:.2f}',
                          ha="center", va="center", color="black", fontsize=9)

plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(avg_iou, cmap='RdYlGn', vmin=0, vmax=1)
axes[1].set_xticks(range(n_methods))
axes[1].set_yticks(range(n_methods))
axes[1].set_xticklabels(method_names, rotation=45, ha='right')
axes[1].set_yticklabels(method_names)
axes[1].set_title(f'Average IoU\n({len(test_indices)} images)')

for i in range(n_methods):
    for j in range(n_methods):
        text = axes[1].text(j, i, f'{avg_iou[i, j]:.2f}',
                          ha="center", va="center", color="black", fontsize=9)

plt.colorbar(im2, ax=axes[1])
plt.tight_layout()
plt.show()

print("\n" + "=" * 80)
print("KEY INSIGHTS")
print("=" * 80)
print("Lower correlation/IoU = More diverse attention selection")
print("Methods should show DIFFERENT selections to justify multiple approaches")
print("\nOff-diagonal values (comparing different methods):")
for i in range(n_methods):
    for j in range(i+1, n_methods):
        print(f"  {method_names[i]} vs {method_names[j]}: "
              f"Corr={avg_corr[i,j]:.3f} (±{std_corr[i,j]:.3f}), "
              f"IoU={avg_iou[i,j]:.3f} (±{std_iou[i,j]:.3f})")

In [None]:
# Compare all 6 methods on selected image
idx = 87
img, gt_caps, pred_caption, vision_attn = extract_blip_attention(idx)

# Generate attention with k=0.5 (50% for actual attack)
methods = {
    "Sentence-based": lambda: aggregate_sentence_based(vision_attn, k_percent=0.5),
    "Word-based": lambda: aggregate_word_based(vision_attn, k_percent=0.5),
    "POS-weighted": lambda: aggregate_pos_weighted(vision_attn, pred_caption, k_percent=0.5),
    "Temporal-decay": lambda: aggregate_temporal_decay(vision_attn, k_percent=0.5),
    "Variance-based": lambda: aggregate_variance_based(vision_attn, k_percent=0.5),
    "Entropy-based": lambda: aggregate_entropy_based(vision_attn, k_percent=0.5)
}

# But visualize with k=0.3 (30% for clearer view like paper)
for name, method in methods.items():
    print(f"\n=== {name} ===")
    agg, mask = method()
    visualize_attention_overlay(img, pred_caption, agg, mask, name, k_viz=0.3)