In [None]:
import torch, nnunetv2
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("nnU-Net v2 imported successfully!")

# Test that you can access nnunetv2 modules
print("nnunetv2 location:", nnunetv2.__file__)

In [None]:
import os, nnunetv2, torch

# Work directories - updated for your local setup
BASE = "D:/nnunet_with_classification/data"
RAW  = f"{BASE}/nnUNet_raw"
PREP = f"{BASE}/nnUNet_preprocessed"
RES  = f"{BASE}/nnUNet_results"

# Create nnU-Net directories
os.makedirs(RAW, exist_ok=True)
os.makedirs(PREP, exist_ok=True)
os.makedirs(RES, exist_ok=True)

# Set environment variables
os.environ["nnUNet_raw"] = RAW
os.environ["nnUNet_preprocessed"] = PREP
os.environ["nnUNet_results"] = RES

print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda)
print("nnU-Net v2: imported successfully!")
print(f"nnU-Net directories created at: {BASE}")

In [None]:
import os, glob, pathlib, shutil, re, json, csv

# >>>>>>>>>>>> UPDATE THIS PATH <<<<<<<<<<<<
SRC = "D:/nnunet_with_classification/data"

DSID = 777
DSNAME = f"Dataset{DSID:03d}_M31Quiz"
RAW  = os.environ["nnUNet_raw"]
PREP = os.environ["nnUNet_preprocessed"]
DSROOT = f"{RAW}/{DSNAME}"
imgTr = f"{DSROOT}/imagesTr"; lblTr = f"{DSROOT}/labelsTr"; imgTs = f"{DSROOT}/imagesTs"
for d in (imgTr, lblTr, imgTs): os.makedirs(d, exist_ok=True)

def stem(p): return pathlib.Path(p).name.replace(".nii.gz","")
sub_regex = re.compile(r"subtype\s*([012])", re.IGNORECASE)

def find_split(name):
    # case-insensitive split folder lookup
    for cand in os.listdir(SRC):
        if cand.lower() == name:
            p = os.path.join(SRC, cand)
            if os.path.isdir(p): return p
    return None

train_dir = find_split("train")
val_dir   = find_split("validation")
test_dir  = find_split("test")
assert train_dir and val_dir and test_dir, f"Could not find train/validation/test under {SRC}"

def ingest_split(split_dir, cls_map):
    imgs = glob.glob(os.path.join(split_dir, "**", "*_0000.nii.gz"), recursive=True)
    used = 0
    for img in sorted(imgs):
        case = stem(img).replace("_0000","")
        msk  = img.replace("_0000.nii.gz",".nii.gz")
        if not os.path.exists(msk):
            continue
        shutil.copy(img, f"{imgTr}/{case}_0000.nii.gz")
        shutil.copy(msk, f"{lblTr}/{case}.nii.gz")
        used += 1
        # infer subtype from folder names (expects 'subtype0/1/2')
        sub_idx = None
        for part in pathlib.Path(img).parts:
            m = sub_regex.search(part)
            if m:
                sub_idx = int(m.group(1)); break
        if sub_idx is not None:
            cls_map[case] = sub_idx
    return used

cls_map = {}
n_tr = ingest_split(train_dir, cls_map)
n_va = ingest_split(val_dir,   cls_map)

# test images
for img in sorted(glob.glob(os.path.join(test_dir, "**", "*_0000.nii.gz"), recursive=True)):
    case = stem(img).replace("_0000","")
    shutil.copy(img, f"{imgTs}/{case}_0000.nii.gz")

print("imagesTr:", len(glob.glob(f"{imgTr}/*_0000.nii.gz")))
print("labelsTr:", len(glob.glob(f"{lblTr}/*.nii.gz")))
print("imagesTs:", len(glob.glob(f"{imgTs}/*_0000.nii.gz")))
print("Mapped classification labels:", len(cls_map))

