In [None]:
# Imports
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch
import torch.nn.functional as F
import json
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import kagglehub
import spacy

np.random.seed(42)
torch.manual_seed(42)

In [None]:
# Device setup
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

In [None]:
# Load COCO dataset
path = kagglehub.dataset_download("nagasai524/mini-coco2014-dataset-for-image-captioning")

with open(os.path.join(path, "captions.json"), "r") as f:
    data = json.load(f)
    annotations = data["annotations"] if isinstance(data, dict) else data

captions = {}
for item in annotations:
    img_id = item["image_id"]
    if img_id not in captions:
        captions[img_id] = []
    captions[img_id].append(item["caption"])

for root, dirs, files in os.walk(path):
    if any(f.endswith(".jpg") for f in files):
        img_folder = root
        break

img_ids = list(captions.keys())
print(f"Ready: {len(captions)} images loaded")

In [None]:
# Image loading function
def get_image(idx):
    """Load image at original size"""
    img_id = img_ids[idx]
    img_path = os.path.join(img_folder, f"COCO_train2014_{img_id:012d}.jpg")
    if not os.path.exists(img_path):
        img_path = os.path.join(img_folder, f"{img_id}.jpg")
    return Image.open(img_path).convert("RGB"), captions[img_id]

# Show test image
img, caps = get_image(87)
plt.imshow(img)
plt.axis("off")
plt.title("Test Image (Index 87)")
plt.show()

print("\nGround truth captions:")
for cap in caps:
    print(f"  - {cap}")

In [None]:
# Load SAT model
import sys
sys.path.insert(0, "a-PyTorch-Tutorial-to-Image-Captioning")  # Adjust path if needed
from models import Encoder, Decoder

checkpoint_path = "checkpoint/BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar"
checkpoint = torch.load(checkpoint_path, map_location=str(device))

sat_encoder = checkpoint["encoder"].to(device).eval()
sat_decoder = checkpoint["decoder"].to(device).eval()

with open("checkpoint/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json", "r") as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}

print("SAT model loaded successfully!")

In [None]:
# Generate caption with SAT (test if model works)
def generate_sat_caption(idx, show_image=True):
    """Run SAT on one COCO image and print GT and predicted captions"""
    img, gt_caps = get_image(idx)

    if show_image:
        plt.imshow(img)
        plt.axis("off")
        plt.show()

    # SAT preprocessing
    img_array = np.array(img)
    if len(img_array.shape) == 2:
        img_array = np.stack([img_array] * 3, axis=2)
    
    img_resized = Image.fromarray(img_array).resize((256, 256))
    img_array = np.array(img_resized).transpose(2, 0, 1) / 255.0
    img_tensor = torch.FloatTensor(img_array).to(device)
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    img_tensor = normalize(img_tensor).unsqueeze(0)
    
    # Encode
    encoder_out = sat_encoder(img_tensor)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)
    encoder_out = encoder_out.view(1, -1, encoder_dim)
    
    # Decode
    h, c = sat_decoder.init_hidden_state(encoder_out)
    
    word_indices = [word_map["<start>"]]
    
    for step in range(50):
        prev_word = torch.LongTensor([word_indices[-1]]).to(device)
        embeddings = sat_decoder.embedding(prev_word)
        
        awe, alpha = sat_decoder.attention(encoder_out, h)
        gate = sat_decoder.sigmoid(sat_decoder.f_beta(h))
        awe = gate * awe
        
        h, c = sat_decoder.decode_step(
            torch.cat([embeddings, awe], dim=1), (h, c)
        )
        
        scores = sat_decoder.fc(h)
        scores = F.log_softmax(scores, dim=1)
        next_word = scores.argmax(1).item()
        
        word_indices.append(next_word)
        
        if next_word == word_map["<end>"]:
            break
    
    # Get caption
    words = [rev_word_map[idx] for idx in word_indices[1:-1]]
    pred_caption = " ".join(words)
    
    print(f"SAT caption: {pred_caption}")
    
    return img, gt_caps, pred_caption

