In [None]:
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


class CsvImageDataset(Dataset):
    def __init__(self, csv_path, img_root, transform=None, label_map=None):
        self.df = pd.read_csv(csv_path)
        self.img_root = Path(img_root)
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((512,512)),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
        self.label_col = "diagnosis_1"
        self.label2idx = {"Benign":0, "Malignant":1, "Indeterminate":1} if label_map is None else label_map


        self.df['img_path'] = self.df['isic_id'].apply(self._find_file)

    def _find_file(self, name):
        for ext in ['.jpg', '.png', '.jpeg']:
            path = self.img_root / f"{name}{ext}"
            if path.exists():
                return str(path)
        raise FileNotFoundError(f"Image file not found for {name} in {self.img_root}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['img_path']
        label_str = row[self.label_col]
        label = self.label2idx.get(label_str, 1)  # default 1 if unknown
        img = Image.open(img_path).convert("RGB")
        img_t = self.transform(img)
        return img_t, label, img_path

class MultiTaskDeepLab(nn.Module):
    def __init__(self, num_seg_classes:int=2, num_cls_classes:int=2):
        super().__init__()
        self.seg_model = deeplabv3_resnet50(weights="COCO_WITH_VOC_LABELS_V1", aux_loss=True)
        in_ch_seg = self.seg_model.classifier[-1].in_channels
        self.seg_model.classifier[-1] = nn.Conv2d(in_ch_seg, num_seg_classes, kernel_size=1)
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(num_seg_classes, num_cls_classes)  
        )

    def forward(self, x):
        seg_out = self.seg_model(x)['out']
        cls_out = self.cls_head(seg_out) 
        return seg_out, cls_out

def visualize_segmentation_comparison(orig_pil, pred_mask, save_path, true_label=None, pred_label=None):
    orig = np.array(orig_pil)
    overlay = orig.copy()
    if pred_mask.shape != orig.shape[:2]:
        pred_mask = np.array(Image.fromarray(pred_mask.astype(np.uint8)).resize(
            (orig.shape[1], orig.shape[0]), Image.NEAREST))
    overlay[pred_mask==1] = [255,0,0]
    blended = (0.7*orig + 0.3*overlay).astype(np.uint8)

    fig, axes = plt.subplots(1,2,figsize=(10,5))
    axes[0].imshow(orig)
    axes[0].set_title(f"Original\nTrue: {true_label}")
    axes[0].axis("off")

    axes[1].imshow(blended)
    axes[1].set_title(f"Prediction Overlay\nPred: {pred_label}")
    axes[1].axis("off")

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def evaluate_csv_model(model, dataset, device, save_dir="misclassified_results", max_wrong=30):
    os.makedirs(save_dir, exist_ok=True)
    loader = DataLoader(dataset, batch_size=4, shuffle=False)

    criterion_cls = nn.CrossEntropyLoss()
    criterion_seg = nn.CrossEntropyLoss()  # dummy mask용

    model.eval()
    total_correct = 0
    total_samples = 0
    total_loss = 0
    wrong_count = 0

    with torch.no_grad():
        for imgs, labels, paths in loader:
            imgs = imgs.to(device)
            labels_tensor = torch.tensor(labels, dtype=torch.long, device=device)

            seg_logits, cls_logits = model(imgs)
            loss_cls = criterion_cls(cls_logits, labels_tensor)
            masks_dummy = torch.zeros(seg_logits.shape[0], seg_logits.shape[2], seg_logits.shape[3],
                                      dtype=torch.long, device=device)
            loss_seg = criterion_seg(seg_logits, masks_dummy)
            batch_loss = (loss_cls + loss_seg).item()
            total_loss += batch_loss * imgs.size(0)

            preds_cls = torch.argmax(F.softmax(cls_logits, dim=1), dim=1)
            total_correct += (preds_cls == labels_tensor).sum().item()
            total_samples += imgs.size(0)

            for i in range(len(labels)):
                true_label = labels[i]
                pred_label = preds_cls[i].item()

                if preds_cls[i] != labels_tensor[i]:
                    wrong_count += 1
                    if wrong_count > max_wrong:
                        continue
                    
                    # segmentation mask argmax
                    pred_mask = torch.argmax(seg_logits[i], dim=0).cpu().numpy()
                    orig = Image.open(paths[i]).convert("RGB")
                    
                    # 파일명에 true/예측 클래스 추가
                    save_path = Path(save_dir) / f"wrong_{wrong_count}_true{true_label}_pred{pred_label}.jpg"

                    # overlay에 클래스 텍스트 추가
                    visualize_segmentation_comparison(orig, pred_mask, save_path, true_label, pred_label)
                    print(f"[{wrong_count}] Saved: {save_path}")


    acc = total_correct / total_samples
    avg_loss = total_loss / total_samples
    print(f"Accuracy: {acc:.4f}, Avg Loss: {avg_loss:.4f}, Wrong Samples Saved: {wrong_count}")
    return acc, avg_loss


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MultiTaskDeepLab(num_seg_classes=2, num_cls_classes=2).to(device)
ckpt = torch.load("../runs_multitask_1000/best_multitask.pt", map_location=device)
model.load_state_dict(ckpt["model"], strict=False)
model.eval()

# CSV 평가
dataset = CsvImageDataset(
    csv_path="../ISIC_dataset/challenge-2016-test_metadata_2025-09-01.csv",
    img_root="../ISIC_dataset/img"
)

acc, avg_loss = evaluate_csv_model(model, dataset, device,
                                   save_dir="misclassified_results_ISIC_True_false",
                                   max_wrong=30)
