In [None]:
# Basic imports for SAT model and attack
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import zoom
import json
import os
import ssl

# Fix SSL certificate issue on macOS
ssl._create_default_https_context = ssl._create_unverified_context

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("Imports loaded")

In [None]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# SAT Encoder - uses pretrained ResNet101
class Encoder(nn.Module):
    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size
        
        # Use pretrained ResNet-101
        resnet = models.resnet101(pretrained=True)
        
        # Remove linear and pool layers
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        
        # Resize input images to fixed size for uniform outputs
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        
        # Disable gradient computation for encoder
        self.fine_tune(False)
    
    def forward(self, images):
        # Input: (batch_size, 3, 256, 256)
        # Output: (batch_size, 2048, enc_image_size, enc_image_size)
        out = self.resnet(images)
        out = self.adaptive_pool(out)
        out = out.permute(0, 2, 3, 1)  # (batch_size, enc_image_size, enc_image_size, 2048)
        return out
    
    def fine_tune(self, fine_tune=True):
        # Allow or prevent gradient computation
        for p in self.resnet.parameters():
            p.requires_grad = False
        # Only fine-tune conv layers 2-4
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

print("Encoder defined")

In [None]:
# Attention module - computes soft attention weights
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, encoder_out, decoder_hidden):
        # encoder_out: (batch_size, num_pixels, encoder_dim)
        # decoder_hidden: (batch_size, decoder_dim)
        
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        
        return attention_weighted_encoding, alpha

print("Attention module defined")

In [None]:
# SAT Decoder with attention
class DecoderWithAttention(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super(DecoderWithAttention, self).__init__()
        
        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)
        self.init_weights()
    
    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c
    
    def forward(self, encoder_out, encoded_captions, caption_lengths):
        # We won't use this during caption generation, just for training
        # For generation, we use caption_with_attention method
        pass

print("Decoder defined")

In [None]:
# Load or initialize SAT model
# Model hyperparameters (from sgrvinod's implementation)
encoder_dim = 2048
attention_dim = 512
embed_dim = 512
decoder_dim = 512
vocab_size = 10000  # Will be updated when we load actual vocab

# Initialize encoder and decoder
encoder = Encoder().to(device)
decoder = DecoderWithAttention(
    attention_dim=attention_dim,
    embed_dim=embed_dim,
    decoder_dim=decoder_dim,
    vocab_size=vocab_size,
    encoder_dim=encoder_dim
).to(device)

# Set to evaluation mode
encoder.eval()
decoder.eval()

print(f"SAT Model initialized")
print(f"Encoder output: 14x14 = 196 image patches")
print(f"Each patch: {encoder_dim} dimensions")

In [None]:
# Image preprocessing for SAT
# SAT uses 256x256 images with ImageNet normalization
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

def load_image(image_path):
    """Load and preprocess a single image"""
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0)  # Add batch dimension
    return img, img_tensor

# For testing, create a simple dummy image
# You can replace this with actual COCO images later
dummy_img = Image.new('RGB', (256, 256), color='blue')
img_original, img_tensor = dummy_img, transform(dummy_img).unsqueeze(0)

print(f"Image tensor shape: {img_tensor.shape}")
print("Image preprocessing ready")

