In [7]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import os
import ipywidgets as widgets
from ipywidgets import interact

# Directories
mri_dir = "/content/drive/MyDrive/nnunet_inputs/imagesTs"
gt_dir = "/content/drive/MyDrive/mask_niftis"   # <-- ground truth nii
pred_dir = "/content/drive/MyDrive/SAM_results"     # <-- predictions nii

def load_nifti(path):
    return nib.load(path).get_fdata()

# Collect filenames
mri_files = sorted([f for f in os.listdir(mri_dir) if f.endswith(".nii.gz")])
gt_files = sorted([f for f in os.listdir(gt_dir) if f.endswith(".nii.gz")])
pred_files = sorted([f for f in os.listdir(pred_dir) if f.endswith(".nii.gz")])

# Map stems to filenames
def strip_nii_gz(filename):
    if filename.endswith(".nii.gz"):
        return filename[:-7]  # strip both .nii.gz
    return os.path.splitext(filename)[0]

def normalize_key(filename):
    key = strip_nii_gz(filename)
    # remove nnU-Net modality suffixes (_0000, _0001, etc.)
    if key.endswith(("_0000", "_0001", "_0002", "_0003")):
        key = key[:-5]
    if key.endswith("_sam_predictions"):
      key = key[:-16]
    return key

mri_stems  = {normalize_key(f): f for f in mri_files}
gt_stems   = {normalize_key(f): f for f in gt_files}
pred_stems = {normalize_key(f): f for f in pred_files}

# Match common patients
patients = sorted(set(mri_stems.keys()) & set(gt_stems.keys()) & set(pred_stems.keys()))

def view_patient(patient):
    mri = load_nifti(os.path.join(mri_dir, mri_stems[patient]))
    gt = load_nifti(os.path.join(gt_dir, gt_stems[patient]))
    pred = load_nifti(os.path.join(pred_dir, pred_stems[patient]))

    def browse_slices(idx):
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        # MRI only
        axs[0].imshow(mri[:, :, idx].T, cmap="gray", origin="lower")
        axs[0].set_title("MRI")
        axs[0].axis("off")

        # MRI + Ground Truth
        axs[1].imshow(mri[:, :, idx].T, cmap="gray", origin="lower")
        axs[1].imshow(gt[:, :, idx].T, cmap="Reds", alpha=0.4, origin="lower")
        axs[1].set_title("Ground Truth")
        axs[1].axis("off")

        # MRI + Prediction
        axs[2].imshow(mri[:, :, idx].T, cmap="gray", origin="lower")
        axs[2].imshow(pred[:, :, idx].T, cmap="Blues", alpha=0.4, origin="lower")
        axs[2].set_title("Prediction")
        axs[2].axis("off")

        fig.suptitle(f"Patient: {patient} | Slice {idx}/{mri.shape[2]-1}")
        plt.show()

    interact(
        browse_slices,
        idx=widgets.IntSlider(min=0, max=mri.shape[2]-1, step=1, value=mri.shape[2]//2)
    )

# Dropdown for patient selection
widgets.interact(view_patient, patient=widgets.Dropdown(options=patients, description="Patient"))



interactive(children=(IntSlider(value=18, description='idx', max=35), Output()), _dom_classes=('widget-interac…