# dataset.json
dataset_json = {
  "name": DSNAME,
  "tensorImageSize": "3D",
  "modality": {"0": "CT"},
  "labels": {"background": 0, "pancreas": 1, "lesion": 2},
  "numTraining": len(glob.glob(f"{lblTr}/*.nii.gz")),
  "numTest": len(glob.glob(f"{imgTs}/*_0000.nii.gz")),
  "training": [{"image": f"./imagesTr/{stem(i)}.nii.gz",
                "label": f"./labelsTr/{stem(i).replace('_0000','')}.nii.gz"}
               for i in sorted(glob.glob(f"{imgTr}/*_0000.nii.gz"))],
  "test": [f"./imagesTs/{stem(i)}.nii.gz" for i in sorted(glob.glob(f"{imgTs}/*_0000.nii.gz"))]
}
os.makedirs(DSROOT, exist_ok=True)
with open(f"{DSROOT}/dataset.json","w") as f: json.dump(dataset_json, f, indent=2)

# splits_final.json (use original validation as nnU-Net val)
val_cases = {stem(p).replace("_0000","") for p in glob.glob(os.path.join(val_dir, "**", "*_0000.nii.gz"), recursive=True)}
all_cases = sorted([stem(p).replace("_0000","") for p in glob.glob(f"{imgTr}/*_0000.nii.gz")])
train_cases = [c for c in all_cases if c not in val_cases]
spdir = f"{PREP}/{DSNAME}"
os.makedirs(spdir, exist_ok=True)
with open(f"{spdir}/splits_final.json","w") as f:
    json.dump([{"train": train_cases, "val": sorted(list(val_cases))}], f, indent=2)

# classification_labels.csv
csv_path = f"{spdir}/classification_labels.csv"
with open(csv_path, "w", newline="") as f:
    w = csv.writer(f)
    for c in all_cases:
        if c in cls_map:
            w.writerow([c, cls_map[c]])
        else:
            # if you see many misses here, your folders may not be named 'subtype0/1/2'
            pass

print("Wrote dataset.json, splits_final.json, classification_labels.csv")

In [None]:
import json, glob, os, pathlib

DSID = 777
DSNAME = f"Dataset{DSID:03d}_M31Quiz"
RAW   = os.environ["nnUNet_raw"]
DSROOT = f"{RAW}/{DSNAME}"
lblTr = f"{DSROOT}/labelsTr"

dataset_v2 = {
    "channel_names": { "0": "CT" },               # required in v2
    "labels": { "background": 0, "pancreas": 1, "lesion": 2 },  # your classes
    "numTraining": len(glob.glob(os.path.join(lblTr, "*.nii.gz"))),
    "file_ending": ".nii.gz"
}
with open(f"{DSROOT}/dataset.json", "w") as f:
    json.dump(dataset_v2, f, indent=2)

print("✅ Wrote v2 dataset.json at:", f"{DSROOT}/dataset.json")
print("numTraining:", dataset_v2["numTraining"])

In [None]:
import nibabel as nib, numpy as np, glob, os, pathlib

DSID = 777
DSNAME = f"Dataset{DSID:03d}_M31Quiz"
lblTr = os.path.join(os.environ["nnUNet_raw"], DSNAME, "labelsTr")

bad = []
for p in sorted(glob.glob(os.path.join(lblTr, "*.nii.gz"))):
    img = nib.load(p)
    arr = img.get_fdata()  # floats possible
    uniq = np.unique(arr)
    if not np.all(np.isin(uniq, [0,1,2])):
        bad.append((p, uniq))

print("Masks needing fix:", len(bad))
for p, uniq in bad[:10]:
    print("  ", pathlib.Path(p).name, "unique:", uniq[:10])

# Fix: round and clip to {0,1,2}, write as uint8
for p, uniq in bad:
    img = nib.load(p)
    arr = img.get_fdata()
    arr = np.rint(arr)           # round to nearest integer
    arr = np.clip(arr, 0, 2)     # enforce label set
    arr = arr.astype(np.uint8)

    hdr = img.header.copy()
    hdr.set_data_dtype(np.uint8)
    # avoid unintended scaling
    hdr["scl_slope"] = 1
    hdr["scl_inter"] = 0

    fixed = nib.Nifti1Image(arr, img.affine, hdr)
    nib.save(fixed, p)

