In [None]:
import torch, nnunetv2
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("nnU-Net v2 imported successfully!")

# Test that you can access nnunetv2 modules
print("nnunetv2 location:", nnunetv2.__file__)

In [None]:
import os, nnunetv2, torch

# Work directories - updated for your local setup
BASE = "D:/nnunet_with_classification/data"
RAW  = f"{BASE}/nnUNet_raw"
PREP = f"{BASE}/nnUNet_preprocessed"
RES  = f"{BASE}/nnUNet_results"

# Create nnU-Net directories
os.makedirs(RAW, exist_ok=True)
os.makedirs(PREP, exist_ok=True)
os.makedirs(RES, exist_ok=True)

# Set environment variables
os.environ["nnUNet_raw"] = RAW
os.environ["nnUNet_preprocessed"] = PREP
os.environ["nnUNet_results"] = RES

print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda)
print("nnU-Net v2: imported successfully!")
print(f"nnU-Net directories created at: {BASE}")

In [5]:
import os, glob, pathlib, shutil, re, json, csv

# >>>>>>>>>>>> UPDATE THIS PATH <<<<<<<<<<<<
SRC = "D:/nnunet_with_classification/data"

DSID = 777
DSNAME = f"Dataset{DSID:03d}_M31Quiz"
RAW  = os.environ["nnUNet_raw"]
PREP = os.environ["nnUNet_preprocessed"]
DSROOT = f"{RAW}/{DSNAME}"
imgTr = f"{DSROOT}/imagesTr"; lblTr = f"{DSROOT}/labelsTr"; imgTs = f"{DSROOT}/imagesTs"
for d in (imgTr, lblTr, imgTs): os.makedirs(d, exist_ok=True)

def stem(p): return pathlib.Path(p).name.replace(".nii.gz","")
sub_regex = re.compile(r"subtype\s*([012])", re.IGNORECASE)

def find_split(name):
    # case-insensitive split folder lookup
    for cand in os.listdir(SRC):
        if cand.lower() == name:
            p = os.path.join(SRC, cand)
            if os.path.isdir(p): return p
    return None

train_dir = find_split("train")
val_dir   = find_split("validation")
test_dir  = find_split("test")
assert train_dir and val_dir and test_dir, f"Could not find train/validation/test under {SRC}"

def ingest_split(split_dir, cls_map):
    imgs = glob.glob(os.path.join(split_dir, "**", "*_0000.nii.gz"), recursive=True)
    used = 0
    for img in sorted(imgs):
        case = stem(img).replace("_0000","")
        msk  = img.replace("_0000.nii.gz",".nii.gz")
        if not os.path.exists(msk):
            continue
        shutil.copy(img, f"{imgTr}/{case}_0000.nii.gz")
        shutil.copy(msk, f"{lblTr}/{case}.nii.gz")
        used += 1
        # infer subtype from folder names (expects 'subtype0/1/2')
        sub_idx = None
        for part in pathlib.Path(img).parts:
            m = sub_regex.search(part)
            if m:
                sub_idx = int(m.group(1)); break
        if sub_idx is not None:
            cls_map[case] = sub_idx
    return used

cls_map = {}
n_tr = ingest_split(train_dir, cls_map)
n_va = ingest_split(val_dir,   cls_map)

# test images
for img in sorted(glob.glob(os.path.join(test_dir, "**", "*_0000.nii.gz"), recursive=True)):
    case = stem(img).replace("_0000","")
    shutil.copy(img, f"{imgTs}/{case}_0000.nii.gz")

print("imagesTr:", len(glob.glob(f"{imgTr}/*_0000.nii.gz")))
print("labelsTr:", len(glob.glob(f"{lblTr}/*.nii.gz")))
print("imagesTs:", len(glob.glob(f"{imgTs}/*_0000.nii.gz")))
print("Mapped classification labels:", len(cls_map))

# dataset.json
dataset_json = {
  "name": DSNAME,
  "tensorImageSize": "3D",
  "modality": {"0": "CT"},
  "labels": {"background": 0, "pancreas": 1, "lesion": 2},
  "numTraining": len(glob.glob(f"{lblTr}/*.nii.gz")),
  "numTest": len(glob.glob(f"{imgTs}/*_0000.nii.gz")),
  "training": [{"image": f"./imagesTr/{stem(i)}.nii.gz",
                "label": f"./labelsTr/{stem(i).replace('_0000','')}.nii.gz"}
               for i in sorted(glob.glob(f"{imgTr}/*_0000.nii.gz"))],
  "test": [f"./imagesTs/{stem(i)}.nii.gz" for i in sorted(glob.glob(f"{imgTs}/*_0000.nii.gz"))]
}
os.makedirs(DSROOT, exist_ok=True)
with open(f"{DSROOT}/dataset.json","w") as f: json.dump(dataset_json, f, indent=2)

