# WP5 â€” Inference with MONAI Model Zoo (no training)
This notebook runs direct inference on WP5 data using small pretrained 3D segmentation models from the MONAI Model Zoo (Bundles). Expect low zero-shot performance; this is for a quick baseline. Data details: see WP5_Segmentation_Data_Guide.md (train=380, test=180; evaluate classes 0..4 and ignore class 6).

In [1]:
import os, json, re, time
from pathlib import Path
import numpy as np
import torch
import monai
from monai.utils import set_determinism
set_determinism(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('monai', monai.__version__, 'torch', torch.__version__, 'device', device)

monai 1.5.1+3.gc4a1acc8 torch 2.8.0+cu128 device cuda


In [2]:
DATA_ROOT = Path('/data3/wp5/wp5-code/dataloaders/wp5-dataset')
DATA_DIR = DATA_ROOT / 'data'
SPLIT_CFG = DATA_ROOT / '3ddl_split_config_20250801.json'
OUT_DIR = Path('runs/wp5_pretrained_infer')
OUT_DIR.mkdir(parents=True, exist_ok=True)
print('DATA_DIR exists:', DATA_DIR.exists()); print('SPLIT_CFG exists:', SPLIT_CFG.exists()); print('OUT_DIR:', OUT_DIR.resolve())

DATA_DIR exists: True
SPLIT_CFG exists: True
OUT_DIR: /home/peisheng/MONAI/runs/wp5_pretrained_infer


In [3]:
d_train_json = Path('datalist_train.json'); d_test_json = Path('datalist_test.json')
def build_datalists(data_dir: Path, cfg_path: Path):
    cfg = json.loads(cfg_path.read_text()); test_serials = set(cfg.get('test_serial_numbers', []))
    def serial_from_name(n): m = re.match(r'^SN(\d+)', n); return int(m.group(1)) if m else None
    pairs = {};
    for n in os.listdir(data_dir):
        if n.endswith('_image.nii'):
            base = n[:-10]; img = str(data_dir / f'{base}_image.nii'); lbl = str(data_dir / f'{base}_label.nii');
            if os.path.exists(lbl): pairs[base] = (img, lbl, serial_from_name(n))
    train, test = [], []
    for k, (img, lbl, serial) in pairs.items():
        rec = {'image': img, 'label': lbl, 'id': k}; (test if serial in test_serials else train).append(rec)
    return train, test
if not (d_train_json.exists() and d_test_json.exists()):
    train_list, test_list = build_datalists(DATA_DIR, SPLIT_CFG); d_train_json.write_text(json.dumps(train_list, indent=2)); d_test_json.write_text(json.dumps(test_list, indent=2))
train_list = json.loads(d_train_json.read_text()); test_list = json.loads(d_test_json.read_text()); print('train/test sizes:', len(train_list), len(test_list))

train/test sizes: 380 180


In [4]:
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Orientationd, ScaleIntensityRanged
from monai.data import Dataset, DataLoader
val_transforms = Compose([LoadImaged(keys=['image','label']), EnsureChannelFirstd(keys=['image','label']), Orientationd(keys=['image','label'], axcodes='RAS'), ScaleIntensityRanged(keys=['image'], a_min=-3, a_max=8.5, b_min=0.0, b_max=1.0, clip=True)])
ds_test = Dataset(test_list, transform=val_transforms); dl_test = DataLoader(ds_test, batch_size=1, shuffle=False, num_workers=2); len(ds_test)



180

In [5]:
from monai.networks.nets import UNet
def load_pretrained_or_fallback(device):
    try:
        from monai.bundle import download, ConfigParser
        bundle_name = 'spleen_ct_segmentation'; bundle_root = Path('pretrained_models') / bundle_name; bundle_root.mkdir(parents=True, exist_ok=True)
        print('Attempting to download/load bundle:', bundle_name)
        local_dir = download(name=bundle_name, bundle_dir=str(bundle_root))
        config_file = os.path.join(local_dir, 'configs', 'inference.json'); parser = ConfigParser(); parser.read_config_files([config_file]); net = parser.get_parsed_content('network')
        ckpt_path = os.path.join(local_dir, 'models', 'model.pt')
        if os.path.exists(ckpt_path):
            print('Loading weights from:', ckpt_path); w = torch.load(ckpt_path, map_location=device); sd = w.get('state_dict', w) if isinstance(w, dict) else w
            try: net.load_state_dict(sd, strict=False)
            except Exception as e: print('Non-strict load failed, trying strict=False fallback:', e); net.load_state_dict(sd, strict=False)
        else: print('No ckpt found, using random init for bundle network.')
        return net.to(device).eval()
    except Exception as e:
        print('Bundle load failed (likely offline). Fallback to small UNet.', e); net = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16,32,64,128,256), strides=(2,2,2,2)); return net.to(device).eval()