# Test on image 87
for i in [87]:
    generate_sat_caption(i, show_image=False)

In [None]:
# Load spaCy for POS tagging
try:
    nlp = spacy.load("en_core_web_sm")
except:
    os.system("python -m spacy download en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

print("spaCy loaded!")

In [None]:
# Extract SAT attention
def extract_sat_attention(idx):
    """
    Extract per-word cross-attention from SAT model
    Returns: img, gt_caps, pred_caption, word_attentions [num_words, 24, 24], words
    """
    img, gt_caps = get_image(idx)
    
    # SAT preprocessing
    img_array = np.array(img)
    if len(img_array.shape) == 2:
        img_array = np.stack([img_array] * 3, axis=2)
    
    img_resized = Image.fromarray(img_array).resize((256, 256))
    img_array = np.array(img_resized).transpose(2, 0, 1) / 255.0
    img_tensor = torch.FloatTensor(img_array).to(device)
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    img_tensor = normalize(img_tensor).unsqueeze(0)
    
    # Encode
    encoder_out = sat_encoder(img_tensor)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)
    encoder_out = encoder_out.view(1, -1, encoder_dim)
    
    # Decode with attention tracking
    h, c = sat_decoder.init_hidden_state(encoder_out)
    
    word_indices = [word_map["<start>"]]
    alphas_list = []
    
    for step in range(50):
        prev_word = torch.LongTensor([word_indices[-1]]).to(device)
        embeddings = sat_decoder.embedding(prev_word)
        
        awe, alpha = sat_decoder.attention(encoder_out, h)
        alpha = alpha.view(enc_image_size, enc_image_size)
        
        gate = sat_decoder.sigmoid(sat_decoder.f_beta(h))
        awe = gate * awe
        
        h, c = sat_decoder.decode_step(
            torch.cat([embeddings, awe], dim=1), (h, c)
        )
        
        scores = sat_decoder.fc(h)
        scores = F.log_softmax(scores, dim=1)
        next_word = scores.argmax(1).item()
        
        word_indices.append(next_word)
        alphas_list.append(alpha)
        
        if next_word == word_map["<end>"]:
            break
    
    # Remove start/end tokens
    word_indices = word_indices[1:-1]
    alphas = torch.stack(alphas_list[:-1])
    
    # Resize to 24x24
    alphas_resized = F.interpolate(
        alphas.unsqueeze(1), 
        size=(24, 24), 
        mode="bilinear"
    ).squeeze(1)
    
    words = [rev_word_map[idx] for idx in word_indices]
    pred_caption = " ".join(words)
    
    print(f"Caption: {pred_caption}")
    print(f"Captured {len(alphas_resized)} words of attention")
    print(f"Attention shape: {alphas_resized.shape}")
    
    return img, gt_caps, pred_caption, alphas_resized, words

In [None]:
# Aggregation methods (6 total: 2 baselines + 4 novel)

# Baseline 1: Sentence-based (sum across all words)
def aggregate_sentence_based(word_attentions, k_percent=0.5):
    """Sum attention across all words"""
    aggregated = word_attentions.mean(dim=0)
    
    k = int(576 * k_percent)
    threshold = torch.topk(aggregated.flatten(), k).values.min()
    mask = (aggregated >= threshold).float()
    
    return aggregated.cpu().numpy(), mask.cpu().numpy()


# Baseline 2: Word-based (union of per-word top-k)
def aggregate_word_based(word_attentions, k_percent=0.5):
    """Select top-k pixels per word, then union"""
    k_per_word = int(576 * k_percent / len(word_attentions))
    
    combined_mask = torch.zeros(24, 24)
    for word_attn in word_attentions:
        threshold = torch.topk(word_attn.flatten(), k_per_word).values.min()
        combined_mask += (word_attn >= threshold).float()
    
    aggregated = word_attentions.mean(dim=0)
    mask = (combined_mask > 0).float()
    
    return aggregated.cpu().numpy(), mask.cpu().numpy()


