In [None]:
import os
import json
import argparse
from pathlib import Path
from typing import List, Dict, Any, Tuple
import random
import math
import time
import numpy as np
from PIL import Image
import cv2
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from torch.utils.tensorboard import SummaryWriter
from pycocotools.coco import COCO
from pycocotools import mask as maskUtils
from sklearn.metrics import roc_auc_score

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def coco_polys_to_mask(h: int, w: int, ann_list: list) -> np.ndarray:
    rles = []
    for ann in ann_list:
        seg = ann.get("segmentation", [])
        if isinstance(seg, list) and len(seg) > 0 and isinstance(seg[0], str):
            seg = [list(map(float, s.split(","))) for s in seg]
        if len(seg) == 0:
            continue
        rle = maskUtils.frPyObjects(seg, h, w)
        if ann.get("iscrowd", 0) == 0:
            rle = maskUtils.merge(rle)
        rles.append(rle)
    if len(rles) == 0:
        return np.zeros((h, w), dtype=np.uint8)
    merged = maskUtils.merge(rles)
    m = maskUtils.decode(merged)
    if m.ndim == 3:
        m = np.any(m, axis=2)
    return (m > 0).astype(np.uint8)

def build_image_level_label(coco: COCO, img_id: int, cat_id_to_cls: Dict[int,int]) -> int:

    ann_ids = coco.getAnnIds(imgIds=[img_id])
    anns = coco.loadAnns(ann_ids)
    has_b, has_m = False, False
    for a in anns:
        c = cat_id_to_cls.get(a["category_id"], 0)
        if c == 1:
            has_b = True
        elif c == 2:
            has_m = True
    if has_m:
        return 2
    elif has_b:
        return 1
    else:
        return 0


