In [None]:
import os
import numpy as np
import torch
from PIL import Image
import cv2
from tqdm import tqdm
from transformers import SwinModel, AutoImageProcessor
import matplotlib.pyplot as plt

In [None]:
def extract_attention_map(model, img_pil, image_processor, device):
    input_tensor = image_processor(images=img_pil, return_tensors="pt").pixel_values.to(device)
    outputs = model(input_tensor, output_attentions=True)
    attn = outputs.attentions[-1].mean(dim=1).squeeze().cpu().detach().numpy()
    attn = (attn - attn.min()) / (attn.max() - attn.min() + 1e-5)
    attn_resized = cv2.resize(attn, img_pil.size)
    heatmap = cv2.applyColorMap(np.uint8(255 * attn_resized), cv2.COLORMAP_JET)
    overlay = np.array(img_pil) * 0.5 + heatmap * 0.5
    return np.uint8(np.clip(overlay, 0, 255))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(device)
processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
model.eval()

def run_visualization(input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    image_list = sorted([f for f in os.listdir(input_dir) if f.lower().endswith((".png", ".jpg", ".jpeg"))])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(device)
    processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
    model.eval()

    for fname in tqdm(image_list):
        try:
            img_path = os.path.join(input_dir, fname)
            img_pil = Image.open(img_path).convert("RGB")

            overlay = extract_attention_map(model, img_pil, processor, device)
            out_path = os.path.join(output_dir, f"{os.path.splitext(fname)[0]}.png")
            Image.fromarray(overlay).save(out_path)
        except Exception as e:
            print(f"❌ Error with {fname}: {e}")

In [None]:
run_visualization("/raw/cropped_image_all", "/visualizations")