In [None]:
# nnUNetv2_predict \
#   -i /path/to/nnUNet_raw/DatasetXXX/imagesTs \
#   -o /path/to/nnunet_output \
#   -d DatasetXXX \
#   -c 3d_fullres \
#   -f all \
#   --save_probabilities



In [None]:
!python3 nnunet_to_medsam2_prompts.py \
    --nnunet_output nnunet_output \
    --ct_dir ct_images \
    --out_dir prompts_out \
    --dilation_iters 3 \
    --bbox_padding 5


In [None]:
# Minimal MedSAM2 inference: refine key-slice masks only
# Uses prompts from nnunet_to_medsam2_prompts.py and writes refined masks to medsam2_results
!python3 medsam2_infer_3D_CT_minimal.py \
    --checkpoint MedSAM2/checkpoints/MedSAM2_CTLesion.pt \
    --cfg configs/sam2.1_hiera_t512.yaml \
    --prompts_dir prompts_out \
    --output_dir medsam2_results

In [None]:
# Stitch individual MedSAM2 segmentations into combined multi-label files
!python3 stitch_medsam2_segmentations.py \
    --masks_dir medsam2_results \
    --output_dir medsam2_results \
    --reference_dir ct_images


## Alternative: Negative Prompts Workflow

The following cells use positive and negative mask prompts for more constrained refinement.


In [None]:
# Generate positive and negative mask prompts from nnUNet segmentations
# This creates separate positive (target label) and negative (other labels) masks
!python3 nnunet_to_medsam2_prompts_masks.py \
    --nnunet_output nnunet_output \
    --ct_dir ct_images \
    --out_dir prompts_out_masks \
    --dilation_iters 3


In [15]:
# MedSAM2 inference with positive and negative prompts
# Uses positive mask to segment target, negative point prompts to exclude other labels
!python3 medsam2_infer_3D_CT_negprompts.py \
    --checkpoint MedSAM2/checkpoints/MedSAM2_CTLesion.pt \
    --cfg configs/sam2.1_hiera_t512.yaml \
    --prompts_dir prompts_out_masks \
    --output_dir medsam2_results_negprompts