net = load_pretrained_or_fallback(device); sum(p.numel() for p in net.parameters())/1e6

Attempting to download/load bundle: spleen_ct_segmentation
2025-10-06 14:32:50,310 - INFO - --- input summary of monai.bundle.scripts.download ---
2025-10-06 14:32:50,311 - INFO - > name: 'spleen_ct_segmentation'
2025-10-06 14:32:50,312 - INFO - > bundle_dir: 'pretrained_models/spleen_ct_segmentation'
2025-10-06 14:32:50,314 - INFO - > source: 'monaihosting'
2025-10-06 14:32:50,314 - INFO - > remove_prefix: 'monai_'
2025-10-06 14:32:50,315 - INFO - > progress: True
2025-10-06 14:32:50,316 - INFO - ---


Bundle load failed (likely offline). Fallback to small UNet. import huggingface_hub (No module named 'huggingface_hub').

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


1.97961

In [6]:
from monai.inferers import sliding_window_inference
from monai.transforms import SaveImage
from monai.metrics import HausdorffDistanceMetric, SurfaceDistanceMetric
import torch, numpy as np, json, csv
from pathlib import Path
pred_dir = OUT_DIR / 'predictions_spleen_unet'; pred_dir.mkdir(parents=True, exist_ok=True)
try:
    import nibabel as nib  # noqa
    saver = SaveImage(output_dir=str(pred_dir), output_postfix='pred', output_ext='.nii.gz', output_dtype=np.uint8, resample=False, mode='nearest', separate_folder=False, print_log=False)
except Exception:
    saver = None
MAX_CASES = 5; roi = (112,112,80)
def infer_batch(images):
    with torch.no_grad(): return sliding_window_inference(images, roi_size=roi, sw_batch_size=1, predictor=net)
