<a href="https://colab.research.google.com/github/s4908819/Colab-for-COMP3710/blob/main/4_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!nvidia-smi
!python -V
# 常用包（nibabel 读 NIfTI；tqdm 进度条；umap-learn 可视化可选；opencv-python 可选）
!pip -q install nibabel tqdm umap-learn opencv-python


Thu Sep 18 09:34:57 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   33C    P0             44W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
%cd /content
!mkdir -p datasets scripts data results
!touch datasets/__init__.py scripts/__init__.py
!unzip -q /content/keras_png_slices_data.zip -d /content/data


/content


In [6]:
%%writefile datasets/oasis2d.py
# datasets/oasis2d.py
from typing import List, Dict, Optional, Tuple
from torch.utils.data import Dataset
import numpy as np, torch
import nibabel as nib
from skimage.transform import resize as sk_resize
# --- 新增：PNG/JPG 读取支持 ---
from PIL import Image

def _load_array(path: str) -> np.ndarray:
    p = path.lower()
    if p.endswith(".npy"):
        return np.load(path)
    if p.endswith(".npz"):
        data = np.load(path); key = list(data.keys())[0]
        return data[key]
    if p.endswith(".nii") or p.endswith(".nii.gz"):
        return np.asarray(nib.load(path).get_fdata())
    # --- 新增：PNG/JPG（灰度） ---
    if p.endswith(".png") or p.endswith(".jpg") or p.endswith(".jpeg"):
        img = Image.open(path)
        if img.mode != "L":  # 统一转灰度
            img = img.convert("L")
        return np.array(img)
    raise ValueError(f"Unsupported file: {path}")

def _zscore(x: np.ndarray, eps=1e-6) -> np.ndarray:
    return (x - x.mean()) / (x.std() + eps)

def _minmax01(x: np.ndarray, eps=1e-6) -> np.ndarray:
    mn, mx = x.min(), x.max()
    return (x - mn) / (mx - mn + eps)

# --- 新增：把 2D 数据补成 3D（沿指定轴插入一个长度为1的维度） ---
def _ensure_3d(arr: np.ndarray, slice_axis: int) -> np.ndarray:
    if arr.ndim == 2:
        return np.expand_dims(arr, axis=slice_axis)  # 让 shape 在 slice_axis 位置变成 1
    return arr