# splits_final.json (use original validation as nnU-Net val)
val_cases = {stem(p).replace("_0000","") for p in glob.glob(os.path.join(val_dir, "**", "*_0000.nii.gz"), recursive=True)}
all_cases = sorted([stem(p).replace("_0000","") for p in glob.glob(f"{imgTr}/*_0000.nii.gz")])
train_cases = [c for c in all_cases if c not in val_cases]
spdir = f"{PREP}/{DSNAME}"
os.makedirs(spdir, exist_ok=True)
with open(f"{spdir}/splits_final.json","w") as f:
    json.dump([{"train": train_cases, "val": sorted(list(val_cases))}], f, indent=2)

# classification_labels.csv
csv_path = f"{spdir}/classification_labels.csv"
with open(csv_path, "w", newline="") as f:
    w = csv.writer(f)
    for c in all_cases:
        if c in cls_map:
            w.writerow([c, cls_map[c]])
        else:
            # if you see many misses here, your folders may not be named 'subtype0/1/2'
            pass

print("Wrote dataset.json, splits_final.json, classification_labels.csv")

imagesTr: 288
labelsTr: 288
imagesTs: 72
Mapped classification labels: 288
Wrote dataset.json, splits_final.json, classification_labels.csv


In [7]:
import json, glob, os, pathlib

DSID = 777
DSNAME = f"Dataset{DSID:03d}_M31Quiz"
RAW   = os.environ["nnUNet_raw"]
DSROOT = f"{RAW}/{DSNAME}"
lblTr = f"{DSROOT}/labelsTr"

dataset_v2 = {
    "channel_names": { "0": "CT" },               # required in v2
    "labels": { "background": 0, "pancreas": 1, "lesion": 2 },  # your classes
    "numTraining": len(glob.glob(os.path.join(lblTr, "*.nii.gz"))),
    "file_ending": ".nii.gz"
}
with open(f"{DSROOT}/dataset.json", "w") as f:
    json.dump(dataset_v2, f, indent=2)

print("✅ Wrote v2 dataset.json at:", f"{DSROOT}/dataset.json")
print("numTraining:", dataset_v2["numTraining"])

✅ Wrote v2 dataset.json at: D:/nnunet_with_classification/data/nnUNet_raw/Dataset777_M31Quiz/dataset.json
numTraining: 288


In [9]:
import nibabel as nib, numpy as np, glob, os, pathlib

DSID = 777
DSNAME = f"Dataset{DSID:03d}_M31Quiz"
lblTr = os.path.join(os.environ["nnUNet_raw"], DSNAME, "labelsTr")

bad = []
for p in sorted(glob.glob(os.path.join(lblTr, "*.nii.gz"))):
    img = nib.load(p)
    arr = img.get_fdata()  # floats possible
    uniq = np.unique(arr)
    if not np.all(np.isin(uniq, [0,1,2])):
        bad.append((p, uniq))

print("Masks needing fix:", len(bad))
for p, uniq in bad[:10]:
    print("  ", pathlib.Path(p).name, "unique:", uniq[:10])

# Fix: round and clip to {0,1,2}, write as uint8
for p, uniq in bad:
    img = nib.load(p)
    arr = img.get_fdata()
    arr = np.rint(arr)           # round to nearest integer
    arr = np.clip(arr, 0, 2)     # enforce label set
    arr = arr.astype(np.uint8)

    hdr = img.header.copy()
    hdr.set_data_dtype(np.uint8)
    # avoid unintended scaling
    hdr["scl_slope"] = 1
    hdr["scl_inter"] = 0

    fixed = nib.Nifti1Image(arr, img.affine, hdr)
    nib.save(fixed, p)

# Sanity check after fix
remaining = []
for p in sorted(glob.glob(os.path.join(lblTr, "*.nii.gz"))):
    arr = nib.load(p).get_fdata()
    uniq = np.unique(arr)
    if not np.all(np.isin(uniq, [0,1,2])):
        remaining.append((p, uniq))
print("Remaining masks with non-{0,1,2} labels:", len(remaining))

