In [None]:
# nnUNetv2_predict \
#   -i /path/to/nnUNet_raw/DatasetXXX/imagesTs \
#   -o /path/to/nnunet_output \
#   -d DatasetXXX \
#   -c 3d_fullres \
#   -f all \
#   --save_probabilities



In [None]:
!python3 nnunet_to_medsam2_prompts.py \
    --nnunet_output nnunet_output \
    --ct_dir ct_images \
    --out_dir prompts_out \
    --dilation_iters 3 \
    --bbox_padding 5


In [None]:
# Minimal MedSAM2 inference: refine key-slice masks only
# Uses prompts from nnunet_to_medsam2_prompts.py and writes refined masks to medsam2_results
!python3 medsam2_infer_3D_CT_minimal.py \
    --checkpoint MedSAM2/checkpoints/MedSAM2_CTLesion.pt \
    --cfg configs/sam2.1_hiera_t512.yaml \
    --prompts_dir prompts_out \
    --output_dir medsam2_results

In [None]:
# Stitch individual MedSAM2 segmentations into combined multi-label files
!python3 stitch_medsam2_segmentations.py \
    --masks_dir medsam2_results \
    --output_dir medsam2_results \
    --reference_dir ct_images


## Alternative: Negative Prompts Workflow

The following cells use positive and negative mask prompts for more constrained refinement.


In [None]:
# Generate positive and negative mask prompts from nnUNet segmentations
# This creates separate positive (target label) and negative (other labels) masks
!python3 nnunet_to_medsam2_prompts_masks.py \
    --nnunet_output nnunet_output \
    --ct_dir ct_images \
    --out_dir prompts_out_masks \
    --dilation_iters 3


In [None]:
# MedSAM2 inference with positive and negative prompts
# Uses positive mask to segment target, negative point prompts to exclude other labels
!python3 medsam2_infer_3D_CT_negprompts.py \
    --checkpoint MedSAM2/checkpoints/MedSAM2_CTLesion.pt \
    --cfg configs/sam2.1_hiera_t512.yaml \
    --prompts_dir prompts_out_masks \
    --output_dir medsam2_results_negprompts


## Uncertainty-Guided Refinement Pipeline

This approach leverages MedSAM2 only where nnU-Net is uncertain, while preserving nnU-Net's superior anatomical topology for CHD-specific defects like VSDs.

**Key Features:**
- Uses **point prompts** from high-confidence seeds instead of bounding boxes
- **Selective refinement**: Only refines high-contrast structures (LV, RV, Aorta, Pulmonary)
- **Skips refinement** for topologically complex structures (Myocardium, VSD areas)
- **Uncertainty-guided ensemble**: Combines nnU-Net core with MedSAM2 refinement in uncertainty zones

**Structure Classification:**
- **High-Contrast** (refined with MedSAM2): LV, RV, Aorta, Pulmonary
- **Topologically Complex** (uses nnU-Net directly): Myocardium, VSD areas

**Note:** This workflow uses a separate prompts directory (`prompts_out_uncertainty`) to avoid mixing with previous approaches. The prompts are generated specifically from the test case probability maps.


In [22]:
# Step 0: Generate initial prompts from nnU-Net segmentations
# This creates the base JSON prompt files that will be enhanced with uncertainty analysis
# Note: The test folder contains both .npz (probabilities) and .nii.gz (segmentations) files
!python3 nnunet_to_medsam2_prompts.py \
    --nnunet_output nnunet_probabilities_Dataset1Ensemblefrom100DAepoch5foldvalidationtest \
    --ct_dir ct_images \
    --out_dir prompts_out_uncertainty \
    --dilation_iters 3 \
    --bbox_padding 5


Processing nnUNet outputs: 100%|██████████████████| 1/1 [00:18<00:00, 18.16s/it]

✔ Prompt generation complete.
  → JSON + coarse masks written to: prompts_out_uncertainty


In [23]:
# Step 1: Analyze nnU-Net probability maps for uncertainty
# This calculates entropy and identifies high-confidence seeds and uncertainty zones
# Updates JSON prompt files with uncertainty analysis data
# The prompts are now based on the probability maps, not reused from other cases
!python3 nnunet_uncertainty_analysis.py \
    --prob_dir nnunet_probabilities_Dataset1Ensemblefrom100DAepoch5foldvalidationtest \
    --prompts_dir prompts_out_uncertainty \
    --nnunet_seg_dir nnunet_probabilities_Dataset1Ensemblefrom100DAepoch5foldvalidationtest \
    --num_points 8 \
    --confidence_threshold 0.95 \
    --entropy_percentile 75.0

