In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
scan_pairs.py
- Scan /workspace/cem_mitolab recursively.
- Find "images/" and "masks/" sibling folders (case-insensitive).
- Pair image <-> mask by relative path; fallback to stem matching with common-sense normalization.
- Validate shape consistency; write report + pairs.txt + 16 overlay PNGs.

Run:
  python /workspace/scan_pairs.py
"""

from __future__ import annotations
import os
import re
import json
import random
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
from skimage import io

try:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    HAS_PLT = True
except Exception:
    HAS_PLT = False


# ---------------------------- configurable ----------------------------
ROOT = Path(os.environ.get("EMP_DIR", "/workspace/cem_mitolab")).resolve()
QA_DIR = Path("/workspace/qa")
IMG_DIR_RE = re.compile(r"images?$", re.I)   # images / image
MSK_DIR_RE = re.compile(r"masks?$",  re.I)   # masks / mask
EXTS = {".tif", ".tiff", ".png", ".jpg", ".jpeg"}

# patterns to strip from stems when falling back to name-based matching
STEM_CLEAN_RE = re.compile(r"([_-](ch|loc|slice|z|t)?\d+|[-_](\d+))$", re.I)
# ---------------------------------------------------------------------


def is_image_file(p: Path) -> bool:
    return p.is_file() and p.suffix.lower() in EXTS


def find_sibling(dirpath: Path, pat: re.Pattern) -> Path | None:
    """
    Find a sibling directory of dirpath whose name matches 'pat' (case-insensitive).
    """
    parent = dirpath.parent
    for d in parent.iterdir():
        if d.is_dir() and pat.search(d.name):
            return d
    return None


def normalize_stem(stem: str) -> str:
    """
    Drop common trailing tokens like _ch0, -0001, -LOC-... etc (one pass).
    """
    s = STEM_CLEAN_RE.sub("", stem)
    return s


def try_mask_path(msk_root: Path, rel: Path) -> Path | None:
    """
    Try multiple suffixes in order: .tiff -> .tif -> original
    """
    cand = (msk_root / rel.with_suffix(".tiff"))
    if cand.exists():
        return cand
    cand = (msk_root / rel.with_suffix(".tif"))
    if cand.exists():
        return cand
    cand = (msk_root / rel)
    if cand.exists():
        return cand
    return None


def build_stem_index(root: Path) -> Dict[str, Path]:
    """
    Build a {normalized_stem: path} index under masks root (recursively).
    If multiple files collide, keep the first one.
    """
    idx: Dict[str, Path] = {}
    for p in root.rglob("*"):
        if is_image_file(p):
            s = normalize_stem(p.stem).lower()
            if s not in idx:
                idx[s] = p
    return idx


def pair_images_masks(root: Path) -> Tuple[List[Tuple[str, str]], dict]:
    """
    Walk root; for each 'images' dir, find sibling 'masks' dir, then pair files.
    Return (pairs, stats_dict)
    """
    pairs: List[Tuple[str, str]] = []
    bad_missing_mask: List[str] = []
    bad_read_error: List[str] = []
    bad_size_mismatch: List[Tuple[str, str, tuple, tuple]] = []
    size_hist: Dict[str, int] = {}

    # find all candidate (images, masks) siblings
    candidate_pairs: List[Tuple[Path, Path]] = []
    for dirpath, dirnames, _ in os.walk(root):
        dirpath = Path(dirpath)
        # dirs in this level
        imgs = [dirpath / d for d in dirnames if IMG_DIR_RE.search(d)]
        msks = [dirpath / d for d in dirnames if MSK_DIR_RE.search(d)]
        if not imgs or not msks:
            continue
        # choose the first match in this level
        candidate_pairs.append((imgs[0], msks[0]))

    # go through each (images_root, masks_root)
    for img_root, msk_root in candidate_pairs:
        stem_index = build_stem_index(msk_root)

        for ip in img_root.rglob("*"):
            if not is_image_file(ip):
                continue

            # 1) match by relative path
            rel = ip.relative_to(img_root)
            mp = try_mask_path(msk_root, rel)

            # 2) fallback: by normalized stem
            if mp is None:
                key = normalize_stem(ip.stem).lower()
                mp = stem_index.get(key, None)

            if mp is None or not mp.exists():
                bad_missing_mask.append(str(ip))
                continue

            # read both and validate shape
            try:
                im = io.imread(str(ip))
                mk = io.imread(str(mp))
                if im.ndim == 3 and im.shape[-1] == 1:
                    im = im[..., 0]
                if mk.ndim == 3 and mk.shape[-1] == 1:
                    mk = mk[..., 0]
            except Exception:
                bad_read_error.append(str(ip))
                continue

            if im.shape != mk.shape:
                bad_size_mismatch.append((str(ip), str(mp), im.shape, mk.shape))
                continue

            key_shape = str(im.shape)
            size_hist[key_shape] = size_hist.get(key_shape, 0) + 1
            pairs.append((str(ip), str(mp)))

    stats = {
        "root": str(root),
        "num_sibling_levels": len(candidate_pairs),
        "valid_pairs": len(pairs),
        "size_histogram": dict(sorted(size_hist.items(), key=lambda x: -x[1])),
        "bad_counts": {
            "missing_mask": len(bad_missing_mask),
            "read_error": len(bad_read_error),
            "size_mismatch": len(bad_size_mismatch),
        },
    }
    # also save a little more detail (first few)
    details = {
        "missing_mask_head": bad_missing_mask[:10],
        "read_error_head": bad_read_error[:10],
        "size_mismatch_head": bad_size_mismatch[:5],
    }
    return pairs, {"stats": stats, "details": details}


def save_overlays(pairs: List[Tuple[str, str]], outdir: Path, n: int = 16):
    if not HAS_PLT or not pairs:
        return
    outdir.mkdir(parents=True, exist_ok=True)
    samp = random.sample(pairs, min(n, len(pairs)))
    for i, (ip, mp) in enumerate(samp):
        im = io.imread(ip); mk = io.imread(mp)
        if im.ndim == 3 and im.shape[-1] == 1: im = im[..., 0]
        if mk.ndim == 3 and mk.shape[-1] == 1: mk = mk[..., 0]
        imf = im.astype(np.float32)
        p1, p99 = np.percentile(imf, [1, 99])
        if p99 > p1:
            imf = np.clip((imf - p1) / (p99 - p1), 0, 1)
        rgb = np.dstack([imf, (mk > 0).astype(np.float32), np.zeros_like(imf)])
        plt.figure(figsize=(4, 4)); plt.axis("off")
        plt.imshow(rgb); plt.tight_layout(pad=0)
        plt.savefig(outdir / f"sample_{i:02d}.png", dpi=150)
        plt.close()


def main():
    print(f"[scan] ROOT = {ROOT}")
    QA_DIR.mkdir(parents=True, exist_ok=True)

    pairs, info = pair_images_masks(ROOT)
    # write pairs.txt
    (QA_DIR / "pairs.txt").write_text(
        "\n".join([f"{a}\t{b}" for a, b in pairs]), encoding="utf-8"
    )
    # write report
    with open(QA_DIR / "report-smart.json", "w", encoding="utf-8") as f:
        json.dump(info["stats"], f, indent=2)
    with open(QA_DIR / "report-smart-details.json", "w", encoding="utf-8") as f:
        json.dump(info, f, indent=2)

    print(json.dumps(info["stats"], indent=2))
    # overlays
    save_overlays(pairs, QA_DIR, n=16)
    if pairs:
        print(f"[scan] Wrote {len(pairs)} pairs to {QA_DIR/'pairs.txt'} and overlays to {QA_DIR}")
    else:
        print("[scan] No valid pairs found. Check details JSON and directory naming (images/ vs masks/).")


if __name__ == "__main__":
    main()


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
tile_pairs.py
Cut paired (image, mask) files into MoDL-style 512×512 tiles.

Input:
  - a pairs list file, each line: <image_path>\t<mask_path>
    (we already generated /workspace/qa/pairs.txt via scan_pairs.py)
Output:
  - tiled image pngs to out_img_dir
  - tiled mask  pngs to out_msk_dir
  - a summary json

Usage (inside the container):
  python /workspace/tile_pairs.py \
      --pairs /workspace/qa/pairs.txt \
      --out-img /workspace/deform/train \
      --out-msk /workspace/deform/label \
      --tile 512 --stride 512 --keep-bg false

Note:
  - Mask is binarized as (mask > 0).
  - Image is percentile (1,99) normalized to uint8.
  - Non-multiple sizes are reflect-padded to the next grid.
"""