Masks needing fix: 214
   quiz_0_060.nii.gz unique: [0.         1.00001526 2.        ]
   quiz_0_066.nii.gz unique: [0.         1.00001526 2.        ]
   quiz_0_077.nii.gz unique: [0.         1.00001526 2.        ]
   quiz_0_117.nii.gz unique: [0.         1.00001526 2.        ]
   quiz_0_126.nii.gz unique: [0.         1.00001526 2.        ]
   quiz_0_139.nii.gz unique: [0.         1.00001526 2.        ]
   quiz_0_145.nii.gz unique: [0.         1.00001526 2.        ]
   quiz_0_150.nii.gz unique: [0.         1.00001526 2.        ]
   quiz_0_159.nii.gz unique: [0.         1.00001526 2.        ]
   quiz_0_160.nii.gz unique: [0.         1.00001526 2.        ]
Remaining masks with non-{0,1,2} labels: 0


In [None]:
!nnUNetv2_plan_and_preprocess -d 777 -c 3d_fullres --verify_dataset_integrity -pl nnUNetPlannerResEncM

In [None]:
!nnUNetv2_train 777 3d_fullres 0 -tr NNUNet -p nnUNetResEncUNetMPlans

In [None]:
#Making folder flat, not having sub folders as nnUnetv2 expects a flat folder not subfolds
import os
import glob
import shutil

# Create a flat validation folder for nnU-Net
val_flat_dir = "D:/nnunet_with_classification/data/validation_flat"
os.makedirs(val_flat_dir, exist_ok=True)

# Copy all validation images to flat folder
val_source = "D:/nnunet_implementation_m31_assessment/data/validation"
val_files = glob.glob(os.path.join(val_source, "**", "*_0000.nii.gz"), recursive=True)

print(f"Copying {len(val_files)} validation files to flat folder...")
for src_file in val_files:
    filename = os.path.basename(src_file)
    dst_file = os.path.join(val_flat_dir, filename)
    shutil.copy2(src_file, dst_file)
    print(f"  Copied: {filename}")

print(f"✅ Flat validation folder ready: {val_flat_dir}")

In [None]:
!nnUNetv2_train 777 3d_fullres 0 -tr NNUNet_tuned -p nnUNetResEncUNetMPlans

In [None]:
#inference for validation data for segementation
# After creating flat folder:
!nnUNetv2_predict \
  -i D:/nnunet_with_classification/data/test \
  -o D:/nnunet_with_classification/data/predictions_validation \
  -d 777 \
  -f 0 \
  -tr NNUNet_tuned  \
  -c 3d_fullres \
  -p nnUNetResEncUNetMPlans

In [None]:
# Cell: Validation Classification Inference (FIXED for FP16/FP32 mismatch)

import os, json, torch, csv, nibabel as nib
from pathlib import Path
import numpy as np
from typing import List

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
import torch.nn as nn