# Init metrics for HD/ASD (full HD to match baselines)
hd_metric = HausdorffDistanceMetric(percentile=100.0, reduction='none')
asd_metric = SurfaceDistanceMetric(symmetric=True, reduction='none')
# Accumulators
classes = [0,1,2,3,4]
sum_dice = {c:0.0 for c in classes}; sum_iou = {c:0.0 for c in classes}; sum_hd = {c:0.0 for c in classes}; sum_asd = {c:0.0 for c in classes}; counts = {c:0 for c in classes}
per_case_rows = []
for i, batch in enumerate(dl_test):
    if i >= MAX_CASES: break
    img = batch['image'].to(device); lbl = batch['label']
    logits = infer_batch(img); pred = torch.argmax(logits, dim=1, keepdim=True).cpu()
    # Save predictions
    meta_list = batch.get('image_meta_dict', None)
    if saver is not None and meta_list is not None:
        for b in range(pred.shape[0]): saver(pred[b], meta_list[b] if isinstance(meta_list, list) else meta_list)
    else:
        id_field = batch.get('id', None)
        for b in range(pred.shape[0]):
            base = (id_field[b] if isinstance(id_field, list) else id_field) if id_field is not None else f'case_{i}_{b}'
            np.save(pred_dir / f"{base}_pred.npy", pred[b].numpy())
    # Build ignore mask
    ignore_mask = (lbl != 6)
    # Per-class metrics (0..4); use union foreground for classes 1..4
    for cls in classes:
        if cls == 0:
            pred_mask = (pred == 0)
            gt_mask = (lbl == 0)
        else:
            pred_mask = (pred > 0)
            gt_mask = (lbl == cls)
        # apply ignore mask (remove voxels where lbl==6)
        pm = (pred_mask & ignore_mask).squeeze(1).numpy().astype(np.uint8)
        gm = (gt_mask & ignore_mask).squeeze(1).numpy().astype(np.uint8)
        inter = (pm & gm).sum(); uni = ((pm | gm)).sum()
        dice = (2.0 * inter) / (pm.sum() + gm.sum() + 1e-8) if (pm.sum()+gm.sum())>0 else 1.0
        iou = (inter / (uni + 1e-8)) if uni>0 else 1.0
        # HD/ASD: define 0.0 if either empty and the other non-empty to match provided baselines
        if pm.sum()==0 and gm.sum()==0:
            hd = 0.0; asd = 0.0
        elif pm.sum()==0 or gm.sum()==0:
            hd = 0.0; asd = 0.0
        else:
            pt = torch.from_numpy(pm[None,None,...].astype(np.float32)); gt = torch.from_numpy(gm[None,None,...].astype(np.float32))
            try:
                hd = float(hd_metric(pt, gt).numpy()[0])
            except Exception:
                hd = 0.0
            try:
                asd = float(asd_metric(pt, gt).numpy()[0])
            except Exception:
                asd = 0.0
        sum_dice[cls] += float(dice); sum_iou[cls] += float(iou); sum_hd[cls] += float(hd); sum_asd[cls] += float(asd); counts[cls] += 1
        # store per-case
        case_id = (batch.get('id',[f'case_{i}'])[0] if isinstance(batch.get('id'), list) else batch.get('id', f'case_{i}'))
        per_case_rows.append({'case': str(case_id), 'class': int(cls), 'dice': float(dice), 'iou': float(iou), 'hd': float(hd), 'asd': float(asd)})
# Summaries
summary = {str(c): {'dice': sum_dice[c]/max(counts[c],1), 'iou': sum_iou[c]/max(counts[c],1), 'hd': sum_hd[c]/max(counts[c],1), 'asd': sum_asd[c]/max(counts[c],1)} for c in classes}
avg = {'dice': float(np.mean([summary[str(c)]['dice'] for c in classes])), 'iou': float(np.mean([summary[str(c)]['iou'] for c in classes])), 'hd': float(np.mean([summary[str(c)]['hd'] for c in classes])), 'asd': float(np.mean([summary[str(c)]['asd'] for c in classes]))}
print('Per-class summary (0..4):', json.dumps(summary, indent=2))
print('Average over classes 0..4:', avg)
# Save metrics
metrics_dir = OUT_DIR / 'metrics'; metrics_dir.mkdir(parents=True, exist_ok=True)
with open(metrics_dir / 'summary.json', 'w') as f: json.dump({'per_class': summary, 'average': avg}, f, indent=2)
with open(metrics_dir / 'per_case.csv', 'w', newline='') as f:
    w = csv.DictWriter(f, fieldnames=['case','class','dice','iou','hd','asd']); w.writeheader(); w.writerows(per_case_rows)


  win_data = inputs[unravel_slice[0]].to(sw_device)
  out[idx_zm] += p


Per-class summary (0..4): {
  "0": {
    "dice": 0.4504289995715703,
    "iou": 0.2914043450707414,
    "hd": 0.0,
    "asd": 0.0
  },
  "1": {
    "dice": 0.1795441872972036,
    "iou": 0.10209297755184918,
    "hd": 0.0,
    "asd": 0.0
  },
  "2": {
    "dice": 0.15627896175118405,
    "iou": 0.08740492799559102,
    "hd": 0.0,
    "asd": 0.0
  },
  "3": {
    "dice": 0.004404174193678086,
    "iou": 0.0022266029651355863,
    "hd": 0.0,
    "asd": 0.0
  },
  "4": {
    "dice": 0.07507721769579162,
    "iou": 0.039028502626969906,
    "hd": 0.0,
    "asd": 0.0
  }
}
Average over classes 0..4: {'dice': 0.17314670810188554, 'iou': 0.10443147124205743, 'hd': 0.0, 'asd': 0.0}