from __future__ import annotations
import argparse
import json
import os
from math import ceil
from pathlib import Path
from typing import Tuple, List

import numpy as np
from skimage import io

try:
    from tqdm import tqdm
except Exception:
    tqdm = lambda x, **k: x  # fallback: no progress bar


def to_uint8_percentile(im: np.ndarray, p1: float = 1.0, p99: float = 99.0) -> np.ndarray:
    """Normalize image by (p1, p99) percentiles -> uint8."""
    imf = im.astype(np.float32)
    lo, hi = np.percentile(imf, [p1, p99])
    if hi > lo:
        imf = np.clip((imf - lo) / (hi - lo), 0, 1)
    else:
        # fallback to min-max
        lo, hi = imf.min(), imf.max()
        if hi > lo:
            imf = (imf - lo) / (hi - lo)
        else:
            imf = np.zeros_like(imf, dtype=np.float32)
    return (imf * 255.0 + 0.5).astype(np.uint8)


def reflect_pad_to_grid(im: np.ndarray, tile: int) -> np.ndarray:
    """Pad with reflect to the next multiple of tile."""
    H, W = im.shape[:2]
    Hn = int(ceil(H / tile) * tile)
    Wn = int(ceil(W / tile) * tile)
    if Hn == H and Wn == W:
        return im
    pad_h = Hn - H
    pad_w = Wn - W
    if im.ndim == 2:
        return np.pad(im, ((0, pad_h), (0, pad_w)), mode="reflect")
    else:  # (H, W, C)
        return np.pad(im, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")


def iter_tiles(im: np.ndarray, mk: np.ndarray, tile: int, stride: int, keep_bg: bool):
    H, W = im.shape[:2]
    for y in range(0, H - tile + 1, stride):
        for x in range(0, W - tile + 1, stride):
            imt = im[y : y + tile, x : x + tile]
            mkt = mk[y : y + tile, x : x + tile]
            if not keep_bg and mkt.max() == 0:
                continue
            yield x, y, imt, mkt


def load_pair(ip: Path, mp: Path) -> Tuple[np.ndarray, np.ndarray] | None:
    """Read image & mask; squeeze singleton channel; validate shape."""
    try:
        im = io.imread(str(ip))
        mk = io.imread(str(mp))
        # squeeze trailing singleton channel if present
        if im.ndim == 3 and im.shape[-1] == 1:
            im = im[..., 0]
        if mk.ndim == 3 and mk.shape[-1] == 1:
            mk = mk[..., 0]
        if im.shape != mk.shape:
            return None
        return im, mk
    except Exception:
        return None


def main():
    parser = argparse.ArgumentParser(description="Cut paired images/masks into MoDL tiles.")
    parser.add_argument("--pairs", type=str, required=True,
                        help="pairs.txt generated by scan_pairs.py")
    parser.add_argument("--out-img", type=str, default="/workspace/deform/train",
                        help="output dir for image tiles")
    parser.add_argument("--out-msk", type=str, default="/workspace/deform/label",
                        help="output dir for mask tiles")
    parser.add_argument("--tile", type=int, default=512, help="tile size")
    parser.add_argument("--stride", type=int, default=512, help="stride")
    parser.add_argument("--keep-bg", type=str, default="false",
                        help="keep background-only tiles? (true/false)")
    parser.add_argument("--limit", type=int, default=0,
                        help="optional: only process first N pairs (for quick test)")
    parser.add_argument("--summary", type=str, default="/workspace/qa/tiling_report.json",
                        help="where to write summary json")
    args = parser.parse_args()

    keep_bg = str(args.keep_bg).lower() in {"1", "true", "yes", "y"}

    pairs_file = Path(args.pairs).resolve()
    out_img = Path(args.out_img).resolve()
    out_msk = Path(args.out_msk).resolve()
    out_img.mkdir(parents=True, exist_ok=True)
    out_msk.mkdir(parents=True, exist_ok=True)

    if not pairs_file.exists():
        raise FileNotFoundError(f"pairs file not found: {pairs_file}")

    # read lines
    lines = [ln.strip() for ln in pairs_file.read_text(encoding="utf-8").splitlines() if ln.strip()]
    if args.limit and args.limit > 0:
        lines = lines[: args.limit]

    total_pairs = len(lines)
    n_tiles = 0
    n_drop = 0
    n_readerr = 0
    n_mismatch = 0

    for ln in tqdm(lines, desc="tiling"):
        if "\t" not in ln:
            continue
        ip_str, mp_str = ln.split("\t", 1)
        ip, mp = Path(ip_str), Path(mp_str)

        pair = load_pair(ip, mp)
        if pair is None:
            # try to distinguish mismatch vs readerr (best-effort)
            try:
                _im = io.imread(str(ip))
                _mk = io.imread(str(mp))
                if _im.ndim == 3 and _im.shape[-1] == 1:
                    _im = _im[..., 0]
                if _mk.ndim == 3 and _mk.shape[-1] == 1:
                    _mk = _mk[..., 0]
                if _im.shape != _mk.shape:
                    n_mismatch += 1
                else:
                    n_readerr += 1
            except Exception:
                n_readerr += 1
            continue

        im, mk = pair
        # normalize & binarize
        im8 = to_uint8_percentile(im)
        mkb = (mk > 0).astype(np.uint8) * 255

        # pad to grid
        im8 = reflect_pad_to_grid(im8, args.tile)
        mkb = reflect_pad_to_grid(mkb, args.tile)

        base = ip.stem  # file stem as base name
        for x, y, imt, mkt in iter_tiles(im8, mkb, args.tile, args.stride, keep_bg):
            name = f"em11037_{base}_x{x}_y{y}.png"
            io.imsave(str(out_img / name), imt, check_contrast=False)
            io.imsave(str(out_msk / name), mkt, check_contrast=False)
            n_tiles += 1

    summary = {
        "pairs_file": str(pairs_file),
        "total_pairs_read": total_pairs,
        "processed_pairs": total_pairs - (n_readerr + n_mismatch),
        "read_errors": n_readerr,
        "shape_mismatch": n_mismatch,
        "tiles_written": n_tiles,
        "out_img": str(out_img),
        "out_msk": str(out_msk),
        "tile": args.tile,
        "stride": args.stride,
        "keep_bg": keep_bg,
    }
    Path(args.summary).parent.mkdir(parents=True, exist_ok=True)
    with open(args.summary, "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2)
    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()


In [None]:
# split_npy_80_10_10.py
# 读取 npydata/imgs_train.npy & imgs_mask_train.npy
# 按 0.8 / 0.1 / 0.1 随机划分 train / val / test

import numpy as np
from pathlib import Path

ROOT = Path(__file__).resolve().parent
NPYDIR = ROOT / "npydata"

X_path = NPYDIR / "imgs_train.npy"
Y_path = NPYDIR / "imgs_mask_train.npy"

print("[load]", X_path)
print("[load]", Y_path)
X = np.load(X_path)
Y = np.load(Y_path)

assert X.shape[0] == Y.shape[0], "X 和 Y 的样本数量不一致！"
N = X.shape[0]
print(f"[info] 总样本数 N = {N}")

# 为了可复现，可以固定随机种子
rng = np.random.default_rng(seed=42)
idx = rng.permutation(N)

n_train = int(N * 0.8)
n_val   = int(N * 0.1)
n_test  = N - n_train - n_val   # 剩下的都给 test，保证总数对得上

i_train = idx[:n_train]
i_val   = idx[n_train:n_train + n_val]
i_test  = idx[n_train + n_val:]

splits = {
    "train": i_train,
    "val":   i_val,
    "test":  i_test,
}

for name, inds in splits.items():
    X_split = X[inds]
    Y_split = Y[inds]
    np.save(NPYDIR / f"imgs_{name}.npy", X_split)
    np.save(NPYDIR / f"masks_{name}.npy", Y_split)
    print(f"[save] imgs_{name}.npy 形状: {X_split.shape}")
    print(f"[save] masks_{name}.npy 形状: {Y_split.shape}")

print("[done] 划分完成：train/val/test =",
      splits["train"].size, splits["val"].size, splits["test"].size)


In [2]:
!pip install data_load

Defaulting to user installation because normal site-packages is not writeable
[31mERROR: Could not find a version that satisfies the requirement data_load (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for data_load[0m[31m
[0m

In [3]:
#train.py code

import os
import datetime

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, UpSampling2D,
    Dropout, BatchNormalization, concatenate,
    Conv2DTranspose, Concatenate
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
from MoDL_seg.data_load import *
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


class myUnet(object):
   def __init__(self, img_rows = 512, img_cols = 512):
      self.img_rows = img_rows
      self.img_cols = img_cols

   def load_data(self):
      mydata = DataProcess(self.img_rows, self.img_cols)
      imgs_train, imgs_mask_train = mydata.load_train_data()
      return imgs_train, imgs_mask_train

   def get_unet(self):
    """
    UNet architecture from scratch with encoder-decoder structure
    Same number of layers as the original implementation
    """
    # Input layer
    inputs = Input((self.img_rows, self.img_cols, 1))
    
    # ========== ENCODER PATH (Contracting) ==========
    
    # Encoder Block 1: 64 filters
    e1_conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    e1_conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(e1_conv1)
    e1_pool = MaxPooling2D(pool_size=(2, 2))(e1_conv2)
    
    # Encoder Block 2: 128 filters
    e2_conv1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(e1_pool)
    e2_conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(e2_conv1)
    e2_pool = MaxPooling2D(pool_size=(2, 2))(e2_conv2)
    
    # Encoder Block 3: 256 filters
    e3_conv1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(e2_pool)
    e3_conv2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(e3_conv1)
    e3_pool = MaxPooling2D(pool_size=(2, 2))(e3_conv2)
    
    # Encoder Block 4: 512 filters with dropout
    e4_conv1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(e3_pool)
    e4_conv2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(e4_conv1)
    e4_drop = Dropout(0.5)(e4_conv2)
    e4_pool = MaxPooling2D(pool_size=(2, 2))(e4_drop)
    
    # ========== BOTTLENECK ==========
    
    # Bottleneck Block: 1024 filters with dropout
    bottleneck_conv1 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(e4_pool)
    bottleneck_conv2 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(bottleneck_conv1)
    bottleneck_drop = Dropout(0.5)(bottleneck_conv2)
    
    # ========== DECODER PATH (Expanding) ==========
    
    # Decoder Block 1: 512 filters
    d1_upsample = UpSampling2D(size=(2, 2))(bottleneck_drop)
    d1_transpose = Conv2DTranspose(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(d1_upsample)
    d1_merge = Concatenate(axis=3)([e4_drop, d1_transpose])
    d1_conv1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(d1_merge)
    d1_conv2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(d1_conv1)
    
    # Decoder Block 2: 256 filters
    d2_upsample = UpSampling2D(size=(2, 2))(d1_conv2)
    d2_transpose = Conv2DTranspose(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(d2_upsample)
    d2_merge = Concatenate(axis=3)([e3_conv2, d2_transpose])
    d2_conv1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(d2_merge)
    d2_conv2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(d2_conv1)
    
    # Decoder Block 3: 128 filters
    d3_upsample = UpSampling2D(size=(2, 2))(d2_conv2)
    d3_transpose = Conv2DTranspose(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(d3_upsample)
    d3_merge = Concatenate(axis=3)([e2_conv2, d3_transpose])
    d3_conv1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(d3_merge)
    d3_conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(d3_conv1)
    
    # Decoder Block 4: 64 filters
    d4_upsample = UpSampling2D(size=(2, 2))(d3_conv2)
    d4_transpose = Conv2DTranspose(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(d4_upsample)
    d4_merge = Concatenate(axis=3)([e1_conv2, d4_transpose])
    d4_conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(d4_merge)
    d4_conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(d4_conv1)
    
    # ========== OUTPUT LAYER ==========
    
    # Additional conv layer with 2 filters (matching original)
    output_conv1 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(d4_conv2)
    
    # Final output layer: 1 filter with sigmoid activation
    outputs = Conv2D(1, 1, activation='sigmoid')(output_conv1)
    
    # Create and compile model
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=Adam(learning_rate=1e-5), loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

 

   def train(self):
         print("loading data")

        # ------- 读入 npy -------
         imgs = np.load("../npydata/imgs_train.npy").astype("float32")
         masks = np.load("../npydata/imgs_mask_train.npy").astype("float32")
         print("total samples (full):", imgs.shape[0])

        # ------- 限制最多使用一部分数据，避免一次性太大 -------
         MAX_SAMPLES = 6000  # 可以以后再调大 / 调小
         N = imgs.shape[0]
         if N > MAX_SAMPLES:
            rng = np.random.default_rng(seed=42)
            idx = rng.choice(N, size=MAX_SAMPLES, replace=False)
            imgs = imgs[idx]
            masks = masks[idx]
            print(f"subsampled to {MAX_SAMPLES} samples for GPU training")
         else:
            print("use all samples for GPU training")

        # ------- 归一化 + 类型降为 float16（减少显存占用） -------
         imgs = imgs.astype("float16")
         masks = masks.astype("float16")

         imgs /= 255.0
         mean = imgs.mean(axis=0, dtype="float32")   # 用 float32 计算均值更稳
         imgs = imgs - mean.astype("float16")

         masks /= 255.0
         masks[masks > 0.5] = 1.0
         masks[masks <= 0.5] = 0.0

         print("after subsample:", imgs.shape[0])

        # ------- 划分训练集 / 验证集（0.8 / 0.2） -------
         N = imgs.shape[0]
         val_ratio = 0.2
         val_size = int(N * val_ratio)

         rng = np.random.default_rng(seed=123)
         indices = rng.permutation(N)

         val_idx = indices[:val_size]
         train_idx = indices[val_size:]

         X_train = imgs[train_idx]
         Y_train = masks[train_idx]
         X_val   = imgs[val_idx]
         Y_val   = masks[val_idx]

         print(f"train: {X_train.shape[0]}  val: {X_val.shape[0]}")

        # ------- 用 tf.data.Dataset 按 batch 喂 GPU -------
         BATCH_SIZE   = 2   # 显存安全起见先用 1，跑通后可以尝试改成 2
         TOTAL_EPOCHS = 30 # 本次只跑 10 个 epoch

         train_ds = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
         train_ds = train_ds.shuffle(buffer_size=len(X_train))
         train_ds = train_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

         val_ds = tf.data.Dataset.from_tensor_slices((X_val, Y_val))
         val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

        # ------- 每次都从头建一个新模型，不继承旧的 -------
         print("building a fresh model...")
         model = self.get_unet()
         print("got unet")

         model.compile(
            optimizer=Adam(learning_rate=1e-4),
            loss="binary_crossentropy",
            metrics=["accuracy"],
         )

         CHECKPOINT_PATH = "../model/U-RNet+_gpu_10ep.keras"
         model_checkpoint = ModelCheckpoint(
            CHECKPOINT_PATH,
            monitor="val_loss",   # 用验证集损失挑 best
            verbose=1,
            save_best_only=True,
         )

         starttrain = datetime.datetime.now()
         print("Fitting model...")

         history = model.fit(
            train_ds,
            epochs=TOTAL_EPOCHS,
            verbose=1,
            validation_data=val_ds,
            callbacks=[model_checkpoint],
         )

         endtrain = datetime.datetime.now()
         print("train time: %s seconds" % (endtrain - starttrain))

        # ------- 画 Accuracy / Loss 曲线 -------
         acc      = history.history["accuracy"]
         val_acc  = history.history["val_accuracy"]
         loss     = history.history["loss"]
         val_loss = history.history["val_loss"]
         epochs   = range(len(acc))

         plt.figure()
         plt.plot(epochs, acc, "b", label="training accuracy")
         plt.plot(epochs, val_acc, ":r", label="validation accuracy")
         plt.title("Accuracy")
         plt.xlabel("Epoch")
         plt.ylabel("Accuracy")
         plt.legend()
         plt.savefig("../model/Accuracy.png")

         plt.figure()
         plt.plot(epochs, loss, "b", label="training loss")
         plt.plot(epochs, val_loss, ":r", label="validation loss")
         plt.title("Loss")
         plt.xlabel("Epoch")
         plt.ylabel("Loss")
         plt.legend()
         plt.savefig("../model/Loss.png")

         plt.show()

         with open("../model/unet.txt", "wt") as ft:
            ft.write("loss: %.6s\n" % (loss[-1]))
            ft.write("accuracy: %.6s\n" % (acc[-1]))


if __name__ == '__main__':

   myunet = myUnet()
   myunet.train()

loading data


FileNotFoundError: [Errno 2] No such file or directory: '../npydata/imgs_train.npy'