In [None]:
# Caption generation with attention extraction
def caption_with_attention(encoder, decoder, image_tensor, word_map, beam_size=1, max_len=20):
    """
    Generate caption for image and return attention weights for each word.
    
    Returns:
        - caption: list of word indices
        - alphas: list of attention maps (one per word)
    """
    # Encode image
    encoder_out = encoder(image_tensor.to(device))  # (1, enc_image_size, enc_image_size, encoder_dim)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)
    
    # Flatten encoding
    encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
    num_pixels = encoder_out.size(1)  # Should be 196 (14x14)
    
    # Initialize LSTM state
    h, c = decoder.init_hidden_state(encoder_out)  # (1, decoder_dim)
    
    # Start with <start> token
    start_token = 1  # Assuming <start> token index is 1
    prev_word = torch.LongTensor([start_token]).to(device)
    
    # Store generated words and attention weights
    caption = []
    alphas_list = []
    
    # Generate caption word by word
    for t in range(max_len):
        embeddings = decoder.embedding(prev_word)  # (1, embed_dim)
        
        # Get attention-weighted encoding
        attention_weighted_encoding, alpha = decoder.attention(encoder_out, h)
        alphas_list.append(alpha.cpu().detach())  # Save attention weights
        
        # Gating scalar
        gate = decoder.sigmoid(decoder.f_beta(h))  # (1, encoder_dim)
        attention_weighted_encoding = gate * attention_weighted_encoding
        
        # LSTM step
        h, c = decoder.decode_step(
            torch.cat([embeddings, attention_weighted_encoding], dim=1),
            (h, c)
        )
        
        # Predict next word
        scores = decoder.fc(h)  # (1, vocab_size)
        predicted = scores.argmax(dim=1)  # (1,)
        
        caption.append(predicted.item())
        prev_word = predicted
        
        # Stop if <end> token (index 2)
        if predicted.item() == 2:
            break
    
    return caption, alphas_list

print("Caption generation with attention extraction ready")

In [None]:
# Sentence-based aggregation (Paper's main method)
def aggregate_sentence_based(alphas_list, k_percent=0.5):
    """
    Aggregate attention across all words by summing.
    Then select top-k% patches.
    
    Args:
        alphas_list: List of attention tensors, one per word (each is (1, num_pixels))
        k_percent: Percentage of patches to select (default 0.5 = 50%)
    
    Returns:
        aggregated_2d: Attention heatmap as 2D array (14, 14)
        mask_2d: Binary mask of top-k patches (14, 14)
    """
    # Stack all attention maps and sum across words
    all_alphas = torch.stack([a.squeeze(0) for a in alphas_list])  # (num_words, num_pixels)
    aggregated = all_alphas.sum(dim=0)  # (num_pixels,) - sum across all words
    
    # Reshape to 2D grid (14x14 for SAT)
    grid_size = int(np.sqrt(aggregated.shape[0]))  # Should be 14
    aggregated_2d = aggregated.reshape(grid_size, grid_size)
    
    # Select top-k patches
    k = int(aggregated.shape[0] * k_percent)
    threshold = torch.topk(aggregated, k).values.min()
    mask = (aggregated >= threshold).float()
    mask_2d = mask.reshape(grid_size, grid_size)
    
    return aggregated_2d.numpy(), mask_2d.numpy()

print("Sentence-based aggregation ready")

In [None]:
# Word-based aggregation (Paper's baseline)
def aggregate_word_based(alphas_list, k_percent=0.5):
    """
    Select top-k patches per word, then take union.
    
    Args:
        alphas_list: List of attention tensors, one per word
        k_percent: Percentage of patches to select per word
    
    Returns:
        aggregated_2d: Attention heatmap (averaged across words)
        mask_2d: Binary mask (union of top-k patches across all words)
    """
    num_words = len(alphas_list)
    num_pixels = alphas_list[0].shape[1]
    
    # For visualization: average attention across words
    all_alphas = torch.stack([a.squeeze(0) for a in alphas_list])
    aggregated = all_alphas.mean(dim=0)
    
    # For mask: select top-k per word, then union
    k_per_word = int(num_pixels * k_percent / num_words)
    selected_patches = torch.zeros(num_pixels)
    
    for alpha in alphas_list:
        alpha_flat = alpha.squeeze(0)
        top_k_indices = torch.topk(alpha_flat, k_per_word).indices
        selected_patches[top_k_indices] = 1
    
    # Reshape to 2D
    grid_size = int(np.sqrt(num_pixels))
    aggregated_2d = aggregated.reshape(grid_size, grid_size)
    mask_2d = selected_patches.reshape(grid_size, grid_size)
    
    return aggregated_2d.numpy(), mask_2d.numpy()

print("Word-based aggregation ready")