# Sanity check after fix
remaining = []
for p in sorted(glob.glob(os.path.join(lblTr, "*.nii.gz"))):
    arr = nib.load(p).get_fdata()
    uniq = np.unique(arr)
    if not np.all(np.isin(uniq, [0,1,2])):
        remaining.append((p, uniq))
print("Remaining masks with non-{0,1,2} labels:", len(remaining))

In [None]:
!nnUNetv2_plan_and_preprocess -d 777 -c 3d_fullres --verify_dataset_integrity -pl nnUNetPlannerResEncM

In [None]:
!nnUNetv2_train 777 3d_fullres 0 -tr nnUNetTrainerWithClassification -p nnUNetResEncUNetMPlans --c

In [None]:
import torch, os ,glob
ckpt = r"D:/nnunet_with_classification/data/nnUNet_results/Dataset777_M31Quiz/nnUNetTrainerWithClassification__nnUNetResEncUNetMPlans__3d_fullres/fold_0/checkpoint_best_combined.pth"
sd = torch.load(ckpt, map_location="cpu")
state = sd.get("network_state_dict") or sd.get("network_weights") or sd  # fallbacks
hits = [k for k in state.keys() if "ClassificationHead" in k]
print("num ClassificationHead params:", len(hits))
print(hits[:5])




In [None]:
!python scripts/inference.py \
  --model_dir "D:/nnunet_with_classification/data/nnUNet_results/Dataset777_M31Quiz/nnUNetTrainerWithClassification__nnUNetResEncUNetMPlans__3d_fullres" \
  --input_dir "D:/nnunet_with_classification/data/validation_flat" \
  --output_dir "D:/nnunet_with_classification/data/validation_flat/predictions" \
  --fold 0 \
  --checkpoint checkpoint_best_combined.pth \
  --num_classes 3 \
  --no-tta \
  --preproc_workers 1 \
  --export_workers 1

In [None]:
# Complete evaluation script for your results on validation data
import os
import csv
import numpy as np
import nibabel as nib
from sklearn.metrics import f1_score, classification_report

# Validation cases
validation_cases = [
    "quiz_0_168", "quiz_0_171", "quiz_0_174", "quiz_0_184", "quiz_0_187", 
    "quiz_0_189", "quiz_0_244", "quiz_0_253", "quiz_0_254", "quiz_1_090",
    "quiz_1_093", "quiz_1_094", "quiz_1_154", "quiz_1_158", "quiz_1_164",
    "quiz_1_166", "quiz_1_211", "quiz_1_213", "quiz_1_221", "quiz_1_227",
    "quiz_1_231", "quiz_1_242", "quiz_1_331", "quiz_1_335", "quiz_2_074",
    "quiz_2_080", "quiz_2_084", "quiz_2_085", "quiz_2_088", "quiz_2_089",
    "quiz_2_098", "quiz_2_191", "quiz_2_241", "quiz_2_364", "quiz_2_377",
    "quiz_2_379"
]

# Paths (GT from env, PRED paths updated per your request)
gt_seg_dir   = f"{os.environ['nnUNet_raw']}/Dataset777_M31Quiz/labelsTr"
pred_seg_dir = "D:/nnunet_with_classification/data/validation_flat/predictions"  # <-- NEW kept
gt_cls_csv   = "D:/nnunet_with_classification/data/validation_labels.csv"
pred_cls_csv = "D:/nnunet_with_classification/data/validation_flat/predictions/classification_results.csv"  # <-- NEW kept