Found 1 probability files
Processing probability maps:   0%|                        | 0/1 [00:00<?, ?it/s]  Loaded probabilities: shape (8, 221, 512, 512)
  Entropy threshold (75.0th percentile): 0.0027
    LV: 182426 seeds found, 10 sampled, 263707 uncertainty voxels
    RV: 686990 seeds found, 10 sampled, 1022888 uncertainty voxels
    LA: 216460 seeds found, 10 sampled, 645327 uncertainty voxels
    RA: 1865847 seeds found, 10 sampled, 2441826 uncertainty voxels
    Myo: 672424 seeds found, 10 sampled, 1568039 uncertainty voxels
    Aorta: 501099 seeds found, 10 sampled, 770132 uncertainty voxels
    Pulmonary: 554715 seeds found, 10 sampled, 1153105 uncertainty voxels
  ✓ Updated prompts_out_uncertainty/ct_1004.json
Processing probability maps: 100%|████████████████| 1/1 [00:16<00:00, 16.79s/it]


In [24]:
# Step 2: Run MedSAM2 inference with uncertainty-guided refinement
# Uses point prompts from high-confidence seeds
# Skips refinement for topologically complex structures (Myocardium, VSD areas)
# Uses the uncertainty-specific prompts directory
!python3 medsam2_infer_3D_CT_uncertainty_guided.py \
    --checkpoint MedSAM2/checkpoints/MedSAM2_CTLesion.pt \
    --cfg sam2/configs/sam2.1_hiera_t512.yaml \
    -i ct_images \
    --prompts_dir prompts_out_uncertainty \
    --nnunet_seg_dir nnunet_probabilities_Dataset1Ensemblefrom100DAepoch5foldvalidationtest \
    -o medsam2_results_uncertainty


CUDA not available, using CPU
Found 1 JSON files in prompts_out_uncertainty
  Found case ct_1004: ct_images/ct_1004_0000.nii.gz