class OASIS2DSeg(Dataset):
    def __init__(
        self,
        items: List[Dict],               # [{'pid','img','mask'}, ...] —— 来自 split JSON 某一组
        slice_axis: int = 2,             # 2=轴状切片（axial）
        target_hw: Optional[Tuple[int,int]] = (256, 256),
        min_fg_pixels: int = 64,         # 掩码前景像素下限
        norm: str = "zscore",            # 或 "minmax"
        augment: bool = False,
        rng_seed: int = 42,
    ):
        self.items = items
        self.slice_axis = slice_axis
        self.target_hw = target_hw
        self.min_fg_pixels = min_fg_pixels
        self.norm = norm
        self.augment = augment
        self.rng = np.random.default_rng(rng_seed)

        # 预先构建 (volume_idx, slice_idx) 列表（只保留有前景的切片）
        self._slices = []
        for vi, it in enumerate(self.items):
            img = _load_array(it["img"])
            msk = _load_array(it["mask"]).astype(np.int64)
            # --- 新增：兼容 2D png，自动补成 3D ---
            img = _ensure_3d(img, self.slice_axis)
            msk = _ensure_3d(msk, self.slice_axis)
            if img.shape != msk.shape:
                raise ValueError(f"Shape mismatch: {img.shape} vs {msk.shape} for {it}")
            num_slices = img.shape[self.slice_axis]
            for si in range(num_slices):
                sl_msk = np.take(msk, si, axis=self.slice_axis)
                if (sl_msk > 0).sum() >= self.min_fg_pixels:
                    self._slices.append((vi, si))
        if not self._slices:
            raise RuntimeError("No valid slices; consider lowering min_fg_pixels.")

    def __len__(self):
        return len(self._slices)

    def _resize_pair(self, img2d: np.ndarray, msk2d: np.ndarray, target_hw):
        if target_hw is None:
            return img2d, msk2d
        H, W = target_hw
        # 图像双线性；掩码最近邻（避免标签污染）
        img_r = sk_resize(img2d, (H, W), order=1, preserve_range=True, anti_aliasing=True).astype(np.float32)
        msk_r = sk_resize(msk2d, (H, W), order=0, preserve_range=True, anti_aliasing=False).astype(np.int64)
        return img_r, msk_r

    def __getitem__(self, idx):
        vi, si = self._slices[idx]
        it = self.items[vi]
        img = _load_array(it["img"])
        msk = _load_array(it["mask"]).astype(np.int64)
        # --- 新增：兼容 2D png，自动补成 3D ---
        img = _ensure_3d(img, self.slice_axis)
        msk = _ensure_3d(msk, self.slice_axis)

        sl_img = np.take(img, si, axis=self.slice_axis).astype(np.float32)
        sl_msk = np.take(msk, si, axis=self.slice_axis)

        # 归一化
        sl_img = _zscore(sl_img) if self.norm == "zscore" else _minmax01(sl_img)

        # 简单增强（可选）：左右/上下翻转（保持图像与掩码同步）
        if self.augment:
            if self.rng.random() < 0.5:
                sl_img = np.flip(sl_img, axis=1); sl_msk = np.flip(sl_msk, axis=1)
            if self.rng.random() < 0.5:
                sl_img = np.flip(sl_img, axis=0); sl_msk = np.flip(sl_msk, axis=0)

        # 尺寸对齐
        sl_img, sl_msk = self._resize_pair(sl_img, sl_msk, self.target_hw)

        # ===== 关键修复：把掩码压成类别索引 =====
        # 二分类场景：非零即前景 -> {0,1}
        sl_msk = (sl_msk > 0).astype(np.int64)

        # 如果将来是多分类（例：0,85,170,255 -> {0,1,2,3}），可改为：
        # mapping = {0:0, 85:1, 170:2, 255:3}
        # sl_msk = np.vectorize(lambda v: mapping.get(int(v), 0), otypes=[np.int64])(sl_msk)

        # 转张量
        img_t = torch.from_numpy(sl_img[None, ...].astype(np.float32))  # [1,H,W]
        msk_t = torch.from_numpy(sl_msk.astype(np.int64))               # [H,W]
        meta = {"pid": it["pid"], "slice": int(si)}
        return img_t, msk_t, meta


Writing datasets/oasis2d.py


In [9]:
%%writefile scripts/prepare_split.py
# scripts/prepare_split.py
# 用法：
#   python scripts/prepare_split.py \
#     --data_root /home/groups/comp3710/OASIS_preprocessed \
#     --out data/split_42.json

import argparse, json, os, re, random
from pathlib import Path

# === 扩展：PNG/JPG 支持（原 NIfTI/NumPy 也保留） ===
IMG_EXTS = (".nii", ".nii.gz", ".npy", ".npz", ".png", ".jpg", ".jpeg")
MASK_EXTS = IMG_EXTS

def is_img(p: str) -> bool:
    p = p.lower()
    return any(p.endswith(ext) for ext in IMG_EXTS)

def is_mask(p: str) -> bool:
    p = p.lower()
    return any(p.endswith(ext) for ext in MASK_EXTS)

# ---- 关键：把文件名统一为“可配对”的规范形态 ----
# 例： "seg_441_slice_0.nii.png" -> "441_slice_0"
#      "441_slice_0.nii.png"     -> "441_slice_0"
#      "case123_img.png"         -> "case123"
#      "case123_mask.png"        -> "case123"
def norm_stem(name_or_path: str) -> str:
    name = Path(name_or_path).name
    stem = Path(name).stem            # 去掉最后一个扩展名（如 .png）
    stem = re.sub(r"\.nii$", "", stem, flags=re.I)  # 去掉尾部 .nii
    stem = re.sub(r"^(seg_|mask_|label_|case_)", "", stem, flags=re.I)
    stem = re.sub(r"(_mask|_seg(mentation)?)$", "", stem, flags=re.I)  # 去掉后缀
    return stem