# Novel 1: POS-weighted
def aggregate_pos_weighted(word_attentions, words, k_percent=0.5):
    """Weight attention by part-of-speech importance"""
    doc = nlp(" ".join(words))
    
    weighted_attn = torch.zeros_like(word_attentions[0])
    for i, (word, token) in enumerate(zip(words, doc)):
        if token.pos_ in ["NOUN", "PROPN"]:
            weight = 2.0
        elif token.pos_ == "VERB":
            weight = 1.5
        elif token.pos_ == "ADJ":
            weight = 1.2
        else:
            weight = 1.0
        weighted_attn += word_attentions[i] * weight
    
    aggregated = weighted_attn / len(words)
    
    k = int(576 * k_percent)
    threshold = torch.topk(aggregated.flatten(), k).values.min()
    mask = (aggregated >= threshold).float()
    
    return aggregated.cpu().numpy(), mask.cpu().numpy()


# Novel 2: Temporal-decay
def aggregate_temporal_decay(word_attentions, k_percent=0.5, decay=0.9):
    """Weight earlier words higher with exponential decay"""
    weighted_attn = torch.zeros_like(word_attentions[0])
    
    for i, word_attn in enumerate(word_attentions):
        weight = decay ** i
        weighted_attn += word_attn * weight
    
    aggregated = weighted_attn / len(word_attentions)
    
    k = int(576 * k_percent)
    threshold = torch.topk(aggregated.flatten(), k).values.min()
    mask = (aggregated >= threshold).float()
    
    return aggregated.cpu().numpy(), mask.cpu().numpy()


# Novel 3: Variance-based
def aggregate_variance_based(word_attentions, k_percent=0.5):
    """Select pixels with high variance across words"""
    variance = word_attentions.var(dim=0)
    
    k = int(576 * k_percent)
    threshold = torch.topk(variance.flatten(), k).values.min()
    mask = (variance >= threshold).float()
    
    return variance.cpu().numpy(), mask.cpu().numpy()


# Novel 4: Entropy-based
def aggregate_entropy_based(word_attentions, k_percent=0.5):
    """Select pixels with high entropy (uncertainty) across words"""
    word_attentions_norm = torch.softmax(word_attentions.flatten(1), dim=1)
    word_attentions_norm = word_attentions_norm.view_as(word_attentions)
    
    epsilon = 1e-10
    entropy = -(word_attentions_norm * torch.log(word_attentions_norm + epsilon)).sum(dim=0)
    
    k = int(576 * k_percent)
    threshold = torch.topk(entropy.flatten(), k).values.min()
    mask = (entropy >= threshold).float()
    
    return entropy.cpu().numpy(), mask.cpu().numpy()

print("All 6 aggregation methods loaded!")

In [None]:
# Visualization function
def visualize_attention_overlay(img, pred_caption, agg_attn, mask, method_name, k_viz=0.3):
    """
    Visualize attention with overlay
    k_viz: percentage for visualization (30% for clarity)
    """
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(img)
    axes[0].set_title(f"Original\n{pred_caption[:40]}")
    axes[0].axis("off")
    
    img_array = np.array(img.resize((384, 384)))
    attn_resized = np.array(Image.fromarray((agg_attn * 255).astype(np.uint8)).resize((384, 384)))
    attn_resized = attn_resized / 255.0
    
    axes[1].imshow(img_array)
    axes[1].imshow(attn_resized, cmap="hot", alpha=0.5)
    axes[1].set_title(f"{method_name}\nHeatmap Overlay")
    axes[1].axis("off")
    
    axes[2].imshow(agg_attn, cmap="hot", interpolation="bilinear")
    axes[2].set_title("Attention Heatmap")
    axes[2].axis("off")
    
    # Recompute mask for visualization
    agg_tensor = torch.from_numpy(agg_attn) if isinstance(agg_attn, np.ndarray) else agg_attn
    k_viz_pixels = int(576 * k_viz)
    threshold_viz = torch.topk(agg_tensor.flatten(), k_viz_pixels).values.min()
    mask_viz = (agg_tensor >= threshold_viz).float().numpy()
    
    mask_resized = np.array(Image.fromarray((mask_viz * 255).astype(np.uint8)).resize((384, 384), Image.NEAREST))
    mask_resized = mask_resized / 255.0
    
    axes[3].imshow(img_array)
    axes[3].imshow(mask_resized, cmap="Reds", alpha=0.4)
    axes[3].set_title(f"Top {int(k_viz*100)}% Patches\nSelected for Attack")
    axes[3].axis("off")
    
    plt.tight_layout()
    plt.show()