In [None]:
# Visualization function
def visualize_attention(img, caption_text, aggregated_2d, mask_2d, method_name="Sentence-based"):
    """
    Visualize attention heatmap and mask overlay on image.
    
    Args:
        img: Original PIL image
        caption_text: Generated caption as string
        aggregated_2d: Attention heatmap (14, 14)
        mask_2d: Binary mask (14, 14)
        method_name: Name of aggregation method
    """
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # Original image
    axes[0].imshow(img)
    axes[0].set_title(f"Original Image\n{caption_text[:40]}")
    axes[0].axis("off")
    
    # Attention heatmap overlay
    img_resized = img.resize((256, 256))
    img_array = np.array(img_resized)
    
    # Resize attention to match image size
    from scipy.ndimage import zoom
    zoom_factor = 256 / aggregated_2d.shape[0]
    attn_resized = zoom(aggregated_2d, zoom_factor, order=1)
    
    axes[1].imshow(img_array)
    axes[1].imshow(attn_resized, cmap="hot", alpha=0.5)
    axes[1].set_title(f"{method_name}\nAttention Overlay")
    axes[1].axis("off")
    
    # Attention heatmap only
    axes[2].imshow(aggregated_2d, cmap="hot", interpolation="bilinear")
    axes[2].set_title("Attention Heatmap\n(14x14 grid)")
    axes[2].axis("off")
    
    # Top-k patches mask
    mask_resized = zoom(mask_2d, zoom_factor, order=0)  # Nearest neighbor for binary mask
    axes[3].imshow(img_array)
    axes[3].imshow(mask_resized, cmap="Reds", alpha=0.4)
    axes[3].set_title(f"Top {int(mask_2d.sum())} Patches\nSelected for Attack")
    axes[3].axis("off")
    
    plt.tight_layout()
    plt.show()

print("Visualization ready")

In [None]:
# Test the complete pipeline
# NOTE: This will generate random captions since we don't have pretrained weights yet
# Once we load proper SAT weights, this will work correctly

print("Testing SAT pipeline...")
print("=" * 50)

# Create dummy word map for testing
word_map = {str(i): i for i in range(vocab_size)}
reverse_word_map = {v: k for k, v in word_map.items()}

# Generate caption with attention
with torch.no_grad():
    caption_indices, alphas_list = caption_with_attention(
        encoder, decoder, img_tensor, word_map, max_len=10
    )

# Convert indices to text (will be random without pretrained weights)
caption_text = " ".join([reverse_word_map.get(idx, "<unk>") for idx in caption_indices])
print(f"Generated caption: {caption_text}")
print(f"Number of words: {len(caption_indices)}")
print(f"Number of attention maps: {len(alphas_list)}")
print(f"Attention shape per word: {alphas_list[0].shape}")
print()

# Test sentence-based aggregation
print("Testing sentence-based aggregation...")
agg_sent, mask_sent = aggregate_sentence_based(alphas_list, k_percent=0.5)
print(f"Aggregated attention shape: {agg_sent.shape}")
print(f"Number of selected patches: {mask_sent.sum()}")
print()

# Test word-based aggregation
print("Testing word-based aggregation...")
agg_word, mask_word = aggregate_word_based(alphas_list, k_percent=0.5)
print(f"Aggregated attention shape: {agg_word.shape}")
print(f"Number of selected patches: {mask_word.sum()}")
print()

print("Pipeline test complete!")
print("=" * 50)
print("\nNOTE: To get proper captions, you need to:")
print("1. Download pretrained SAT checkpoint from sgrvinod's repo")
print("2. Load the checkpoint weights")
print("3. Load the vocabulary (word_map.json)")

In [None]:
# Visualize both aggregation methods
print("Visualizing sentence-based aggregation:")
visualize_attention(img_original, caption_text, agg_sent, mask_sent, "Sentence-based")

print("\nVisualizing word-based aggregation:")
visualize_attention(img_original, caption_text, agg_word, mask_word, "Word-based")