def find_pairs(data_root: Path):
    root = Path(data_root)
    pairs = []

    # 情况 A：images/ 和 masks/ 兄弟目录（保持，改用 norm_stem 匹配）
    img_dir, msk_dir = root / "images", root / "masks"
    if img_dir.is_dir() and msk_dir.is_dir():
        imgs = sorted([p for p in img_dir.rglob("*") if p.is_file() and is_img(str(p))])
        idx = {norm_stem(p.name): p for p in imgs}
        for m in sorted([p for p in msk_dir.rglob("*") if p.is_file() and is_mask(str(p))]):
            key = norm_stem(m.name)
            cand = idx.get(key)
            if cand:
                pairs.append({"pid": key, "img": str(cand), "mask": str(m)})
        if pairs:
            return pairs

    # 情况 B：每个病人一个文件夹（保持不变）
    for d in sorted([p for p in root.iterdir() if p.is_dir()]):
        pid = d.name
        imgs = sorted([p for p in d.rglob("*")
                       if p.is_file() and is_img(str(p)) and re.search(r"(?:^|[_-])(img|image|t1|t2)(?:[_-]|$)", p.stem, re.I)])
        msks = sorted([p for p in d.rglob("*")
                       if p.is_file() and is_mask(str(p)) and re.search(r"(?:^|[_-])(mask|label|seg)(?:[_-]|$)", p.stem, re.I)])
        if imgs and msks:
            pairs.append({"pid": pid, "img": str(imgs[0]), "mask": str(msks[0])})
    if pairs:
        return pairs

    # 情况 C：扁平目录，<PID>_img.* / <PID>_mask.*（保持，改用 norm_stem）
    files = sorted([p for p in root.rglob("*") if p.is_file() and (is_img(str(p)) or is_mask(str(p)))])
    if files:
        imgs = [p for p in files if re.search(r"(?:^|[_-])(img|image|t1|t2)(?:[_-]|$)", p.stem, re.I)]
        msks = [p for p in files if re.search(r"(?:^|[_-])(mask|label|seg)(?:[_-]|$)", p.stem, re.I)]
        idx = {norm_stem(p.name): p for p in imgs}
        tmp_pairs = []
        for m in msks:
            key = norm_stem(m.name)
            if key in idx:
                tmp_pairs.append({"pid": key, "img": str(idx[key]), "mask": str(m)})
        if tmp_pairs:
            return tmp_pairs

    # 情况 D：Keras PNG 切片版目录（新增；核心匹配规则改为 norm_stem）
    #   keras_png_slices_train/            keras_png_slices_seg_train/
    #   keras_png_slices_validate/         keras_png_slices_seg_validate/
    #   keras_png_slices_test/             keras_png_slices_seg_test/
    png_pairs = []
    phases = ["train", "validate", "val", "test"]  # validate/val 都兼容
    for ph in phases:
        imgd = root / f"keras_png_slices_{ph}"
        mskd = root / f"keras_png_slices_seg_{ph}"
        if not (imgd.is_dir() and mskd.is_dir()):
            continue

        # 掩码映射：用“规范 stem”做键
        mask_idx = {}
        for m in mskd.rglob("*"):
            if m.is_file() and m.suffix.lower() in (".png", ".jpg", ".jpeg"):
                mask_idx[norm_stem(m.name)] = m

        # 找图片并配对
        for im in imgd.rglob("*"):
            if not (im.is_file() and im.suffix.lower() in (".png", ".jpg", ".jpeg")):
                continue
            key = norm_stem(im.name)
            cand = mask_idx.get(key)
            if cand is not None:
                # 每个切片当作独立“样本”；pid 用 key（稳妥）
                png_pairs.append({"pid": key, "img": str(im), "mask": str(cand)})

    if png_pairs:
        return sorted(png_pairs, key=lambda x: x["pid"])

    # 实在找不到
    return []

