In [13]:
import os
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from model import AttentionUNet
from sklearn.metrics import jaccard_score

# 🔧 Dosya yolları
image_dir = r"C:/Users/EXCALIBUR/Desktop/project/data/combined/test/images"
mask_dir = r"C:/Users/EXCALIBUR/Desktop/project/data/combined/test/masks"
model_path = r"C:/Users/EXCALIBUR/Desktop/project/KOMPLEXPROJEattention_unet_model.pth"
save_dir = r"C:/Users/EXCALIBUR/Desktop/arxiv/best_samples"
os.makedirs(save_dir, exist_ok=True)

# 📦 Dataset Sınıfı
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))  # Aynı sıralama önemli

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask, self.images[idx]

# 🔁 Dönüştürme
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# 📂 Dataset yükle
dataset = SegmentationDataset(image_dir, mask_dir, transform)

# 🧠 Model yükle
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AttentionUNet().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# 🔍 Her örnek için IoU hesapla
scores = []
outputs = []

for idx in range(len(dataset)):
    img, true_mask, fname = dataset[idx]
    input_tensor = img.unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        pred = (output > 0.0002).float().squeeze().cpu().numpy()

    gt = TF.resize(true_mask, [256, 256]).squeeze().cpu().numpy()
    gt_bin = (gt > 0.5).astype(np.uint8).flatten()
    pred_bin = (pred > 0.5).astype(np.uint8).flatten()

    if pred_bin.shape == gt_bin.shape:
        iou = jaccard_score(gt_bin, pred_bin)
        scores.append((iou, idx, fname, pred, output.squeeze().cpu().numpy()))

# 📌 En iyi 3 örneği seç
best_samples = sorted(scores, reverse=True)[:3]

# 🎨 Görselleştir ve kaydet
for rank, (iou_score, idx, fname, pred_mask, raw_output) in enumerate(best_samples):
    img, true_mask, _ = dataset[idx]

    img_np = TF.resize(img, [256, 256]).permute(1, 2, 0).cpu().numpy()
    img_np = (img_np * 255).astype(np.uint8)

    true_mask_np = TF.resize(true_mask, [256, 256]).squeeze().cpu().numpy()
    true_mask_np = (true_mask_np > 0.5).astype(np.uint8)

    pred_mask_uint8 = (pred_mask * 255).astype(np.uint8)

    # Heatmap
    cam_map = raw_output
    cam_map = cv2.resize(cam_map, (256, 256))
    cam_map = (cam_map - cam_map.min()) / (cam_map.max() - cam_map.min() + 1e-8)
    cam_heatmap = cv2.applyColorMap(np.uint8(255 * cam_map), cv2.COLORMAP_JET)
    cam_heatmap = cv2.cvtColor(cam_heatmap, cv2.COLOR_BGR2RGB)

       # Plotla
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 4, 1)
    plt.imshow(img_np)
    plt.title("Input Image")
    plt.axis('off')

    plt.subplot(1, 4, 2)
    plt.imshow(true_mask_np, cmap='gray')
    plt.title("Ground Truth")
    plt.axis('off')

    plt.subplot(1, 4, 3)
    plt.imshow(pred_mask_uint8, cmap='gray')
    plt.title("Prediction")
    plt.axis('off')

    plt.subplot(1, 4, 4)
    plt.imshow(img_np)
    plt.imshow(cam_heatmap, alpha=0.5)
    plt.title("Model Output (Heatmap)")
    plt.axis('off')

    # ✅ Güncel başlık (sade)
    plt.suptitle(f"IoU: {iou_score:.4f}", fontsize=12)
    plt.tight_layout()
    save_path = os.path.join(save_dir, f"result_{rank+1}_{fname}.png")
    plt.savefig(save_path)
    plt.close()


  model.load_state_dict(torch.load(model_path, map_location=device))