Using device: cpu for MedSAM2 inference with negative prompts
Found 66 case directories with prompts.json
Cases:   0%|                                             | 0/66 [00:00<?, ?it/s]
Processing case ct_1083
  Label LV: pos=prompts_out_masks/ct_1083.nii/LV_positive_mask.nii.gz, neg=prompts_out_masks/ct_1083.nii/LV_negative_mask.nii.gz
    Refining 93 slices with negative prompts (z=37..129)

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(
    Saved refined mask with negative prompts: medsam2_results_negprompts/ct_1083_LV_negprompts_mask.nii.gz
      Non-zero voxels: 516034, Unique values: [0 1]
  Label RV: pos=prompts_out_masks/ct_1083.nii/RV_positive_mask.nii.gz, neg=prompts_out_masks/ct_1083

In [17]:
# Stitch individual negative prompts masks into combined multi-label files
# Note: --use_negprompts flag tells the script to look for *_negprompts_mask.nii.gz files
!python3 stitch_medsam2_segmentations.py \
    --masks_dir medsam2_results_negprompts \
    --output_dir medsam2_results_negprompts \
    --reference_dir ct_images \
    --use_negprompts


Found 22 cases to process: ['ct_1002', 'ct_1005', 'ct_1010', 'ct_1011', 'ct_1016', 'ct_1023', 'ct_1028', 'ct_1033', 'ct_1035', 'ct_1036', 'ct_1042', 'ct_1044', 'ct_1046', 'ct_1050', 'ct_1054', 'ct_1059', 'ct_1060', 'ct_1083', 'ct_1092', 'ct_1119', 'ct_1135', 'ct_1138']
Processing cases:   0%|                                  | 0/22 [00:00<?, ?it/s]
Processing case: ct_1002
  Found 7 mask files:
    1: LV - ct_1002_LV_negprompts_mask.nii.gz
    2: RV - ct_1002_RV_negprompts_mask.nii.gz
    3: LA - ct_1002_LA_negprompts_mask.nii.gz
    4: RA - ct_1002_RA_negprompts_mask.nii.gz
    5: Myo - ct_1002_Myo_negprompts_mask.nii.gz
    6: Aorta - ct_1002_Aorta_negprompts_mask.nii.gz
    7: Pulmonary - ct_1002_Pulmonary_negprompts_mask.nii.gz
  Using first mask as reference (affine only): medsam2_results_negprompts/ct_1002_Aorta_negprompts_mask.nii.gz
    Label 1 (LV): 691523 voxels
    Label 2 (RV): 635256 voxels
    Label 3 (LA): 1441263 voxels
    Label 4 (RA): 1422775 voxels
    Label 5 (Myo)

## Approach 2: No-Dilation with Eroded Negatives

This approach fixes the "fattening" problem by:
1. Not dilating positive masks (or minimal dilation)
2. ERODING negative masks (creates safety buffer)
3. Using higher threshold (0.5 instead of 0.0)
4. Optional post-processing erosion

**Test Mode**: Set `TEST_CASE_ID` below to process only one case for testing. Set to `None` to process all cases.


In [28]:
# Step 1: Generate Prompts (No-Dilation)
# Positive masks: no dilation (or minimal 1 iteration)
# Negative masks: eroded (2 iterations) instead of dilated

# TEST MODE: Set to a case ID (e.g., "ct_1023") to process only one case, or None for all cases
TEST_CASE_ID = None  # Change to None to process all cases

# Build command
cmd = f"""python3 nnunet_to_medsam2_prompts_nodilation.py \\
    --nnunet_output nnunet_output \\
    --ct_dir ct_images \\
    --out_dir prompts_nodilation \\
    --positive_dilation 0 \\
    --negative_erosion 2"""

if TEST_CASE_ID:
    cmd += f" \\\n    --case_id {TEST_CASE_ID}"
    print(f"TEST MODE: Processing only case {TEST_CASE_ID}\n")
else:
    print("Processing all cases\n")

!{cmd}


Processing all cases

Found 33 segmentation files
Positive mask dilation: 0 iterations
Negative mask erosion: 2 iterations
Output directory: prompts_nodilation
Processing ct_1002...
  LV: pos=505002 voxels (dilation=0), neg=4104032 voxels (erosion=2)
  RV: pos=454363 voxels (dilation=0), neg=4159341 voxels (erosion=2)
  LA: pos=1006135 voxels (dilation=0), neg=3911759 voxels (erosion=2)
  RA: pos=1052058 voxels (dilation=0), neg=3834453 voxels (erosion=2)
  Myo: pos=1697319 voxels (dilation=0), neg=3057589 voxels (erosion=2)
  Aorta: pos=445588 voxels (dilation=0), neg=4406010 voxels (erosion=2)
  Pulmonary: pos=551410 voxels (dilation=0), neg=4330326 voxels (erosion=2)
Processing ct_1003...
  LV: pos=395505 voxels (dilation=0), neg=4612546 voxels (erosion=2)
  RV: pos=436344 voxels (dilation=0), neg=4555796 voxels (erosion=2)
  LA: pos=1202805 voxels (dilation=0), neg=4116691 voxels (erosion=2)
  RA: pos=1410493 voxels (dilation=0), neg=3901002 voxels (erosion=2)
  Myo: pos=1880922 vo

In [29]:
# Step 2: Run MedSAM2 Inference (No-Dilation)
# Uses threshold=0.5 (instead of 0.0) to prevent fattening
# Optional: add --post_erosion flag if masks still too fat

# Build command
cmd = f"""python3 medsam2_infer_3D_CT_nodilation.py \\
    --checkpoint MedSAM2/checkpoints/MedSAM2_CTLesion.pt \\
    --cfg configs/sam2.1_hiera_t512.yaml \\
    --prompts_dir prompts_nodilation \\
    --output_dir medsam2_results_nodilation \\
    --threshold 0.5"""

if TEST_CASE_ID:
    cmd += f" \\\n    --case_id {TEST_CASE_ID}"

# Uncomment the following lines if masks are still too fat:
# cmd += " \\\n    --post_erosion"
# cmd += " \\\n    --erosion_iters 1"

!{cmd}


Using device: cpu for MedSAM2 inference (no-dilation)
Found 33 case directories with prompts.json
Threshold: 0.5
Output directory: medsam2_results_nodilation
Cases:   0%|                                             | 0/33 [00:00<?, ?it/s]
Processing case ct_1138
  Label LV: pos=prompts_nodilation/ct_1138/LV_positive_mask.nii.gz, neg=prompts_nodilation/ct_1138/LV_negative_mask.nii.gz
    Refining 72 slices with no-dilation approach (z=32..103)

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(
    Saved refined mask (no-dilation): medsam2_results_nodilation/ct_1138_LV_nodilation_mask.nii.gz
      Non-zero voxels: 325787, Unique values: [0 1]
  Label RV: pos=prompts_nodilation/ct_1138/RV_positive_mas

In [30]:
# Step 3: Stitch Results (No-Dilation)

# Build command
cmd = """python3 stitch_medsam2_segmentations_nodilation.py \\
    --masks_dir medsam2_results_nodilation \\
    --output_dir medsam2_results_nodilation \\
    --reference_dir ct_images"""

if TEST_CASE_ID:
    cmd += f" \\\n    --case_id {TEST_CASE_ID}"

!{cmd}


Found 33 cases to process: ['ct_1002', 'ct_1003', 'ct_1004', 'ct_1005', 'ct_1010', 'ct_1011', 'ct_1014', 'ct_1016', 'ct_1023', 'ct_1028', 'ct_1030', 'ct_1033', 'ct_1035', 'ct_1036', 'ct_1042', 'ct_1044', 'ct_1046', 'ct_1048', 'ct_1050', 'ct_1054', 'ct_1059', 'ct_1060', 'ct_1064', 'ct_1070', 'ct_1083', 'ct_1092', 'ct_1105', 'ct_1114', 'ct_1119', 'ct_1135', 'ct_1138', 'ct_1145', 'ct_1150']
Processing cases:   0%|                                  | 0/33 [00:00<?, ?it/s]
Processing case: ct_1002
  Found 7 mask files:
    1: LV - ct_1002_LV_nodilation_mask.nii.gz
    2: RV - ct_1002_RV_nodilation_mask.nii.gz
    3: LA - ct_1002_LA_nodilation_mask.nii.gz
    4: RA - ct_1002_RA_nodilation_mask.nii.gz
    5: Myo - ct_1002_Myo_nodilation_mask.nii.gz
    6: Aorta - ct_1002_Aorta_nodilation_mask.nii.gz
    7: Pulmonary - ct_1002_Pulmonary_nodilation_mask.nii.gz
  Using first mask as reference (affine only): medsam2_results_nodilation/ct_1002_Myo_nodilation_mask.nii.gz
    Label 1 (LV): 500011 vox

## Compare Results

Now you have:
- Original nnUNet: `nnunet_output/{case}_seg.nii.gz` (or `nnunet_output/{case}.nii.gz`)
- Original MedSAM2 (with dilation): `medsam2_results/{case}_medsamrefined_seg.nii.gz`
- MedSAM2 with negative prompts: `medsam2_results_negprompts/{case}_medsamrefined_negprompts_seg.nii.gz`
- New MedSAM2 (no dilation): `medsam2_results_nodilation/{case}_seg_nodilation.nii.gz`


In [None]:
# Calculate Dice Scores for Comparison
import SimpleITK as sitk
import numpy as np
from pathlib import Path

def calculate_dice(pred, gt):
    """Calculate Dice coefficient."""
    intersection = np.logical_and(pred, gt).sum()
    union = pred.sum() + gt.sum()
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    return 2.0 * intersection / union

def compare_approaches(case_id, ground_truth_dir=None):
    """Compare all approaches on one case."""
    
    label_names = {1: "LV", 2: "RV", 3: "LA", 4: "RA", 
                   5: "Myo", 6: "Aorta", 7: "Pulmonary"}
    
    # Load predictions
    nnunet_path = Path("nnunet_output") / f"{case_id}.nii.gz"
    if not nnunet_path.exists():
        nnunet_path = Path("nnunet_output") / f"{case_id}_seg.nii.gz"
    
    medsam_orig_path = Path("medsam2_results") / f"{case_id}_medsamrefined_seg.nii.gz"
    medsam_negprompts_path = Path("medsam2_results_negprompts") / f"{case_id}_medsamrefined_negprompts_seg.nii.gz"
    medsam_nodil_path = Path("medsam2_results_nodilation") / f"{case_id}_seg_nodilation.nii.gz"
    
    if not nnunet_path.exists():
        print(f"nnUNet output not found: {nnunet_path}")
        return
    
    nnunet = sitk.GetArrayFromImage(sitk.ReadImage(str(nnunet_path)))
    
    results = {}
    if medsam_orig_path.exists():
        results['MedSAM-Orig'] = sitk.GetArrayFromImage(sitk.ReadImage(str(medsam_orig_path)))
    if medsam_negprompts_path.exists():
        results['MedSAM-NegPrompts'] = sitk.GetArrayFromImage(sitk.ReadImage(str(medsam_negprompts_path)))
    if medsam_nodil_path.exists():
        results['MedSAM-NoDil'] = sitk.GetArrayFromImage(sitk.ReadImage(str(medsam_nodil_path)))
    
    if ground_truth_dir and Path(ground_truth_dir).exists():
        gt_path = Path(ground_truth_dir) / f"{case_id}_seg.nii.gz"
        if gt_path.exists():
            gt = sitk.GetArrayFromImage(sitk.ReadImage(str(gt_path)))
            print(f"\n{'Label':<12} {'nnUNet':<10} {'MedSAM-Orig':<15} {'MedSAM-NegPrompts':<20} {'MedSAM-NoDil':<15} {'Best':<12}")
            print("-" * 95)
            
            for label_id, label_name in label_names.items():
                dice_nnunet = calculate_dice(nnunet == label_id, gt == label_id)
                dice_scores = {'nnUNet': dice_nnunet}
                
                for method_name, pred_array in results.items():
                    dice_scores[method_name] = calculate_dice(pred_array == label_id, gt == label_id)
                
                best_method = max(dice_scores.items(), key=lambda x: x[1] if isinstance(x[1], (int, float)) else 0)
                best_str = f"{best_method[0]}: {best_method[1]:.4f}"
                
                print(f"{label_name:<12} {dice_nnunet:.4f}    "
                      f"{dice_scores.get('MedSAM-Orig', 'N/A'):<15} "
                      f"{dice_scores.get('MedSAM-NegPrompts', 'N/A'):<20} "
                      f"{dice_scores.get('MedSAM-NoDil', 'N/A'):<15} "
                      f"{best_str:<12}")
    
    # Volume comparison
    print(f"\n{'Label':<12} {'nnUNet':<12} {'MedSAM-Orig':<15} {'% Change':<12} "
          f"{'MedSAM-NegPrompts':<20} {'% Change':<12} {'MedSAM-NoDil':<15} {'% Change':<12}")
    print("-" * 120)
    
    for label_id, label_name in label_names.items():
        vol_nnunet = (nnunet == label_id).sum()
        pct_orig = 0
        pct_negprompts = 0
        pct_nodil = 0
        
        vol_orig_str = "N/A"
        vol_negprompts_str = "N/A"
        vol_nodil_str = "N/A"
        
        if 'MedSAM-Orig' in results:
            vol_orig = (results['MedSAM-Orig'] == label_id).sum()
            pct_orig = ((vol_orig - vol_nnunet) / vol_nnunet * 100) if vol_nnunet > 0 else 0
            vol_orig_str = f"{vol_orig}"
        
        if 'MedSAM-NegPrompts' in results:
            vol_negprompts = (results['MedSAM-NegPrompts'] == label_id).sum()
            pct_negprompts = ((vol_negprompts - vol_nnunet) / vol_nnunet * 100) if vol_nnunet > 0 else 0
            vol_negprompts_str = f"{vol_negprompts}"
        
        if 'MedSAM-NoDil' in results:
            vol_nodil = (results['MedSAM-NoDil'] == label_id).sum()
            pct_nodil = ((vol_nodil - vol_nnunet) / vol_nnunet * 100) if vol_nnunet > 0 else 0
            vol_nodil_str = f"{vol_nodil}"
        
        print(f"{label_name:<12} {vol_nnunet:<12d} {vol_orig_str:<15} {pct_orig:+.1f}%       "
              f"{vol_negprompts_str:<20} {pct_negprompts:+.1f}%       "
              f"{vol_nodil_str:<15} {pct_nodil:+.1f}%")

# Run comparison (replace with your case ID and ground truth dir if available)
# compare_approaches("ct_1023", ground_truth_dir="ground_truth_labels")