Processing 1 cases from JSON files
  0%|                                                     | 0/1 [00:00<?, ?it/s]  Processing LV (label_id=1)
    High-contrast: True, Topologically complex: False, VSD: False
    Using MedSAM2 refinement for LV
    Using central key slice: 74 (area: 5894 pixels, z-range: 46-112)
    Using coarse mask from nnUNet for initialization
    Using 1 point prompts from high-confidence seeds

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(

propagate in video:   0%|                               | 0/147 [00:00<?, ?it/s][A    Frame 74: out_mask_

In [None]:
# Step 3: Stitch segmentations with uncertainty-guided weighted ensemble
# Final Mask = (nnU-Net Core) + (MedSAM2 output restricted to Uncertainty Zone)
# This preserves nnU-Net's high-confidence regions while applying MedSAM2 refinement only in uncertainty zones
# Uses the uncertainty-specific prompts directory
!python3 stitch_medsam2_segmentations.py \
    --masks_dir medsam2_results_uncertainty \
    --output_dir medsam2_results_uncertainty \
    --reference_dir ct_images \
    --nnunet_seg_dir nnunet_probabilities_Dataset1Ensemblefrom100DAepoch5foldvalidationtest \
    --prompts_dir prompts_out_uncertainty \
    --use_uncertainty_ensemble


Found 1 cases to process: ['ct_1004_0000']
Processing cases:   0%|                                   | 0/1 [00:00<?, ?it/s]
Processing case: ct_1004_0000
  Found 7 mask files:
    1: LV - ct_1004_0000_LV_mask.nii.gz
    2: RV - ct_1004_0000_RV_mask.nii.gz
    3: LA - ct_1004_0000_LA_mask.nii.gz
    4: RA - ct_1004_0000_RA_mask.nii.gz
    5: Myo - ct_1004_0000_Myo_mask.nii.gz
    6: Aorta - ct_1004_0000_Aorta_mask.nii.gz
    7: Pulmonary - ct_1004_0000_Pulmonary_mask.nii.gz
  Using original CT as reference: ct_images/ct_1004_0000.nii.gz
    Label 1 (LV): 1514523 voxels
    Label 2 (RV): 2928353 voxels
    Label 3 (LA): 1462359 voxels
    Label 4 (RA): 3937968 voxels
    Label 5 (Myo): 1166763 voxels
    Label 6 (Aorta): 1176333 voxels
    Label 7 (Pulmonary): 3123835 voxels
  Saved with reference spatial info: medsam2_results_uncertainty/ct_1004_0000_medsamrefined_seg.nii.gz
  ✓ Combined segmentation saved: medsam2_results_uncertainty/ct_1004_0000_medsamrefined_seg.nii.gz
    Shape: (22

In [None]:
# Stitch individual negative prompts masks into combined multi-label files
# Note: --use_negprompts flag tells the script to look for *_negprompts_mask.nii.gz files
!python3 stitch_medsam2_segmentations.py \
    --masks_dir medsam2_results_negprompts \
    --output_dir medsam2_results_negprompts \
    --reference_dir ct_images \
    --use_negprompts


## Approach 2: No-Dilation with Eroded Negatives

This approach fixes the "fattening" problem by:
1. Not dilating positive masks (or minimal dilation)
2. ERODING negative masks (creates safety buffer)
3. Using higher threshold (0.5 instead of 0.0)
4. Optional post-processing erosion

**Test Mode**: Set `TEST_CASE_ID` below to process only one case for testing. Set to `None` to process all cases.


In [None]:
# Step 1: Generate Prompts (No-Dilation)
# Positive masks: no dilation (or minimal 1 iteration)
# Negative masks: eroded (2 iterations) instead of dilated

# TEST MODE: Set to a case ID (e.g., "ct_1023") to process only one case, or None for all cases
TEST_CASE_ID = None  # Change to None to process all cases

# Build command
cmd = f"""python3 nnunet_to_medsam2_prompts_nodilation.py \\
    --nnunet_output nnunet_output \\
    --ct_dir ct_images \\
    --out_dir prompts_nodilation \\
    --positive_dilation 0 \\
    --negative_erosion 2"""

if TEST_CASE_ID:
    cmd += f" \\\n    --case_id {TEST_CASE_ID}"
    print(f"TEST MODE: Processing only case {TEST_CASE_ID}\n")
else:
    print("Processing all cases\n")

!{cmd}


In [None]:
# Step 2: Run MedSAM2 Inference (No-Dilation)
# Uses threshold=0.5 (instead of 0.0) to prevent fattening
# Optional: add --post_erosion flag if masks still too fat

# Build command
cmd = f"""python3 medsam2_infer_3D_CT_nodilation.py \\
    --checkpoint MedSAM2/checkpoints/MedSAM2_CTLesion.pt \\
    --cfg configs/sam2.1_hiera_t512.yaml \\
    --prompts_dir prompts_nodilation \\
    --output_dir medsam2_results_nodilation \\
    --threshold 0.5"""

if TEST_CASE_ID:
    cmd += f" \\\n    --case_id {TEST_CASE_ID}"

# Uncomment the following lines if masks are still too fat:
# cmd += " \\\n    --post_erosion"
# cmd += " \\\n    --erosion_iters 1"

!{cmd}


In [None]:
# Step 3: Stitch Results (No-Dilation)

# Build command
cmd = """python3 stitch_medsam2_segmentations_nodilation.py \\
    --masks_dir medsam2_results_nodilation \\
    --output_dir medsam2_results_nodilation \\
    --reference_dir ct_images"""

if TEST_CASE_ID:
    cmd += f" \\\n    --case_id {TEST_CASE_ID}"

!{cmd}


## Compare Results

Now you have:
- Original nnUNet: `nnunet_output/{case}_seg.nii.gz` (or `nnunet_output/{case}.nii.gz`)
- Original MedSAM2 (with dilation): `medsam2_results/{case}_medsamrefined_seg.nii.gz`
- MedSAM2 with negative prompts: `medsam2_results_negprompts/{case}_medsamrefined_negprompts_seg.nii.gz`
- New MedSAM2 (no dilation): `medsam2_results_nodilation/{case}_seg_nodilation.nii.gz`


In [None]:
# Calculate Dice Scores for Comparison
import SimpleITK as sitk
import numpy as np
from pathlib import Path

def calculate_dice(pred, gt):
    """Calculate Dice coefficient."""
    intersection = np.logical_and(pred, gt).sum()
    union = pred.sum() + gt.sum()
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    return 2.0 * intersection / union

def compare_approaches(case_id, ground_truth_dir=None):
    """Compare all approaches on one case."""
    
    label_names = {1: "LV", 2: "RV", 3: "LA", 4: "RA", 
                   5: "Myo", 6: "Aorta", 7: "Pulmonary"}
    
    # Load predictions
    nnunet_path = Path("nnunet_output") / f"{case_id}.nii.gz"
    if not nnunet_path.exists():
        nnunet_path = Path("nnunet_output") / f"{case_id}_seg.nii.gz"
    
    medsam_orig_path = Path("medsam2_results") / f"{case_id}_medsamrefined_seg.nii.gz"
    medsam_negprompts_path = Path("medsam2_results_negprompts") / f"{case_id}_medsamrefined_negprompts_seg.nii.gz"
    medsam_nodil_path = Path("medsam2_results_nodilation") / f"{case_id}_seg_nodilation.nii.gz"
    
    if not nnunet_path.exists():
        print(f"nnUNet output not found: {nnunet_path}")
        return
    
    nnunet = sitk.GetArrayFromImage(sitk.ReadImage(str(nnunet_path)))
    
    results = {}
    if medsam_orig_path.exists():
        results['MedSAM-Orig'] = sitk.GetArrayFromImage(sitk.ReadImage(str(medsam_orig_path)))
    if medsam_negprompts_path.exists():
        results['MedSAM-NegPrompts'] = sitk.GetArrayFromImage(sitk.ReadImage(str(medsam_negprompts_path)))
    if medsam_nodil_path.exists():
        results['MedSAM-NoDil'] = sitk.GetArrayFromImage(sitk.ReadImage(str(medsam_nodil_path)))
    
    if ground_truth_dir and Path(ground_truth_dir).exists():
        gt_path = Path(ground_truth_dir) / f"{case_id}_seg.nii.gz"
        if gt_path.exists():
            gt = sitk.GetArrayFromImage(sitk.ReadImage(str(gt_path)))
            print(f"\n{'Label':<12} {'nnUNet':<10} {'MedSAM-Orig':<15} {'MedSAM-NegPrompts':<20} {'MedSAM-NoDil':<15} {'Best':<12}")
            print("-" * 95)
            
            for label_id, label_name in label_names.items():
                dice_nnunet = calculate_dice(nnunet == label_id, gt == label_id)
                dice_scores = {'nnUNet': dice_nnunet}
                
                for method_name, pred_array in results.items():
                    dice_scores[method_name] = calculate_dice(pred_array == label_id, gt == label_id)
                
                best_method = max(dice_scores.items(), key=lambda x: x[1] if isinstance(x[1], (int, float)) else 0)
                best_str = f"{best_method[0]}: {best_method[1]:.4f}"
                
                print(f"{label_name:<12} {dice_nnunet:.4f}    "
                      f"{dice_scores.get('MedSAM-Orig', 'N/A'):<15} "
                      f"{dice_scores.get('MedSAM-NegPrompts', 'N/A'):<20} "
                      f"{dice_scores.get('MedSAM-NoDil', 'N/A'):<15} "
                      f"{best_str:<12}")
    
    # Volume comparison
    print(f"\n{'Label':<12} {'nnUNet':<12} {'MedSAM-Orig':<15} {'% Change':<12} "
          f"{'MedSAM-NegPrompts':<20} {'% Change':<12} {'MedSAM-NoDil':<15} {'% Change':<12}")
    print("-" * 120)
    
    for label_id, label_name in label_names.items():
        vol_nnunet = (nnunet == label_id).sum()
        pct_orig = 0
        pct_negprompts = 0
        pct_nodil = 0
        
        vol_orig_str = "N/A"
        vol_negprompts_str = "N/A"
        vol_nodil_str = "N/A"
        
        if 'MedSAM-Orig' in results:
            vol_orig = (results['MedSAM-Orig'] == label_id).sum()
            pct_orig = ((vol_orig - vol_nnunet) / vol_nnunet * 100) if vol_nnunet > 0 else 0
            vol_orig_str = f"{vol_orig}"
        
        if 'MedSAM-NegPrompts' in results:
            vol_negprompts = (results['MedSAM-NegPrompts'] == label_id).sum()
            pct_negprompts = ((vol_negprompts - vol_nnunet) / vol_nnunet * 100) if vol_nnunet > 0 else 0
            vol_negprompts_str = f"{vol_negprompts}"
        
        if 'MedSAM-NoDil' in results:
            vol_nodil = (results['MedSAM-NoDil'] == label_id).sum()
            pct_nodil = ((vol_nodil - vol_nnunet) / vol_nnunet * 100) if vol_nnunet > 0 else 0
            vol_nodil_str = f"{vol_nodil}"
        
        print(f"{label_name:<12} {vol_nnunet:<12d} {vol_orig_str:<15} {pct_orig:+.1f}%       "
              f"{vol_negprompts_str:<20} {pct_negprompts:+.1f}%       "
              f"{vol_nodil_str:<15} {pct_nodil:+.1f}%")

# Run comparison (replace with your case ID and ground truth dir if available)
# compare_approaches("ct_1023", ground_truth_dir="ground_truth_labels")