def dice_score(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    total = np.sum(y_true) + np.sum(y_pred)
    return (2.0 * intersection) / total if total > 0 else 1.0

# Check paths
print("🔍 Checking paths...")
print(f"GT segmentation dir:          {gt_seg_dir}   -> {os.path.exists(gt_seg_dir)}")
print(f"Predicted segmentation dir:   {pred_seg_dir} -> {os.path.exists(pred_seg_dir)}")
print(f"GT classification CSV:        {gt_cls_csv}   -> {os.path.exists(gt_cls_csv)}")
print(f"Predicted classification CSV: {pred_cls_csv} -> {os.path.exists(pred_cls_csv)}")

# Segmentation metrics (exact filename match: case.nii.gz)
print("\n📊 Computing segmentation metrics...")
pancreas_dices, lesion_dices, missing_files = [], [], []

for case in validation_cases:
    gt_file   = f"{gt_seg_dir}/{case}.nii.gz"
    pred_file = f"{pred_seg_dir}/{case}.nii.gz"
    
    if os.path.exists(gt_file) and os.path.exists(pred_file):
        try:
            gt   = nib.load(gt_file).get_fdata()
            pred = nib.load(pred_file).get_fdata()
            
            # Whole pancreas (>0)
            gt_pancreas   = (gt  > 0).astype(np.uint8)
            pred_pancreas = (pred > 0).astype(np.uint8)
            pancreas_dices.append(dice_score(gt_pancreas, pred_pancreas))
            
            # Lesion (== 2) — change if your lesion label differs
            gt_lesion   = (gt  == 2).astype(np.uint8)
            pred_lesion = (pred == 2).astype(np.uint8)
            lesion_dices.append(dice_score(gt_lesion, pred_lesion))
        except Exception as e:
            print(f"Error processing {case}: {e}")
            missing_files.append(case)
    else:
        missing_files.append(case)

if pancreas_dices:
    whole_pancreas_dsc = float(np.mean(pancreas_dices))
    lesion_dsc         = float(np.mean(lesion_dices))
    print(f"Pancreas DSC (mean): {whole_pancreas_dsc:.4f}")
    print(f"Lesion   DSC (mean): {lesion_dsc:.4f}")
else:
    whole_pancreas_dsc = 0.0
    lesion_dsc         = 0.0
    print("No matched segmentation files found.")

# Classification metrics (exact case key; no stripping)
print("\n🎯 Computing classification metrics...")
gt_labels = {}
if os.path.exists(gt_cls_csv):
    with open(gt_cls_csv, 'r', newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        if 'case' not in reader.fieldnames or 'label' not in reader.fieldnames:
            raise RuntimeError(f"Expected columns 'case' and 'label' in {gt_cls_csv}")
        for row in reader:
            gt_labels[row['case'].strip()] = int(row['label'])

pred_labels = {}
if os.path.exists(pred_cls_csv):
    with open(pred_cls_csv, 'r', newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        if 'case' not in reader.fieldnames:
            raise RuntimeError(f"Expected column 'case' in {pred_cls_csv}")
        # prefer explicit pred_class; else derive from probs p0..pN
        has_pred_class = 'pred_class' in reader.fieldnames
        prob_cols = [c for c in reader.fieldnames if c.startswith('p') and c[1:].isdigit()]
        for row in reader:
            case = row['case'].strip()
            if has_pred_class and row.get('pred_class', '') != '':
                pred_labels[case] = int(row['pred_class'])
            elif prob_cols:
                probs = np.array([float(row[p]) for p in prob_cols], dtype=float)
                pred_labels[case] = int(np.argmax(probs))

# Evaluate only overlapping cases (exact name match)
val_gt, val_pred = [], []
for case in validation_cases:
    if case in gt_labels and case in pred_labels:
        val_gt.append(gt_labels[case])
        val_pred.append(pred_labels[case])

if val_gt and val_pred:
    macro_f1 = f1_score(val_gt, val_pred, average='macro')
    print(f"Macro F1: {macro_f1:.4f}")
else:
    macro_f1 = 0.0
    print("No overlapping classification cases found between GT and predictions.")

# Final summary
print("\n===== SUMMARY =====")
print(f"Whole Pancreas DSC: {whole_pancreas_dsc:.4f}")
print(f"Lesion DSC:         {lesion_dsc:.4f}")
print(f"Macro F1:           {macro_f1:.4f}")
if missing_files:
    print(f"Missing seg files for {len(missing_files)} cases (first few): {missing_files[:5]}")
