In [5]:
#!/usr/bin/env python3
import os
from pathlib import Path
import numpy as np
import scipy.io as sio
from PIL import Image
import random

# --------- Config ---------
GT_DIR = "/home/data1/musong/workspace/2025/8/08-20/tr/data/IXI_sim/data"
LR_DIR = "/home/data1/musong/workspace/2025/8/08-20/tr/data/IXI_sim/final_rxyacq_ROFFT"

OUT_ROOT = "/home/data1/musong/workspace/python/cyclegan/datasets/IXI"
TRAIN_DIR = os.path.join(OUT_ROOT, "train")
TEST_DIR  = os.path.join(OUT_ROOT, "test")

SPLIT_TRAIN = 0.9     # 9:1 split
SEED = 20250823       # deterministic split
# --------------------------

def ensure_dirs():
    for base in (TRAIN_DIR, TEST_DIR):
        os.makedirs(os.path.join(base, "A"), exist_ok=True)
        os.makedirs(os.path.join(base, "B"), exist_ok=True)

def normalize_to_uint8(arr: np.ndarray) -> np.ndarray:
    """Min-max normalize array to [0,255] uint8 safely."""
    arr = np.asarray(arr)
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    vmin = float(arr.min())
    vmax = float(arr.max())
    if vmax > vmin:
        arr = (arr - vmin) / (vmax - vmin)
    else:
        arr = np.zeros_like(arr, dtype=np.float32)
    return (arr * 255.0).clip(0, 255).astype(np.uint8)

def to_rgb(arr_2d: np.ndarray) -> np.ndarray:
    """Replicate a single-channel (H,W) image into (H,W,3)."""
    arr_u8 = normalize_to_uint8(arr_2d)
    return np.stack([arr_u8, arr_u8, arr_u8], axis=-1)

def last_data_key(d: dict) -> str:
    """Return the last non-metadata key from a loadmat() dict."""
    keys = [k for k in d.keys() if not (isinstance(k, str) and k.startswith("__"))]
    if not keys:
        return list(d.keys())[-1]
    return keys[-1]

def load_mat_last_key(path: str) -> np.ndarray:
    """Load a .mat and return ndarray stored under the last non-metadata key."""
    d = sio.loadmat(path)
    key = last_data_key(d)
    data = d[key]
    if not isinstance(data, np.ndarray):
        raise ValueError(f"Data at last key is not an ndarray in {path}")
    return data

def to_2d(arr: np.ndarray) -> np.ndarray:
    """
    Squeeze singleton dims; ensure result is 2D.
    Accepts (W,H), (1,W,H), or (W,H,1). Raises if not 2D after squeeze.
    """
    arr = np.asarray(arr)
    arr = np.squeeze(arr)
    if arr.ndim != 2:
        raise ValueError(f"Expected 2D after squeeze, got shape {arr.shape}")
    return arr

def save_png(img_rgb: np.ndarray, out_path: str):
    Image.fromarray(img_rgb).save(out_path)

def split_stems(stems):
    """Deterministic 9:1 split with a fixed seed."""
    stems = list(stems)
    random.Random(SEED).shuffle(stems)
    n_train = int(len(stems) * SPLIT_TRAIN)
    return stems[:n_train], stems[n_train:]

def process_one(stem: str, dst_base: str,
                gt_counts: dict, lr_counts: dict):
    gt_path = os.path.join(GT_DIR, stem + ".mat")
    lr_path = os.path.join(LR_DIR, stem + ".mat")

    # --- GT -> A ---
    if os.path.exists(gt_path):
        try:
            gt_raw = load_mat_last_key(gt_path)
            gt_2d  = to_2d(gt_raw)
            gt_rgb = to_rgb(gt_2d)
            save_png(gt_rgb, os.path.join(dst_base, "A", stem + ".png"))
            gt_counts["saved"] += 1
        except Exception as e:
            print(f"[GT  WARN] {stem}: {e}")
            gt_counts["warn"] += 1
    else:
        # Not fatal; dataset A/B can be unpaired
        gt_counts["miss"] += 1

    # --- LR -> B (magnitude) ---
    if os.path.exists(lr_path):
        try:
            lr_raw = load_mat_last_key(lr_path)
            lr_2d  = to_2d(lr_raw)
            lr_mag = np.abs(lr_2d)
            lr_rgb = to_rgb(lr_mag)
            save_png(lr_rgb, os.path.join(dst_base, "B", stem + ".png"))
            lr_counts["saved"] += 1
        except Exception as e:
            print(f"[LR  WARN] {stem}: {e}")
            lr_counts["warn"] += 1
    else:
        lr_counts["miss"] += 1

def main():
    ensure_dirs()

    gt_files = sorted([f for f in os.listdir(GT_DIR) if f.endswith(".mat")])
    lr_files = sorted([f for f in os.listdir(LR_DIR) if f.endswith(".mat")])

    gt_stems = {Path(f).stem for f in gt_files}
    lr_stems = {Path(f).stem for f in lr_files}
    all_stems = sorted(gt_stems | lr_stems)

    train_stems, test_stems = split_stems(all_stems)

    counts = {
        "train_gt": {"saved": 0, "warn": 0, "miss": 0},
        "train_lr": {"saved": 0, "warn": 0, "miss": 0},
        "test_gt":  {"saved": 0, "warn": 0, "miss": 0},
        "test_lr":  {"saved": 0, "warn": 0, "miss": 0},
    }

    # Process train
    for stem in train_stems:
        process_one(stem, TRAIN_DIR, counts["train_gt"], counts["train_lr"])

    # Process test
    for stem in test_stems:
        process_one(stem, TEST_DIR, counts["test_gt"], counts["test_lr"])

    # Summary
    print("\n✅ Done.")
    print(f"Train size: {len(train_stems)} | Test size: {len(test_stems)}")
    print(f"Train A(GT): saved={counts['train_gt']['saved']}, warn={counts['train_gt']['warn']}, miss={counts['train_gt']['miss']}")
    print(f"Train B(LR): saved={counts['train_lr']['saved']}, warn={counts['train_lr']['warn']}, miss={counts['train_lr']['miss']}")
    print(f"Test  A(GT): saved={counts['test_gt']['saved']},  warn={counts['test_gt']['warn']},  miss={counts['test_gt']['miss']}")
    print(f"Test  B(LR): saved={counts['test_lr']['saved']},  warn={counts['test_lr']['warn']},  miss={counts['test_lr']['miss']}")
    print(f"\nTrain dir: {TRAIN_DIR}")
    print(f"Test  dir: {TEST_DIR}")


main()



✅ Done.
Train size: 900 | Test size: 100
Train A(GT): saved=900, warn=0, miss=0
Train B(LR): saved=900, warn=0, miss=0
Test  A(GT): saved=100,  warn=0,  miss=0
Test  B(LR): saved=100,  warn=0,  miss=0

Train dir: /home/data1/musong/workspace/python/cyclegan/datasets/IXI/train
Test  dir: /home/data1/musong/workspace/python/cyclegan/datasets/IXI/test