def split_by_patient(pairs, val_ratio=0.15, test_ratio=0.15, seed=42):
    # 聚合 pid → 该 pid 下的全部条目（PNG 场景里通常 1 条/切片）
    by_pid = {}
    for it in pairs:
        by_pid.setdefault(it["pid"], []).append(it)
    pids = sorted(by_pid.keys())
    rng = random.Random(seed)
    rng.shuffle(pids)

    n = len(pids)
    n_test = int(n * test_ratio)
    n_val  = int(n * val_ratio)
    test_p = set(pids[:n_test])
    val_p  = set(pids[n_test:n_test + n_val])

    split = {"train": [], "val": [], "test": []}
    for pid in pids:
        bucket = "test" if pid in test_p else ("val" if pid in val_p else "train")
        split[bucket].extend(by_pid[pid])
    return split

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_root", required=True)
    ap.add_argument("--out", default="data/split_42.json")
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--val_ratio", type=float, default=0.15)
    ap.add_argument("--test_ratio", type=float, default=0.15)
    args = ap.parse_args()

    pairs = find_pairs(Path(args.data_root))
    if not pairs:
        raise SystemExit(f"[ERR] No image/mask pairs found under {args.data_root}")

    split = split_by_patient(pairs, args.val_ratio, args.test_ratio, args.seed)

    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w", encoding="utf-8") as f:
        json.dump(split, f, indent=2, ensure_ascii=False)

    print(f"[OK] volumes: train={len(split['train'])}, val={len(split['val'])}, test={len(split['test'])}")
    print(f"[OK] saved to {str(out_path)}")

if __name__ == "__main__":
    main()


Writing scripts/prepare_split.py


In [7]:
%%writefile scripts/train_unet.py
# scripts/train_unet.py
import os, json, math, argparse, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets.oasis2d import OASIS2DSeg