In [None]:
# Visualization function with adjustable k_percent
def visualize_attention_overlay(img, pred_caption, agg_attn, mask, method_name="Sentence-based", k_viz=0.3):
    """
    Visualize attention with overlay on original image
    k_viz: percentage for visualization (default 30% for clearer view)
    """
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(img)
    axes[0].set_title(f"Original\n{pred_caption[:40]}")
    axes[0].axis("off")
    
    img_array = np.array(img.resize((384, 384)))
    attn_resized = np.array(Image.fromarray((agg_attn * 255).astype(np.uint8)).resize((384, 384)))
    attn_resized = attn_resized / 255.0
    
    axes[1].imshow(img_array)
    axes[1].imshow(attn_resized, cmap="hot", alpha=0.5)
    axes[1].set_title(f"{method_name}\nHeatmap Overlay")
    axes[1].axis("off")
    
    axes[2].imshow(agg_attn, cmap="hot", interpolation="bilinear")
    axes[2].set_title("Attention Heatmap\n(Smoothed)")
    axes[2].axis("off")
    
    # Recompute mask with k_viz for visualization
    if isinstance(mask, np.ndarray):
        mask_tensor = torch.from_numpy(mask)
        agg_tensor = torch.from_numpy(agg_attn)
    else:
        mask_tensor = mask
        agg_tensor = agg_attn
    
    k_viz_pixels = int(576 * k_viz)
    threshold_viz = torch.topk(agg_tensor.flatten(), k_viz_pixels).values.min()
    mask_viz = (agg_tensor >= threshold_viz).float().numpy()
    
    mask_resized = np.array(Image.fromarray((mask_viz * 255).astype(np.uint8)).resize((384, 384), Image.NEAREST))
    mask_resized = mask_resized / 255.0
    
    axes[3].imshow(img_array)
    axes[3].imshow(mask_resized, cmap="Reds", alpha=0.4)
    axes[3].set_title(f"Top {int(k_viz*100)}% Patches\nSelected for Attack")
    axes[3].axis("off")
    
    plt.tight_layout()
    plt.show()

In [None]:
# Test extraction
idx = 87
img, gt_caps, pred_caption, word_attns, words = extract_sat_attention(idx)

print(f"\nWords: {words}")
print(f"Attention shape: {word_attns.shape}")

In [None]:
# Compare all 6 methods on image 87
idx = 87
img, gt_caps, pred_caption, word_attns, words = extract_sat_attention(idx)

methods = {
    "Sentence-based": lambda: aggregate_sentence_based(word_attns, k_percent=0.5),
    "Word-based": lambda: aggregate_word_based(word_attns, k_percent=0.5),
    "POS-weighted": lambda: aggregate_pos_weighted(word_attns, words, k_percent=0.5),
    "Temporal-decay": lambda: aggregate_temporal_decay(word_attns, k_percent=0.5),
    "Variance-based": lambda: aggregate_variance_based(word_attns, k_percent=0.5),
    "Entropy-based": lambda: aggregate_entropy_based(word_attns, k_percent=0.5)
}

for name, method in methods.items():
    print(f"\n=== {name} ===")
    agg, mask = method()
    visualize_attention_overlay(img, pred_caption, agg, mask, name, k_viz=0.3)