# NAF: Neighborhood Attention Map

This notebook provides tools to inspect and visualize attention weights from the NAF model.

**Features:**
- Load pre-trained NAF model and test images
- Select query points on the image
- Extract and visualize attention weights for the selected position
- Compare attention patterns across different image regions

## Setup and Imports

In [None]:
# Use inline matplotlib for better compatibility
%matplotlib widget

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import torchvision.transforms as T

from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from IPython.display import clear_output

from src.model.naf import NAF
from utils.training import load_multiple_backbones

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load Model and Configuration

In [None]:
# Initialize Hydra config
if not GlobalHydra.instance().is_initialized():
    initialize(config_path="../config", version_base=None)

cfg = compose(config_name="base", overrides=["model=naf", "img_size=1024"])

# Initialize NAF model
model = torch.hub.load("valeoai/NAF", "naf", pretrained=True, device=device)
model.eval()

# Load backbone
backbone_configs = [{"name": "vit_base_patch16_dinov3.lvd1689m"}]
backbones, names, _ = load_multiple_backbones(cfg, backbone_configs, device)
backbone = backbones[0]
backbone.eval()

print(f"Model loaded: {type(model).__name__}")
print(f"Backbone loaded: {names[0]}")
print(f"Kernel size: {model.upsampler.kernel_size}")

## 3. Load and Preprocess Image

In [None]:
# Load image
image_path = "../asset/dog0.jpg"
pil_image = Image.open(image_path).convert("RGB")

# Resize to model input size
H, W = cfg.img_size, cfg.img_size
transform = T.Compose([T.Resize((H, W)), T.ToTensor()])
image = transform(pil_image).unsqueeze(0).to(device)

# Normalize for backbone
mean_bck = backbone.config["mean"]
std_bck = backbone.config["std"]
normalize_bck = T.Normalize(mean=mean_bck, std=std_bck)
image_bck = normalize_bck(image)

# Normalize for upsampler (ImageNet normalization)
normalize_ups = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image_ups = normalize_ups(image)

# Convert to display format
image_vis = (image.clone().cpu().squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)

print(f"Loaded image from: {image_path}")
print(f"Image size: {image.shape}")

## 4. Extract Backbone Features

In [None]:
# Extract features from backbone
with torch.no_grad():
    hr_feats = backbone(image_bck)

# Define output size for encoder (match backbone output)
output_size = (hr_feats.shape[2], hr_feats.shape[3])

# Compute scaling factors between image and feature space
scale_h = output_size[0] / H
scale_w = output_size[1] / W

print(f"Backbone features shape: {hr_feats.shape}")
print(f"Feature map size: {output_size}")
print(f"Scaling factors - h: {scale_h:.3f}, w: {scale_w:.3f}")

## 5. Helper Functions for Query Position and Attention Extraction

In [None]:
def extract_attention_at_position(query_img_h, query_img_w):
    """
    Extract attention weights for a query position.
    
    Args:
        query_img_h: Query position height in image coordinates (0-447)
        query_img_w: Query position width in image coordinates (0-447)
        
    Returns:
        attn_map: 2D attention map of shape (kernel_h, kernel_w)
        pos_feat: Query position in feature space
    """
    # Convert image coordinates to feature space
    pos_feat = (int(query_img_h * scale_h), int(query_img_w * scale_w))
    
    # Extract attention weights
    with torch.no_grad():
        # Get RoPE-transformed features (queries)
        _, attn_weights = model(image_ups, hr_feats, output_size=output_size, return_weights=True)
        
    # Extract attention for the query position
    # attn_weights: [B, num_heads, H, W, kernel_h * kernel_w]
    query_h, query_w = pos_feat
    attn_at_query = attn_weights[0, :, query_h, query_w, :]  # [num_heads, kernel_h*kernel_w]
    
    # Average across heads
    attn_mean = attn_at_query.mean(dim=0)  # [kernel_h*kernel_w]
    
    # Reshape to 2D attention map
    kernel_h, kernel_w = model.upsampler.kernel_size
    attn_map = attn_mean.view(kernel_h, kernel_w).cpu().numpy()
    
    return attn_map, pos_feat


print("Helper functions loaded")

## 6. Visualize Attention for a Single Query Point

In [None]:
# Interactive attention visualization
fig, axes = plt.subplots(1, 3, figsize=(12*2/3, 4*2/3))

# Initialize with default position
query_img_h, query_img_w = 650, 270