class COCOMultiTaskDataset(Dataset):
    def __init__(self, root_dir: str, ann_file: str, img_size: int = 512,
                 augment: bool = False, cache_masks_dir: str = None):
        self.root_dir = Path(root_dir)
        self.coco = COCO(ann_file)
        self.img_ids = self.coco.getImgIds()
        self.img_size = img_size
        self.augment = augment
        self.cache_masks_dir = Path(cache_masks_dir) if cache_masks_dir else None
        if self.cache_masks_dir:
            self.cache_masks_dir.mkdir(parents=True, exist_ok=True)

        # map coco category id to our cls: benign=1, malignant=2
        cats = self.coco.loadCats(self.coco.getCatIds())
        self.cat_id_to_cls = {}
        for c in cats:
            name = c["name"].lower()
            if "benign" in name:
                self.cat_id_to_cls[c["id"]] = 1
            elif "malig" in name or "cancer" in name:
                self.cat_id_to_cls[c["id"]] = 2
            else:
                self.cat_id_to_cls[c["id"]] = 1

        self.img_trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])

    def __len__(self):
        return len(self.img_ids)

    def _read_image(self, info: Dict[str,Any]) -> Image.Image:
        fname = info["file_name"]
        path = self.root_dir / fname
        if not path.exists():
            # images/ 폴더에 있을 수도 있음
            alt = self.root_dir / "images" / fname
            if alt.exists():
                path = alt
            else:
                raise FileNotFoundError(f"Image not found: {path} or {alt}")
        img = Image.open(path).convert("RGB")
        return img, str(path)

    def _get_mask_path(self, image_info):
        if not self.cache_masks_dir:
            return None
        base = Path(image_info["file_name"]).stem
        return self.cache_masks_dir / f"{base}.png"

    def _ensure_mask(self, img_info, anns, h, w):
        mask_path = self._get_mask_path(img_info)
        if mask_path and mask_path.exists():
            m = np.array(Image.open(mask_path).convert("L"))
            return (m>0).astype(np.uint8)
        lesion_anns = [a for a in anns if self.cat_id_to_cls.get(a["category_id"],0) in (1,2)]
        m = coco_polys_to_mask(h, w, lesion_anns)
        if mask_path:
            Image.fromarray((m*255).astype(np.uint8)).save(mask_path)
        return m

    def _resize_keep_ratio(self, img: Image.Image, mask: np.ndarray, size: int):
        w, h = img.size
        scale = size / max(h, w)
        nh, nw = int(round(h*scale)), int(round(w*scale))
        img_r = img.resize((nw, nh), Image.BILINEAR)
        mask_r = cv2.resize(mask.astype(np.uint8), (nw, nh), interpolation=cv2.INTER_NEAREST)
        pad_h = size - nh
        pad_w = size - nw
        img_p = Image.new("RGB", (size, size))
        img_p.paste(img_r, (pad_w//2, pad_h//2))
        mask_p = np.zeros((size, size), dtype=np.uint8)
        mask_p[pad_h//2:pad_h//2+nh, pad_w//2:pad_w//2+nw] = mask_r
        return img_p, mask_p

    def to_tensor(self, img: Image.Image):
        return self.img_trans(img)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        info = self.coco.loadImgs([img_id])[0]
        img, _ = self._read_image(info)
        h, w = info["height"], info["width"]
        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        anns = self.coco.loadAnns(ann_ids)
        mask_full = self._ensure_mask(info, anns, h, w)
        img_label = build_image_level_label(self.coco, img_id, self.cat_id_to_cls)
        img_sq, mask_sq = self._resize_keep_ratio(img, mask_full, self.img_size)
        img_t = self.to_tensor(img_sq)
        mask_t = torch.from_numpy(mask_sq).long()
        cls_t = torch.tensor(img_label-1, dtype=torch.long) 

        return img_t, mask_t, cls_t, info.get("file_name", str(img_id))


class MultiTaskDeepLab(nn.Module):
    def __init__(self, num_seg_classes:int=2, num_cls_classes:int=2, freeze_until: str = None):
        super().__init__()
        self.seg_model = deeplabv3_resnet50(weights="COCO_WITH_VOC_LABELS_V1", aux_loss=True)
        in_ch = self.seg_model.classifier[-1].in_channels
        self.seg_model.classifier[-1] = nn.Conv2d(in_ch, num_seg_classes, kernel_size=1)
        self.backbone = self.seg_model.backbone

        if freeze_until:
            for name,p in self.backbone.named_parameters():
                p.requires_grad = True
                if freeze_until == "layer3":
                    if not (name.startswith("layer3") or name.startswith("layer4")):
                        p.requires_grad = False

        self.cls_gap = nn.AdaptiveAvgPool2d((1,1))
        self.cls_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048,512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512,num_cls_classes)
        )

    def forward(self, x):
        seg_out = self.seg_model(x)["out"]
        feats = self.backbone(x)["out"]
        pooled = self.cls_gap(feats)
        cls_logits = self.cls_head(pooled)
        return seg_out, cls_logits

def pixel_iou_and_dice(pred: torch.Tensor, target: torch.Tensor, num_classes: int = 2) -> Dict[str, float]:
    """pred: logits [B,C,H,W] or pred labels [B,H,W]; target: [B,H,W]"""
    if pred.ndim == 4:
        pred_lbl = pred.argmax(1)
    else:
        pred_lbl = pred
    target = target
    ious = []
    dices = []
    for c in range(1, num_classes):  
        pred_c = (pred_lbl == c)
        tgt_c = (target == c)
        inter = (pred_c & tgt_c).sum().item()
        union = (pred_c | tgt_c).sum().item()
        iou = inter / union if union > 0 else float('nan')
        dice = (2*inter) / (pred_c.sum().item() + tgt_c.sum().item()) if (pred_c.sum().item()+tgt_c.sum().item())>0 else float('nan')
        ious.append(iou)
        dices.append(dice)

    ious_valid = [v for v in ious if not math.isnan(v)]
    dices_valid = [v for v in dices if not math.isnan(v)]
    return {
        "miou_fg": float(np.mean(ious_valid)) if ious_valid else 0.0,
        "mdice_fg": float(np.mean(dices_valid)) if dices_valid else 0.0
    }

def train_one_epoch(model, loader, optim, scaler, device, seg_loss_w=1.0, cls_loss_w=1.0):
    model.train()
    running = {"loss":0.0, "seg":0.0, "cls":0.0}
    n = 0
    pbar = tqdm(loader, desc="train", leave=False)
    seg_criterion = nn.CrossEntropyLoss()
    cls_criterion = nn.CrossEntropyLoss()
    for imgs, masks, cls_lbls, _ in pbar:
        imgs = imgs.to(device)
        masks = masks.to(device)          # [B,H,W] values 0/1
        # cls_lbls: -1(ignore)/0/1
        cls_targets = cls_lbls.clone().to(device)
        valid_mask = (cls_targets >= 0)
        cls_targets_valid = torch.clamp(cls_targets, min=0)

        optim.zero_grad()
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            seg_logits, cls_logits = model(imgs)
            loss_seg = seg_criterion(seg_logits, masks)
            if valid_mask.any():
                loss_cls = cls_criterion(cls_logits[valid_mask], cls_targets_valid[valid_mask])
            else:
                loss_cls = torch.tensor(0.0, device=device)
            loss = seg_loss_w * loss_seg + cls_loss_w * loss_cls
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
        else:
            loss.backward()
            optim.step()

        bs = imgs.size(0)
        running["loss"] += loss.item() * bs
        running["seg"] += loss_seg.item() * bs
        running["cls"] += loss_cls.item() * bs
        n += bs
    for k in running:
        running[k] = running[k] / n
    return running

def evaluate(model, loader, device):
    model.eval()
    seg_criterion = nn.CrossEntropyLoss(reduction="sum")
    cls_criterion = nn.CrossEntropyLoss(reduction="sum")
    tot = {"loss":0.0, "seg":0.0, "cls":0.0}
    n = 0
    seg_metrics = []
    all_labels = []
    all_preds_proba = []
    correct = 0
    cls_count = 0
    with torch.no_grad():
        for imgs, masks, cls_lbls, _ in tqdm(loader, desc="eval", leave=False):
            imgs = imgs.to(device)
            masks = masks.to(device)
            cls_targets = cls_lbls.clone().to(device)
            valid_mask = (cls_targets >= 0)
            cls_targets_valid = torch.clamp(cls_targets, min=0)

            seg_logits, cls_logits = model(imgs)
            loss_seg = seg_criterion(seg_logits, masks)
            loss_cls = cls_criterion(cls_logits[valid_mask], cls_targets_valid[valid_mask]) if valid_mask.any() else torch.tensor(0.0, device=device)
            loss = loss_seg + loss_cls

            bs = imgs.size(0)
            tot["loss"] += loss.item()
            tot["seg"] += loss_seg.item()
            tot["cls"] += loss_cls.item()
            n += bs

            seg_metrics.append(pixel_iou_and_dice(seg_logits.cpu(), masks.cpu(), num_classes=2))

            if valid_mask.any():
                probs = F.softmax(cls_logits, dim=1)[:,1].cpu().numpy().tolist()  
                preds = cls_logits.argmax(1).cpu().numpy()
                labs = cls_targets_valid.cpu().numpy()
                all_preds_proba.extend(probs)
                all_labels.extend(labs.tolist())
                correct += int((preds[valid_mask.cpu().numpy()]==labs[valid_mask.cpu().numpy()]).sum())
                cls_count += int(valid_mask.sum().item())

    avg = {k: v / n for k,v in tot.items()}
    if seg_metrics:
        avg["miou_fg"] = float(np.mean([m["miou_fg"] for m in seg_metrics]))
        avg["mdice_fg"] = float(np.mean([m["mdice_fg"] for m in seg_metrics]))
    avg["cls_acc"] = (correct / cls_count) if cls_count>0 else 0.0
    try:
        avg["cls_auc"] = roc_auc_score(all_labels, all_preds_proba) if len(set(all_labels))>1 else 0.0
    except Exception:
        avg["cls_auc"] = 0.0
    return avg

def visualize_and_save(model, loader, device, out_dir: str, n_samples: int = 16):
    model.eval()
    os.makedirs(out_dir, exist_ok=True)
    saved = 0
    with torch.no_grad():
        for imgs, masks, cls_lbls, paths in loader:
            imgs = imgs.to(device)
            seg_logits, cls_logits = model(imgs)
            seg_pred = seg_logits.argmax(1).cpu().numpy()
            probs = F.softmax(cls_logits, dim=1).cpu().numpy()
            for i in range(imgs.size(0)):
                if saved >= n_samples:
                    return
                p = paths[i]
                try:
                    orig = Image.open(p).convert("RGB")
                except FileNotFoundError:
                    print(f"[WARNING] Could not find image: {p}")
                    continue
                # resize seg_pred to orig size using metadata - assume square input
                mask = seg_pred[i].astype(np.uint8)
                mask_up = cv2.resize(mask, orig.size, interpolation=cv2.INTER_NEAREST)
                overlay = np.array(orig).copy()
                overlay[mask_up==1] = (overlay[mask_up==1]*0.5 + np.array([255,0,0])*0.5).astype(np.uint8)  # red lesion
                savep = Path(out_dir)/f"vis_{saved}.png"
                Image.fromarray(overlay).save(savep)
                saved += 1

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--data_root", type=str, default="dataset_1000")
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch_size", type=int, default=8)
    p.add_argument("--img_size", type=int, default=512)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--freeze_until", type=str, default=None)
    p.add_argument("--seg_loss_w", type=float, default=1.0)
    p.add_argument("--cls_loss_w", type=float, default=1.0)
    p.add_argument("--cache_masks_dir", type=str, default="masks_cache_1000_ver2")
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--output_dir", type=str, default="runs_multitask_1000_ver2")
    p.add_argument("--use_amp", action="store_true")
    return p.parse_args()

def main():
    args = parse_args()
    seed_everything(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.output_dir, exist_ok=True)

    def split_paths(split):
        ann = Path(args.data_root)/split/"_annotations.coco.json"
        if not ann.exists():
            # fallback: any json
            alt = list((Path(args.data_root)/split).glob("*.json"))
            if not alt:
                raise FileNotFoundError(f"No COCO json under {args.data_root}/{split}")
            ann = alt[0]
        return str(Path(args.data_root)/split), str(ann)

    train_dir, train_ann = split_paths("train")
    val_dir, val_ann = split_paths("valid") if (Path(args.data_root)/"valid").exists() else split_paths("valid")
    
    ds_train = COCOMultiTaskDataset(train_dir, train_ann, img_size=args.img_size, augment=True, cache_masks_dir=args.cache_masks_dir)
    ds_val = COCOMultiTaskDataset(val_dir, val_ann, img_size=args.img_size, augment=False, cache_masks_dir=args.cache_masks_dir)

    def collate(batch):
        imgs, masks, clss, paths = zip(*batch)
        imgs = torch.stack(imgs)
        masks = torch.stack(masks)
        clss = torch.stack(clss)
        return imgs, masks, clss, list(paths)

    dl_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True,
                          num_workers=args.num_workers, collate_fn=collate,
                          pin_memory=True, drop_last=True)  # drop_last=True 추가
    dl_val = DataLoader(ds_val, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.num_workers, collate_fn=collate,
                        pin_memory=True, drop_last=False)

    model = MultiTaskDeepLab(num_seg_classes=2, num_cls_classes=2, freeze_until=args.freeze_until).to(device)

    optimizer = optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=args.lr, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp and torch.cuda.is_available())
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    writer = SummaryWriter(log_dir=os.path.join(args.output_dir, "tb"))

    best_score = -1.0
    best_path = os.path.join(args.output_dir, "best_multitask.pt")

    for epoch in range(1, args.epochs+1):
        t0 = time.time()
        tr = train_one_epoch(model, dl_train, optimizer, scaler if args.use_amp else None, device, seg_loss_w=args.seg_loss_w, cls_loss_w=args.cls_loss_w)
        va = evaluate(model, dl_val, device)
        score = 0.5 * va.get("miou_fg", 0.0) + 0.5 * va.get("cls_acc", 0.0)
        dt = time.time()-t0

        print(f"Epoch {epoch}/{args.epochs}  time={dt:.1f}s | train_loss={tr['loss']:.4f} seg={tr['seg']:.4f} cls={tr['cls']:.4f} | val_mIoU={va.get('miou_fg',0):.3f} val_cls_acc={va.get('cls_acc',0):.3f} val_auc={va.get('cls_auc',0):.3f} score={score:.4f}")
        writer.add_scalar("train/loss", tr["loss"], epoch)
        writer.add_scalar("train/seg_loss", tr["seg"], epoch)
        writer.add_scalar("train/cls_loss", tr["cls"], epoch)
        writer.add_scalar("val/miou", va.get("miou_fg",0), epoch)
        writer.add_scalar("val/mdice", va.get("mdice_fg",0), epoch)
        writer.add_scalar("val/cls_acc", va.get("cls_acc",0), epoch)
        writer.add_scalar("val/cls_auc", va.get("cls_auc",0), epoch)
        writer.add_scalar("val/score", score, epoch)
        writer.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch)

        
        if score > best_score:
            best_score = score
            torch.save({"model": model.state_dict(), "epoch": epoch, "score": best_score}, best_path)
            print(f"✔ Saved best to {best_path} (score {best_score:.4f})")

        scheduler.step()
        if epoch % 5 == 0:
            vis_dir = os.path.join(args.output_dir, f"vis_epoch{epoch}")
            visualize_and_save(model, dl_val, device, vis_dir, n_samples=16)

    writer.close()
    print("Training finished. Best score:", best_score, "->", best_path)

if __name__ == "__main__":
    main()