# -------------------------
# 1) UNet 定义（简洁稳定版）
# -------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x): return self.conv(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
            self.conv = DoubleConv(in_ch, out_ch)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_ch, out_ch)
        self.bilinear = bilinear
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # pad/crop to match skip
        diffY = x2.size(2) - x1.size(2)
        diffX = x2.size(3) - x1.size(3)
        x1 = F.pad(x1, [diffX//2, diffX - diffX//2, diffY//2, diffY - diffY//2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_ch=1, n_classes=4, base=32, bilinear=True):
        super().__init__()
        self.inc  = DoubleConv(in_ch, base)
        self.down1= Down(base, base*2)
        self.down2= Down(base*2, base*4)
        self.down3= Down(base*4, base*8)
        factor = 2 if bilinear else 1
        self.down4= Down(base*8, base*16//factor)
        self.up1  = Up(base*16, base*8//factor, bilinear)
        self.up2  = Up(base*8,  base*4//factor, bilinear)
        self.up3  = Up(base*4,  base*2//factor, bilinear)
        self.up4  = Up(base*2,  base, bilinear)
        self.outc = nn.Conv2d(base, n_classes, kernel_size=1)
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x,  x3)
        x = self.up3(x,  x2)
        x = self.up4(x,  x1)
        return self.outc(x)  # logits [B,C,H,W]

# -------------------------
# 2) 损失 & 评估
# -------------------------
def soft_dice_loss(logits, target, smooth=1e-5, ignore_background=True):
    """
    logits: [B,C,H,W], target: [B,H,W] long
    返回 (1 - mean foreground dice)
    """
    num_classes = logits.shape[1]
    probs = F.softmax(logits, dim=1)                         # [B,C,H,W]
    one_hot = F.one_hot(target, num_classes).permute(0,3,1,2).float()
    dims = (0,2,3)
    inter = (probs * one_hot).sum(dims)                      # [C]
    denom = probs.sum(dims) + one_hot.sum(dims)              # [C]
    dice = (2*inter + smooth) / (denom + smooth)             # [C]
    if ignore_background and num_classes > 1:
        dice = dice[1:]
    return 1.0 - dice.mean()

@torch.no_grad()
def dsc_per_class(logits, target, num_classes, ignore_background=False):
    pred = logits.argmax(1)  # [B,H,W]
    dsc = []
    classes = range(num_classes)
    if ignore_background and num_classes > 1:
        classes = range(1, num_classes)
    for c in classes:
        p = (pred == c)
        t = (target == c)
        tp = (p & t).sum().item()
        denom = p.sum().item() + t.sum().item()
        dsc_c = (2*tp) / (denom + 1e-5) if denom > 0 else 1.0
        dsc.append(dsc_c)
    return dsc  # 前景类列表

# -------------------------
# 3) 训练循环
# -------------------------
def train_one_epoch(model, loader, opt, scaler, device, num_classes, lambda_dice=1.0):
    model.train()
    run_loss, n = 0.0, 0
    for img, msk, _ in loader:
        img, msk = img.to(device, non_blocking=True), msk.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type=='cuda')):
            logits = model(img)
            ce = F.cross_entropy(logits, msk)                     # CE 不需要 one-hot
            dice = soft_dice_loss(logits, msk, ignore_background=True)
            loss = ce + lambda_dice * dice
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        run_loss += loss.item() * img.size(0)
        n += img.size(0)
    return run_loss / max(1,n)

@torch.no_grad()
def validate(model, loader, device, num_classes):
    model.eval()
    dsc_sum = np.zeros(max(1, num_classes-1), dtype=np.float64)  # 默认不计背景
    n_batches = 0
    for img, msk, _ in loader:
        img, msk = img.to(device), msk.to(device)
        logits = model(img)
        dsc = dsc_per_class(logits, msk, num_classes, ignore_background=True)
        dsc_sum += np.array(dsc, dtype=np.float64)
        n_batches += 1
    mean_per_class = (dsc_sum / max(1, n_batches)).tolist()
    macro = float(np.mean(mean_per_class)) if len(mean_per_class) else 0.0
    return macro, mean_per_class

# -------------------------
# 4) 主函数 & 参数
# -------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_root", required=False, help="仅用于信息记录，Dataset 使用 split JSON 内路径")
    ap.add_argument("--split_json", required=True, help="Step2 生成的 JSON")
    ap.add_argument("--save_dir", default="results/exp1")
    ap.add_argument("--in_channels", type=int, default=1)
    ap.add_argument("--num_classes", type=int, default=4, help="分割类别数（含背景）")
    ap.add_argument("--base", type=int, default=32)
    ap.add_argument("--epochs", type=int, default=100)
    ap.add_argument("--batch_size", type=int, default=16)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--weight_decay", type=float, default=1e-5)
    ap.add_argument("--lambda_dice", type=float, default=1.0)
    ap.add_argument("--amp", action="store_true")
    ap.add_argument("--num_workers", type=int, default=4)
    ap.add_argument("--target_h", type=int, default=256)
    ap.add_argument("--target_w", type=int, default=256)
    ap.add_argument("--min_fg_pixels", type=int, default=64)
    ap.add_argument("--patience", type=int, default=15)
    args = ap.parse_args()

    os.makedirs(args.save_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # 读取 split
    split = json.load(open(args.split_json, "r"))
    tr_items = split["train"]
    va_items = split["val"]

    train_ds = OASIS2DSeg(
        tr_items, target_hw=(args.target_h, args.target_w),
        augment=True,  min_fg_pixels=args.min_fg_pixels
    )
    val_ds = OASIS2DSeg(
        va_items, target_hw=(args.target_h, args.target_w),
        augment=False, min_fg_pixels=args.min_fg_pixels
    )
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=True)

    # 模型 & 优化
    model = UNet(in_ch=args.in_channels, n_classes=args.num_classes, base=args.base).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", patience=5, factor=0.5)
    scaler = torch.cuda.amp.GradScaler(enabled=(args.amp and device.type=="cuda"))

    best_macro = -1.0
    epochs_no_improve = 0

    for epoch in range(1, args.epochs+1):
        t0 = time.time()
        tr_loss = train_one_epoch(model, train_loader, opt, scaler, device,
                                  num_classes=args.num_classes, lambda_dice=args.lambda_dice)
        val_macro, val_per_class = validate(model, val_loader, device, args.num_classes)
        scheduler.step(val_macro)

        dt = time.time() - t0
        per_class_str = ", ".join([f"{d:.4f}" for d in val_per_class])
        print(f"[Epoch {epoch:03d}] loss={tr_loss:.4f}  val_macro_DSC={val_macro:.4f}  "
              f"per-class (fg)=[{per_class_str}]  time={dt:.1f}s")

        # 早停 & 保存最优
        if val_macro > best_macro:
            best_macro = val_macro
            epochs_no_improve = 0
            ckpt_path = os.path.join(args.save_dir, "best.ckpt")
            torch.save({"model": model.state_dict(),
                        "num_classes": args.num_classes,
                        "in_channels": args.in_channels}, ckpt_path)
            print(f"  ↳ Saved best checkpoint to {ckpt_path}")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= args.patience:
                print(f"Early stopping at epoch {epoch}. Best macro DSC={best_macro:.4f}")
                break

    print("Training finished. Best macro DSC (val) =", best_macro)

if __name__ == "__main__":
    main()


Writing scripts/train_unet.py


In [10]:
!python scripts/prepare_split.py \
  --data_root "/content/data/keras_png_slices_data" \
  --out "data/split_42.json"


[OK] volumes: train=7930, val=1699, test=1699
[OK] saved to data/split_42.json


In [13]:
%cd /content
!python -m scripts.train_unet \
  --split_json data/split_42.json \
  --save_dir results/exp1 \
  --epochs 30 \
  --batch_size 16 \
  --lr 1e-3 \
  --amp \
  --num_workers 2 \
  --num_classes 2 \
  --target_h 256 --target_w 256 \
  --min_fg_pixels 0




/content
Device: cuda
  scaler = torch.cuda.amp.GradScaler(enabled=(args.amp and device.type=="cuda"))
[Epoch 001] loss=0.0837  val_macro_DSC=0.9864  per-class (fg)=[0.9864]  time=27.7s
  ↳ Saved best checkpoint to results/exp1/best.ckpt
[Epoch 002] loss=0.0307  val_macro_DSC=0.9881  per-class (fg)=[0.9881]  time=24.0s
  ↳ Saved best checkpoint to results/exp1/best.ckpt
[Epoch 003] loss=0.0231  val_macro_DSC=0.9886  per-class (fg)=[0.9886]  time=24.0s
  ↳ Saved best checkpoint to results/exp1/best.ckpt
[Epoch 004] loss=0.0191  val_macro_DSC=0.9927  per-class (fg)=[0.9927]  time=24.0s
  ↳ Saved best checkpoint to results/exp1/best.ckpt
[Epoch 005] loss=0.0167  val_macro_DSC=0.9928  per-class (fg)=[0.9928]  time=24.0s
  ↳ Saved best checkpoint to results/exp1/best.ckpt
[Epoch 006] loss=0.0149  val_macro_DSC=0.9944  per-class (fg)=[0.9944]  time=24.2s
  ↳ Saved best checkpoint to results/exp1/best.ckpt
[Epoch 007] loss=0.0139  val_macro_DSC=0.9941  per-class (fg)=[0.9941]  time=24.0s
[Epo

In [14]:
%%writefile scripts/infer_unet.py
import os, json, argparse
import numpy as np
import torch, torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from datasets.oasis2d import OASIS2DSeg
from scripts.train_unet import UNet, dsc_per_class

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--split_json", required=True)
    ap.add_argument("--subset", default="test", choices=["train","val","test"])
    ap.add_argument("--ckpt", required=True)
    ap.add_argument("--out_dir", default="results/vis")
    ap.add_argument("--batch_size", type=int, default=16)
    ap.add_argument("--num_workers", type=int, default=2)
    ap.add_argument("--target_h", type=int, default=256)
    ap.add_argument("--target_w", type=int, default=256)
    ap.add_argument("--num_images", type=int, default=24)    # 保存前多少张三联图
    ap.add_argument("--ignore_bg", action="store_true")
    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)

    # 读 split
    split = json.load(open(args.split_json))
    items = split[args.subset]

    # 加载 ckpt
    ckpt = torch.load(args.ckpt, map_location="cpu")
    num_classes = int(ckpt.get("num_classes", 2))
    in_ch       = int(ckpt.get("in_channels", 1))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(in_ch=in_ch, n_classes=num_classes, base=32).to(device)
    model.load_state_dict(ckpt["model"])
    model.eval()

    # Dataset/DataLoader
    ds = OASIS2DSeg(items, target_hw=(args.target_h, args.target_w), augment=False, min_fg_pixels=0)
    dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False,
                    num_workers=args.num_workers, pin_memory=True)

    # 评测
    dsc_sum = np.zeros(max(1, num_classes-1), dtype=np.float64)
    n_batches = 0

    saved = 0
    with torch.no_grad():
        for (img, msk, meta) in dl:
            img = img.to(device)
            msk = msk.to(device)
            logits = model(img)

            # 统计 DSC（前景类；可选忽略背景）
            dsc_list = dsc_per_class(logits, msk, num_classes, ignore_background=args.ignore_bg or (num_classes>1))
            dsc_sum += np.array(dsc_list, dtype=np.float64)
            n_batches += 1

            # 前 num_images 张保存三联图
            for b in range(img.size(0)):
                if saved >= args.num_images: break
                pred = logits[b:b+1].argmax(1).squeeze(0).detach().cpu().numpy()
                im   = img[b,0].detach().cpu().numpy()    # [H,W]
                gt   = msk[b].detach().cpu().numpy()

                # 保存
                fig = plt.figure(figsize=(9,3))
                ax1 = fig.add_subplot(1,3,1); ax1.set_title("Image"); ax1.imshow(im, cmap="gray"); ax1.axis("off")
                ax2 = fig.add_subplot(1,3,2); ax2.set_title("Mask");  ax2.imshow(gt, vmin=0, vmax=num_classes-1); ax2.axis("off")
                ax3 = fig.add_subplot(1,3,3); ax3.set_title("Pred");  ax3.imshow(pred, vmin=0, vmax=num_classes-1); ax3.axis("off")
                pid = meta["pid"][b] if isinstance(meta["pid"], list) else meta["pid"]
                out_path = os.path.join(args.out_dir, f"{saved:03d}_{pid}.png")
                plt.tight_layout()
                plt.savefig(out_path, dpi=120)
                plt.close(fig)
                saved += 1
            if saved >= args.num_images: break

    mean_per_class = (dsc_sum / max(1, n_batches)).tolist()
    macro = float(np.mean(mean_per_class)) if len(mean_per_class) else 0.0

    print(f"[{args.subset}] macro DSC = {macro:.4f}")
    print(f"[{args.subset}] per-class (fg) = {[round(x,4) for x in mean_per_class]}")
    print(f"Saved {saved} preview images under: {args.out_dir}")

if __name__ == "__main__":
    main()


Writing scripts/infer_unet.py


In [15]:
%cd /content
!python -m scripts.infer_unet \
  --split_json data/split_42.json \
  --subset test \
  --ckpt results/exp1/best.ckpt \
  --out_dir results/exp1/vis_test \
  --num_images 24 \
  --ignore_bg \
  --target_h 256 --target_w 256


/content
[test] macro DSC = 0.9984
[test] per-class (fg) = [0.9984]
Saved 24 preview images under: results/exp1/vis_test


In [17]:
import torch, json, os, numpy as np
from datasets.oasis2d import OASIS2DSeg
from scripts.train_unet import UNet
import torch.nn.functional as F

ckpt = torch.load("results/exp1/best.ckpt", map_location="cpu")
num_classes = int(ckpt.get("num_classes", 2))
model = UNet(in_ch=int(ckpt.get("in_channels",1)), n_classes=num_classes, base=32).eval().cuda()
model.load_state_dict(ckpt["model"])

split = json.load(open("data/split_42.json"))
ds = OASIS2DSeg(split["test"][:1], target_hw=(256,256), augment=False)
img, msk, meta = ds[0]
with torch.no_grad():
    logits = model(img.unsqueeze(0).cuda())      # [1,C,H,W]
    probs  = F.softmax(logits, dim=1)            # categorical 概率
    pred   = probs.argmax(1)                     # [1,H,W] 类索引
    onehot = F.one_hot(pred, num_classes).permute(0,3,1,2).byte()  # [1,C,H,W] one-hot

os.makedirs("results/exp1/onehot_demo", exist_ok=True)
np.save("results/exp1/onehot_demo/pred_onehot.npy", onehot[0].cpu().numpy())  # 保存 one-hot
np.save("results/exp1/onehot_demo/pred_index.npy",  pred[0].cpu().numpy())    # 保存类索引
print("Saved:", "results/exp1/onehot_demo/pred_onehot.npy")


Saved: results/exp1/onehot_demo/pred_onehot.npy