# Plot elements that will be updated
img_display = axes[0].imshow(image_vis)
query_point, = axes[0].plot([], [], 'r*', markersize=15, markeredgecolor='white', markeredgewidth=1)
rect_patch = patches.Rectangle((0, 0), 0, 0, linewidth=2, edgecolor='red', facecolor='none', linestyle='--')
axes[0].add_patch(rect_patch)

axes[0].set_xlabel('Image Width')
axes[0].set_ylabel('Image Height')
axes[0].set_title('Click on image to select query point')
axes[0].set_aspect('equal')

# Attention heatmap (will be updated)
attn_display = axes[1].imshow(np.zeros((5, 5)), cmap='RdBu_r', aspect='equal')
axes[1].set_xlabel('Kernel Width')
axes[1].set_ylabel('Kernel Height')
axes[1].set_title('Attention Map')

cbar1 = plt.colorbar(attn_display, ax=axes[1], fraction=0.046, pad=0.04)
cbar1.set_label('Attention Weight')

# Overlay visualization
overlay_display = axes[2].imshow(np.zeros((10, 10, 3), dtype=np.uint8))
axes[2].set_xlabel('Width')
axes[2].set_ylabel('Height')
axes[2].set_title('Attention Overlay on Neighborhood')
axes[2].set_aspect('equal')

# Statistics text
stats_text = axes[1].text(0.02, 0.98, '', transform=axes[1].transAxes, 
                          verticalalignment='top', fontsize=8, 
                          bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

def update_attention(query_h, query_w):
    """Update the attention visualization for a new query position."""
    # Extract attention map
    attn_map, pos_feat = extract_attention_at_position(query_h, query_w)
    kernel_h, kernel_w = model.upsampler.kernel_size
    
    # Update query point
    query_point.set_data([query_w], [query_h])
    
    # Update neighborhood rectangle
    half_h, half_w = kernel_h // 2, kernel_w // 2
    box_x0 = max(0, query_w - int(half_w / scale_w))
    box_y0 = max(0, query_h - int(half_h / scale_h))
    box_x1 = min(W, query_w + int(half_w / scale_w))
    box_y1 = min(H, query_h + int(half_h / scale_h))
    
    rect_patch.set_xy((box_x0, box_y0))
    rect_patch.set_width(box_x1 - box_x0)
    rect_patch.set_height(box_y1 - box_y0)
    
    # Update attention heatmap
    attn_display.set_data(attn_map)
    attn_display.set_clim(vmin=attn_map.min(), vmax=attn_map.max())
    axes[1].set_title(f'Attention Map ({kernel_h}Ã—{kernel_w})')
    
    # Create overlay visualization
    # Extract neighborhood from image
    neighborhood = image_vis[box_y0:box_y1, box_x0:box_x1, :]
    
    # Resize attention map to match neighborhood size
    from scipy.ndimage import zoom
    attn_resized = zoom(attn_map, (neighborhood.shape[0] / attn_map.shape[0], 
                                    neighborhood.shape[1] / attn_map.shape[1]), order=1)
    
    # Normalize attention to 0-1 for overlay
    attn_norm = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min() + 1e-8)
    
    # Create overlay: blend image with attention heatmap
    overlay = neighborhood.copy().astype(float)
    
    # Apply colormap to attention (red-blue)
    import matplotlib.cm as cm
    attn_colored = cm.RdBu_r(attn_norm)[:, :, :3] * 255  # RGB only
    
    # Blend: 50% image, 50% attention heatmap
    overlay_blended = (0.5 * overlay + 0.5 * attn_colored).astype(np.uint8)
    
    # Update overlay display
    overlay_display.set_data(overlay_blended)
    overlay_display.set_extent([box_x0, box_x1, box_y1, box_y0])
    axes[2].set_xlim(box_x0, box_x1)
    axes[2].set_ylim(box_y1, box_y0)
    axes[2].set_title('Attention Overlay on Neighborhood')
    
    # Update statistics text
    fig.canvas.draw_idle()
    
    # Print to console
    print(f"Query: ({query_h}, {query_w}) | Feature: {pos_feat} | Center attn: {attn_map[kernel_h//2, kernel_w//2]:.4f}")

def onclick(event):
    """Handle click events on the image."""
    if event.inaxes == axes[0] and event.xdata is not None and event.ydata is not None:
        new_w = int(np.clip(event.xdata, 0, W-1))
        new_h = int(np.clip(event.ydata, 0, H-1))
        update_attention(new_h, new_w)

fig.canvas.mpl_connect('button_press_event', onclick)

plt.tight_layout()

# Initial visualization
update_attention(query_img_h, query_img_w)
plt.show()

print("Click on the left image to explore attention patterns!")