class MultiScaleClassificationHead(nn.Module):
    """
    Multi-scale feature fusion classification head with attention mechanism
    (Must match exactly what's in your NNUNet.py)
    """
    def __init__(self, encoder_channels: List[int], num_classes: int, dim: int = 3, 
                 target_channels: int = 256, spatial_reduction: int = 4):
        super().__init__()
        self.num_scales = 3  # Use last 3 encoder stages
        self.dim = dim
        
        # Multi-scale feature adapters
        self.feature_adapters = nn.ModuleList()
        
        for channels in encoder_channels[-self.num_scales:]:
            if dim == 3:
                adapter = nn.Sequential(
                    nn.AdaptiveAvgPool3d((spatial_reduction, spatial_reduction, spatial_reduction)),
                    nn.Conv3d(channels, target_channels, kernel_size=1, bias=False),
                    nn.BatchNorm3d(target_channels),
                    nn.ReLU(inplace=True),
                    nn.Dropout3d(0.1)
                )
            else:
                adapter = nn.Sequential(
                    nn.AdaptiveAvgPool2d((spatial_reduction, spatial_reduction)),
                    nn.Conv2d(channels, target_channels, kernel_size=1, bias=False),
                    nn.BatchNorm2d(target_channels),
                    nn.ReLU(inplace=True),
                    nn.Dropout2d(0.1)
                )
            self.feature_adapters.append(adapter)
        
        # Global pooling for each scale
        if dim == 3:
            self.global_pool = nn.AdaptiveAvgPool3d(1)
        else:
            self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Attention mechanism to weight different scales
        self.attention = nn.Sequential(
            nn.Linear(target_channels * self.num_scales, target_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(target_channels, self.num_scales),
            nn.Softmax(dim=1)
        )
        
        # Feature fusion
        self.feature_fusion = nn.Sequential(
            nn.Linear(target_channels, target_channels),
            nn.BatchNorm1d(target_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4)
        )
        
        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(target_channels, target_channels // 2),
            nn.BatchNorm1d(target_channels // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(target_channels // 2, target_channels // 4),
            nn.BatchNorm1d(target_channels // 4),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(target_channels // 4, num_classes)
        )
    
    def forward(self, encoder_features):
        """
        Args:
            encoder_features: List of feature tensors from encoder stages
        Returns:
            Classification logits [B, num_classes]
        """
        if len(encoder_features) < self.num_scales:
            raise ValueError(f"Expected at least {self.num_scales} encoder features, "
                           f"got {len(encoder_features)}")
        
        # Process multi-scale features
        multi_scale_features = []
        
        for feat, adapter in zip(encoder_features[-self.num_scales:], self.feature_adapters):
            # Adapt features to common channel size and spatial resolution
            adapted = adapter(feat)
            # Global pooling to get feature vector
            pooled = self.global_pool(adapted).flatten(1)
            multi_scale_features.append(pooled)
        
        # Stack all scale features
        stacked_features = torch.stack(multi_scale_features, dim=1)  # [B, num_scales, target_channels]
        
        # Concatenate for attention computation
        concat_features = torch.cat(multi_scale_features, dim=1)  # [B, num_scales * target_channels]
        
        # Compute attention weights for different scales
        attention_weights = self.attention(concat_features)  # [B, num_scales]
        
        # Apply attention weights to aggregate multi-scale features
        attention_weights = attention_weights.unsqueeze(-1)  # [B, num_scales, 1]
        weighted_features = (stacked_features * attention_weights).sum(dim=1)  # [B, target_channels]
        
        # Feature fusion
        fused_features = self.feature_fusion(weighted_features)
        
        # Final classification
        logits = self.classifier(fused_features)
        
        return logits

# ---- PATHS ----
MODEL_DIR = Path("D:/nnunet_with_classification/data/nnUNet_results/Dataset777_M31Quiz/NNUNet_tuned__nnUNetResEncUNetMPlans__3d_fullres")
FOLD_DIR  = MODEL_DIR / "fold_0"
CKPT      = FOLD_DIR / "checkpoint_final.pth"

IMAGES_TS = Path("D:/nnunet_with_classification/data/test")
OUT_DIR   = Path("D:/nnunet_with_classification/predictions_validation")
CSV_OUT   = OUT_DIR / "subtype_results.csv"

OUT_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---- Build Predictor ----
predictor = nnUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=True,
    device=device,
    verbose=True,
    verbose_preprocessing=False,
    allow_tqdm=True,
)

predictor.initialize_from_trained_model_folder(
    model_training_output_dir=str(MODEL_DIR),
    use_folds=[0],
    checkpoint_name=CKPT.name,
)

net = predictor.network

# ---- Load Classification Head ----
print("🔍 Loading checkpoint...")
ckpt = torch.load(CKPT, map_location=device, weights_only=False)
cls_sd = ckpt.get("cls_state_dict", None)

encoder_channels = net.encoder.output_channels
print(f"Encoder channels: {encoder_channels}")

classifier = MultiScaleClassificationHead(
    encoder_channels=encoder_channels,
    num_classes=3,
    dim=3,
    target_channels=256,
    spatial_reduction=4
).to(device)

classifier.load_state_dict(cls_sd, strict=True)
print("✅ Successfully loaded MultiScaleClassificationHead weights")

# ---- Setup Feature Capture ----
encoder_features = []

def create_hook(stage_idx):
    def hook_fn(module, input, output):
        while len(encoder_features) <= stage_idx:
            encoder_features.append(None)
        encoder_features[stage_idx] = output
    return hook_fn

hooks = []
for i, stage in enumerate(net.encoder.stages):
    hook = stage.register_forward_hook(create_hook(i))
    hooks.append(hook)

net.eval()
classifier.eval()
torch.set_grad_enabled(False)

# ---- Find Validation Images ----
case_files = []
for subfolder in ["subtype0", "subtype1", "subtype2"]:
    subfolder_path = IMAGES_TS / subfolder
    if subfolder_path.exists():
        files = list(subfolder_path.glob("*_0000.nii.gz"))
        case_files.extend(files)

case_files = sorted(case_files)
print(f"\n🔍 Found {len(case_files)} validation files")

# ---- Process Cases ----
rows = [("Names", "Subtype")]

for img in case_files:
    case_id = img.name.replace("_0000.nii.gz", "") + ".nii.gz"
    print(f"Processing: {img.name}")

    encoder_features.clear()
    temp_output = OUT_DIR / f"temp_{case_id}"
    
    try:
        predictor.predict_from_files(
            [[str(img)]], 
            [str(temp_output)],
            save_probabilities=False,
            overwrite=True,
            num_processes_preprocessing=1,
            num_processes_segmentation_export=1
        )
        
        if temp_output.exists():
            temp_output.unlink()
            
    except Exception as e:
        print(f"❌ Error processing {img.name}: {e}")
        pred = 0
        rows.append((case_id, pred))
        continue

    # Classification with FP16/FP32 fix
    if len(encoder_features) >= 3:
        try:
            valid_features = [f for f in encoder_features if f is not None]
            
            if len(valid_features) >= 3:
                # 🔧 FIX: Convert features to float32 to match classifier weights
                valid_features_float = [f.float() for f in valid_features]
                logits = classifier(valid_features_float)
                pred = int(torch.argmax(logits, dim=1).item())
            else:
                print(f"⚠️  Insufficient features captured for {img.name}")
                pred = 0
                
        except Exception as e:
            print(f"⚠️  Classification error for {img.name}: {e}")
            pred = 0
    else:
        print(f"⚠️  No encoder features captured for {img.name}")
        pred = 0

    rows.append((case_id, pred))
    print(f"  ✅ Classification: {pred}")

# Cleanup
for hook in hooks:
    hook.remove()

# ---- Save Results ----
with open(CSV_OUT, "w", newline="") as f:
    w = csv.writer(f)
    w.writerows(rows)

print(f"\n✅ Validation Classification CSV written: {CSV_OUT}")

# Show distribution
import pandas as pd
df = pd.read_csv(CSV_OUT)
print(f"\nValidation Classification Results:")
print(f"   Total cases: {len(df)}")
print(f"   Distribution:")
print(df['Subtype'].value_counts().sort_index())

In [None]:
# Complete evaluation script for your results on validation data
import os
import csv
import numpy as np
import nibabel as nib
from sklearn.metrics import f1_score, classification_report

# Validation cases
validation_cases = [
    "quiz_0_168", "quiz_0_171", "quiz_0_174", "quiz_0_184", "quiz_0_187", 
    "quiz_0_189", "quiz_0_244", "quiz_0_253", "quiz_0_254", "quiz_1_090",
    "quiz_1_093", "quiz_1_094", "quiz_1_154", "quiz_1_158", "quiz_1_164",
    "quiz_1_166", "quiz_1_211", "quiz_1_213", "quiz_1_221", "quiz_1_227",
    "quiz_1_231", "quiz_1_242", "quiz_1_331", "quiz_1_335", "quiz_2_074",
    "quiz_2_080", "quiz_2_084", "quiz_2_085", "quiz_2_088", "quiz_2_089",
    "quiz_2_098", "quiz_2_191", "quiz_2_241", "quiz_2_364", "quiz_2_377",
    "quiz_2_379"
]

# Paths
gt_seg_dir = f"{os.environ['nnUNet_raw']}/Dataset777_M31Quiz/labelsTr"
pred_seg_dir = "D:/nnunet_with_classification/data/predictions_validation"
gt_cls_csv = f"{os.environ['nnUNet_preprocessed']}/Dataset777_M31Quiz/classification_labels.csv"
pred_cls_csv = "D:/nnunet_with_classification/predictions_validation/subtype_results.csv"

def dice_score(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    total = np.sum(y_true) + np.sum(y_pred)
    return (2.0 * intersection) / total if total > 0 else 1.0

# Check paths
print("🔍 Checking paths...")
print(f"GT segmentation dir: {os.path.exists(gt_seg_dir)}")
print(f"Predicted segmentation dir: {os.path.exists(pred_seg_dir)}")
print(f"GT classification CSV: {os.path.exists(gt_cls_csv)}")
print(f"Predicted classification CSV: {os.path.exists(pred_cls_csv)}")

# Calculate segmentation metrics
pancreas_dices = []
lesion_dices = []
missing_files = []

print("\n📊 Computing segmentation metrics...")
for case in validation_cases:
    gt_file = f"{gt_seg_dir}/{case}.nii.gz"
    pred_file = f"{pred_seg_dir}/{case}.nii.gz"
    
    if os.path.exists(gt_file) and os.path.exists(pred_file):
        try:
            gt = nib.load(gt_file).get_fdata()
            pred = nib.load(pred_file).get_fdata()
            
            # Whole pancreas (label > 0)
            gt_pancreas = (gt > 0).astype(int)
            pred_pancreas = (pred > 0).astype(int)
            pancreas_dice = dice_score(gt_pancreas, pred_pancreas)
            pancreas_dices.append(pancreas_dice)
            
            # Lesion only (label == 2)
            gt_lesion = (gt == 2).astype(int)
            pred_lesion = (pred == 2).astype(int)
            lesion_dice = dice_score(gt_lesion, pred_lesion)
            lesion_dices.append(lesion_dice)
            
            print(f"{case}: Pancreas={pancreas_dice:.3f}, Lesion={lesion_dice:.3f}")
            
        except Exception as e:
            print(f"Error processing {case}: {e}")
            missing_files.append(case)
    else:
        missing_files.append(case)
        print(f"Missing files for {case}")

if pancreas_dices:
    whole_pancreas_dsc = np.mean(pancreas_dices)
    lesion_dsc = np.mean(lesion_dices)
    
    print(f"\n📈 Segmentation Statistics:")
    print(f"Pancreas DSC - Mean: {whole_pancreas_dsc:.4f}, Std: {np.std(pancreas_dices):.4f}")
    print(f"Lesion DSC - Mean: {lesion_dsc:.4f}, Std: {np.std(lesion_dices):.4f}")
else:
    whole_pancreas_dsc = 0.0
    lesion_dsc = 0.0

# Classification metrics
print("\n🎯 Computing classification metrics...")
gt_labels = {}
with open(gt_cls_csv, 'r') as f:
    reader = csv.reader(f)
    for row in reader:
        if len(row) == 2:
            gt_labels[row[0]] = int(row[1])

pred_labels = {}
with open(pred_cls_csv, 'r') as f:
    reader = csv.reader(f)
    next(reader)  # Skip header
    for row in reader:
        if len(row) == 2:
            case_name = row[0].replace('.nii.gz', '')
            pred_labels[case_name] = int(row[1])

# Get validation classification results
val_gt = []
val_pred = []
for case in validation_cases:
    if case in gt_labels and case in pred_labels:
        val_gt.append(gt_labels[case])
        val_pred.append(pred_labels[case])

if val_gt and val_pred:
    macro_f1 = f1_score(val_gt, val_pred, average='macro')
    
    print(f"\n🔍 Classification Analysis:")
    print(f"Ground truth distribution: {np.bincount(val_gt)}")
    print(f"Prediction distribution: {np.bincount(val_pred)}")
    print(f"\nDetailed Classification Report:")
    print(classification_report(val_gt, val_pred, target_names=['Subtype 0', 'Subtype 1', 'Subtype 2']))
else:
    macro_f1 = 0.0

# Final Results
print("\n" + "="*70)
print("🎓 PhD LEVEL REQUIREMENTS EVALUATION")
print("="*70)
print(f"Whole Pancreas DSC: {whole_pancreas_dsc:.4f} ≥ 0.91: {'✅ PASS' if whole_pancreas_dsc >= 0.91 else 'Below expectations'}")
print(f"Lesion DSC: {lesion_dsc:.4f} ≥ 0.31: {'✅ PASS' if lesion_dsc >= 0.31 else 'Below expectations'}")
print(f"Macro F1: {macro_f1:.4f} ≥ 0.70: {'✅ PASS' if macro_f1 >= 0.70 else 'Below expectations'}")

overall_pass = (whole_pancreas_dsc >= 0.91) and (lesion_dsc >= 0.31) and (macro_f1 >= 0.70)
print(f"\n🎯 OVERALL RESULT: {'✅ PASS' if overall_pass else 'Below expectations '}")

# Summary for your report
print(f"\n📋 SUMMARY FOR REPORT:")
print(f"Processed {len(pancreas_dices)} segmentation cases and {len(val_gt)} classification cases")
print(f"Whole Pancreas DSC: {whole_pancreas_dsc:.4f}")
print(f"Lesion DSC: {lesion_dsc:.4f}")
print(f"Classification Macro F1: {macro_f1:.4f}")

if missing_files:
    print(f"⚠️ Missing files for {len(missing_files)} cases: {missing_files}")

In [None]:
#inference for test data

!nnUNetv2_predict \
  -i D:/nnunet_with_classification/data/test \
  -o D:/nnunet_with_classification/data/predictions_test_label \
  -d 777 \
  -f 0 \
  -tr NNUNet_tuned  \
  -c 3d_fullres \
  -p nnUNetResEncUNetMPlans

In [None]:
# Cell: Test labels Inference for classification (FIXED for FP16/FP32 mismatch + File Discovery)
# These wont be correct as the model has not learned well on training data. 
import os, json, torch, csv, nibabel as nib
from pathlib import Path
import numpy as np
from typing import List

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
import torch.nn as nn

class MultiScaleClassificationHead(nn.Module):
    """
    Multi-scale feature fusion classification head with attention mechanism
    (Must match exactly what's in your NNUNet.py)
    """
    def __init__(self, encoder_channels: List[int], num_classes: int, dim: int = 3, 
                 target_channels: int = 256, spatial_reduction: int = 4):
        super().__init__()
        self.num_scales = 3  # Use last 3 encoder stages
        self.dim = dim
        
        # Multi-scale feature adapters
        self.feature_adapters = nn.ModuleList()
        
        for channels in encoder_channels[-self.num_scales:]:
            if dim == 3:
                adapter = nn.Sequential(
                    nn.AdaptiveAvgPool3d((spatial_reduction, spatial_reduction, spatial_reduction)),
                    nn.Conv3d(channels, target_channels, kernel_size=1, bias=False),
                    nn.BatchNorm3d(target_channels),
                    nn.ReLU(inplace=True),
                    nn.Dropout3d(0.1)
                )
            else:
                adapter = nn.Sequential(
                    nn.AdaptiveAvgPool2d((spatial_reduction, spatial_reduction)),
                    nn.Conv2d(channels, target_channels, kernel_size=1, bias=False),
                    nn.BatchNorm2d(target_channels),
                    nn.ReLU(inplace=True),
                    nn.Dropout2d(0.1)
                )
            self.feature_adapters.append(adapter)
        
        # Global pooling for each scale
        if dim == 3:
            self.global_pool = nn.AdaptiveAvgPool3d(1)
        else:
            self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Attention mechanism to weight different scales
        self.attention = nn.Sequential(
            nn.Linear(target_channels * self.num_scales, target_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(target_channels, self.num_scales),
            nn.Softmax(dim=1)
        )
        
        # Feature fusion
        self.feature_fusion = nn.Sequential(
            nn.Linear(target_channels, target_channels),
            nn.BatchNorm1d(target_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4)
        )
        
        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(target_channels, target_channels // 2),
            nn.BatchNorm1d(target_channels // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(target_channels // 2, target_channels // 4),
            nn.BatchNorm1d(target_channels // 4),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(target_channels // 4, num_classes)
        )
    
    def forward(self, encoder_features):
        """
        Args:
            encoder_features: List of feature tensors from encoder stages
        Returns:
            Classification logits [B, num_classes]
        """
        if len(encoder_features) < self.num_scales:
            raise ValueError(f"Expected at least {self.num_scales} encoder features, "
                           f"got {len(encoder_features)}")
        
        # Process multi-scale features
        multi_scale_features = []
        
        for feat, adapter in zip(encoder_features[-self.num_scales:], self.feature_adapters):
            # Adapt features to common channel size and spatial resolution
            adapted = adapter(feat)
            # Global pooling to get feature vector
            pooled = self.global_pool(adapted).flatten(1)
            multi_scale_features.append(pooled)
        
        # Stack all scale features
        stacked_features = torch.stack(multi_scale_features, dim=1)  # [B, num_scales, target_channels]
        
        # Concatenate for attention computation
        concat_features = torch.cat(multi_scale_features, dim=1)  # [B, num_scales * target_channels]
        
        # Compute attention weights for different scales
        attention_weights = self.attention(concat_features)  # [B, num_scales]
        
        # Apply attention weights to aggregate multi-scale features
        attention_weights = attention_weights.unsqueeze(-1)  # [B, num_scales, 1]
        weighted_features = (stacked_features * attention_weights).sum(dim=1)  # [B, target_channels]
        
        # Feature fusion
        fused_features = self.feature_fusion(weighted_features)
        
        # Final classification
        logits = self.classifier(fused_features)
        
        return logits

# ---- PATHS ----
MODEL_DIR = Path("D:/nnunet_with_classification/data/nnUNet_results/Dataset777_M31Quiz/NNUNet_tuned__nnUNetResEncUNetMPlans__3d_fullres")
FOLD_DIR  = MODEL_DIR / "fold_0"
CKPT      = FOLD_DIR / "checkpoint_final.pth"

IMAGES_TS = Path("D:/nnunet_with_classification/data/test")
OUT_DIR   = Path("D:/nnunet_with_classification/data/predictions_test_label")
CSV_OUT   = OUT_DIR / "subtype_results.csv"

OUT_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---- Build Predictor ----
predictor = nnUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=True,
    device=device,
    verbose=True,
    verbose_preprocessing=False,
    allow_tqdm=True,
)

predictor.initialize_from_trained_model_folder(
    model_training_output_dir=str(MODEL_DIR),
    use_folds=[0],
    checkpoint_name=CKPT.name,
)

net = predictor.network

# ---- Load Classification Head ----
print("🔍 Loading checkpoint...")
ckpt = torch.load(CKPT, map_location=device, weights_only=False)
cls_sd = ckpt.get("cls_state_dict", None)

encoder_channels = net.encoder.output_channels
print(f"Encoder channels: {encoder_channels}")

classifier = MultiScaleClassificationHead(
    encoder_channels=encoder_channels,
    num_classes=3,
    dim=3,
    target_channels=256,
    spatial_reduction=4
).to(device)

classifier.load_state_dict(cls_sd, strict=True)
print("✅ Successfully loaded MultiScaleClassificationHead weights")

# ---- Setup Feature Capture ----
encoder_features = []

def create_hook(stage_idx):
    def hook_fn(module, input, output):
        while len(encoder_features) <= stage_idx:
            encoder_features.append(None)
        encoder_features[stage_idx] = output
    return hook_fn

hooks = []
for i, stage in enumerate(net.encoder.stages):
    hook = stage.register_forward_hook(create_hook(i))
    hooks.append(hook)

net.eval()
classifier.eval()
torch.set_grad_enabled(False)

# ---- Find Test Images ----
print(f"🔍 Looking for test files in: {IMAGES_TS}")

# Check if the directory exists
if not IMAGES_TS.exists():
    print(f"❌ Test directory does not exist: {IMAGES_TS}")
    exit(1)

# Look for files directly in the test directory (flat structure)
case_files = list(IMAGES_TS.glob("*_0000.nii.gz"))
case_files = sorted(case_files)

print(f"🔍 Found {len(case_files)} test files")

# Debug: Show first few files found
if len(case_files) > 0:
    print("📋 Sample files found:")
    for i, file in enumerate(case_files[:5]):  # Show first 5 files
        print(f"   {i+1}. {file.name}")
    if len(case_files) > 5:
        print(f"   ... and {len(case_files) - 5} more files")
else:
    print("❌ No files matching pattern '*_0000.nii.gz' found!")
    print("📋 Files in directory:")
    all_files = list(IMAGES_TS.glob("*"))
    for file in all_files[:10]:  # Show first 10 files
        print(f"   {file.name}")
    if len(all_files) > 10:
        print(f"   ... and {len(all_files) - 10} more files")
    exit(1)

# ---- Process Cases ----
rows = [("Names", "Subtype")]

for img in case_files:
    # For files like 'quiz_037_0000.nii.gz', we want 'quiz_037.nii.gz'
    case_id = img.name.replace("_0000.nii.gz", ".nii.gz")
    print(f"Processing: {img.name}")

    encoder_features.clear()
    temp_output = OUT_DIR / f"temp_{case_id}"
    
    try:
        predictor.predict_from_files(
            [[str(img)]], 
            [str(temp_output)],
            save_probabilities=False,
            overwrite=True,
            num_processes_preprocessing=1,
            num_processes_segmentation_export=1
        )
        
        if temp_output.exists():
            temp_output.unlink()
            
    except Exception as e:
        print(f"❌ Error processing {img.name}: {e}")
        pred = 0
        rows.append((case_id, pred))
        continue

    # Classification with FP16/FP32 fix
    if len(encoder_features) >= 3:
        try:
            valid_features = [f for f in encoder_features if f is not None]
            
            if len(valid_features) >= 3:
                # 🔧 FIX: Convert features to float32 to match classifier weights
                valid_features_float = [f.float() for f in valid_features]
                logits = classifier(valid_features_float)
                pred = int(torch.argmax(logits, dim=1).item())
            else:
                print(f"⚠️  Insufficient features captured for {img.name}")
                pred = 0
                
        except Exception as e:
            print(f"⚠️  Classification error for {img.name}: {e}")
            pred = 0
    else:
        print(f"⚠️  No encoder features captured for {img.name}")
        pred = 0

    rows.append((case_id, pred))
    print(f"  ✅ Classification: {pred}")

# Cleanup
for hook in hooks:
    hook.remove()

# ---- Save Results ----
with open(CSV_OUT, "w", newline="") as f:
    w = csv.writer(f)
    w.writerows(rows)

print(f"\n✅ Test Classification CSV written: {CSV_OUT}")

# Show distribution
import pandas as pd
df = pd.read_csv(CSV_OUT)
print(f"\nTest Classification Results:")
print(f"   Total cases: {len(df)}")
print(f"   Distribution:")
print(df['Subtype'].value_counts().sort_index())