In [12]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from monai.networks.nets import UNet
from monai.transforms import Compose, LoadImage, EnsureChannelFirst, Resize, ToTensor
from torch.utils.data import Dataset, DataLoader
from scipy.ndimage import label
import nibabel as nib
import SimpleITK as sitk

# Paths
model_path = '2DUnet_no_augmentation_fold_1_1903.pth'
test_dir = './dataset/test_set_2d'
save_dir = './visualizations_postprocessed_no_secret'
slice_save_root = "./predicted_slices_2d_new"
os.makedirs(save_dir, exist_ok=True)
os.makedirs(slice_save_root, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------
# Post-processing helpers
# --------------------------
def keep_largest_component(pred_mask, min_size=50):
    output = np.zeros_like(pred_mask)
    for cls in np.unique(pred_mask):
        if cls == 0:
            continue
        binary_mask = (pred_mask == cls).astype(np.int32)
        labeled_array, num_features = label(binary_mask)
        if num_features == 0:
            continue
        sizes = np.bincount(labeled_array.ravel())
        sizes[0] = 0
        largest_label = sizes.argmax()
        if sizes[largest_label] >= min_size:
            largest_component = (labeled_array == largest_label)
            output[largest_component] = cls
    return output

# Z-Score normalization transform
class ZScoreNormalize:
    def __call__(self, img):
        arr = img.numpy()
        mean = np.mean(arr)
        std = np.std(arr)
        if std == 0:
            std = 1
        self.mean = mean
        self.std = std
        arr = (arr - mean) / std
        return torch.tensor(arr, dtype=torch.float32)

# Image transforms (used during training)
z_norm = ZScoreNormalize()
transforms = Compose([
    LoadImage(image_only=True),
    EnsureChannelFirst(),
    Resize(spatial_size=(352, 352)),
    ToTensor(),
    z_norm,
])

# Dataset definition
class TestDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.image_paths = []
        for patient in sorted(os.listdir(root_dir)):
            patient_path = os.path.join(root_dir, patient)
            if os.path.isdir(patient_path):
                for file in sorted(os.listdir(patient_path)):
                    if file.endswith(('.nii', '.nii.gz')) and '_gt' not in file:
                        self.image_paths.append(os.path.join(patient_path, file))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = self.transform(img_path)
        return image, img_path

# Load model
model = UNet(
    spatial_dims=2, in_channels=1, out_channels=4,
    channels=(64, 128, 256, 512, 1024, 2048),
    strides=(2, 2, 2, 2, 2),
    num_res_units=3, norm="batch", dropout=0.2
)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

# Load test data
test_dataset = TestDataset(test_dir, transform=transforms)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Inference and visualization
with torch.no_grad():
    for idx, (image, path) in enumerate(test_loader):
        image = image.to(device)
        output = model(image)
        pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
        pred = keep_largest_component(pred)

        # Patient and slice info
        patient_id = os.path.basename(os.path.dirname(path[0]))
        original_img_path = path[0]
        original_img = sitk.ReadImage(original_img_path)
        original_size = original_img.GetSize()
        original_spacing = original_img.GetSpacing()
        original_origin = original_img.GetOrigin()
        original_direction = original_img.GetDirection()

        # Save prediction as 2D NIfTI using SimpleITK, resized to original
        pred_itk = sitk.GetImageFromArray(pred.astype(np.uint8))
        pred_itk.SetOrigin((0.0, 0.0))
        pred_itk.SetSpacing((1.0, 1.0))  # assumed spacing of model output
        pred_itk.SetDirection(original_direction)

        resample = sitk.ResampleImageFilter()
        resample.SetReferenceImage(original_img)
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
        resample.SetOutputSpacing(original_spacing)
        resample.SetSize(original_size)
        resample.SetOutputOrigin(original_origin)
        resample.SetOutputDirection(original_direction)

        pred_resampled = resample.Execute(pred_itk)

        # Save prediction slice
        os.makedirs(os.path.join(slice_save_root, patient_id), exist_ok=True)
        base_name = os.path.splitext(os.path.basename(original_img_path))[0]
        slice_pred_path = os.path.join(slice_save_root, patient_id, base_name + "_pred.nii.gz")
        sitk.WriteImage(pred_resampled, slice_pred_path)

        # For visualization: restore input image
        input_img = image.squeeze(0).squeeze(0).cpu().numpy()
        restored_img = (input_img * z_norm.std) + z_norm.mean

        # Resize restored image to match original for side-by-side display
        restored_sitk = sitk.GetImageFromArray(restored_img)
        restored_sitk.SetSpacing((1.0, 1.0))
        restored_sitk.SetOrigin((0.0, 0.0))

        resample_img = sitk.ResampleImageFilter()
        resample_img.SetReferenceImage(original_img)
        resample_img.SetInterpolator(sitk.sitkLinear)
        resample_img.SetOutputSpacing(original_spacing)
        resample_img.SetSize(original_size)
        resample_img.SetOutputOrigin(original_origin)
        resample_img.SetOutputDirection(original_direction)

        img_resized = sitk.GetArrayFromImage(resample_img.Execute(restored_sitk))
        mask_resized = sitk.GetArrayFromImage(pred_resampled)

        # Save side-by-side visualization
        fig, axs = plt.subplots(2, 1, figsize=(6, 6))
        axs[0].imshow(img_resized, cmap="gray")
        axs[0].set_title("Restored Input Image")
        axs[0].axis("off")
        axs[1].imshow(mask_resized, cmap="jet")
        axs[1].set_title("Resampled Prediction")
        axs[1].axis("off")

        plt.tight_layout()
        save_path = os.path.join(save_dir, base_name + "_postprocessed.png")
        plt.savefig(save_path)
        plt.close()

        print(f"Saved: {save_path}")


Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice00_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice01_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice02_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice03_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice04_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice05_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice06_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice07_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice08_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_frame01.nii_slice09_postprocessed.png
Saved: ./visualizations_postprocessed_secret/patient151_fram

In [14]:
import os
import numpy as np
import SimpleITK as sitk
from natsort import natsorted
from collections import defaultdict
import re  # For robust slice index extraction

# Paths
slices_dir = './predicted_slices_2d_new'
reference_dir = './dataset/secret_test_set_extracted'
output_dir = './final_segmentations'
os.makedirs(output_dir, exist_ok=True)

# Group slices by patient
patient_slices = defaultdict(list)
for patient in os.listdir(slices_dir):
    patient_path = os.path.join(slices_dir, patient)
    if not os.path.isdir(patient_path):
        continue
    for file in natsorted(os.listdir(patient_path)):
        if file.endswith(('.nii.gz', '.nii')) and '_pred' in file:
            full_path = os.path.join(patient_path, file)
            patient_slices[patient].append(full_path)

# Reconstruct volumes
for patient, slice_paths in patient_slices.items():
    print(f"\nReconstructing 3D mask for {patient} ...")

    # Load reference to determine correct number of slices and metadata
    patient_folder = os.path.join(reference_dir, patient)
    reference_img_path = [f for f in os.listdir(patient_folder) if 'frame' in f and f.endswith(('.nii.gz', '.nii'))][0]
    reference_img = sitk.ReadImage(os.path.join(patient_folder, reference_img_path))
    ref_size = reference_img.GetSize()  # (x, y, z)

    # Prepare empty volume: shape (z, y, x)
    volume_np = np.zeros((ref_size[2], ref_size[1], ref_size[0]), dtype=np.uint8)

    # Fill in predictions
    for path in slice_paths:
        filename = os.path.basename(path)

        # ✅ Use regex to extract slice index from filename safely
        match = re.search(r'_slice(\d+)', filename)
        if not match:
            print(f"Warning: could not parse slice index from {filename}")
            continue

        slice_idx = int(match.group(1))
        if slice_idx >= ref_size[2]:
            print(f"Warning: slice index {slice_idx} exceeds volume depth {ref_size[2]} in {patient}")
            continue

        # Read predicted slice
        slice_arr = sitk.GetArrayFromImage(sitk.ReadImage(path))  # shape: (y, x)
        volume_np[slice_idx] = slice_arr

    # Convert back to SimpleITK and assign correct metadata
    final_mask = sitk.GetImageFromArray(volume_np)  # shape: (z, y, x)
    final_mask.CopyInformation(reference_img)

    # Save with the exact same filename as original
    output_path = os.path.join(output_dir, reference_img_path)
    sitk.WriteImage(final_mask, output_path)
    print(f"Saved: {output_path}")



Reconstructing 3D mask for patient151 ...
Saved: ./final_segmentations/patient151_frame01.nii.gz

Reconstructing 3D mask for patient152 ...
Saved: ./final_segmentations/patient152_frame01.nii.gz

Reconstructing 3D mask for patient153 ...
Saved: ./final_segmentations/patient153_frame01.nii.gz

Reconstructing 3D mask for patient154 ...
Saved: ./final_segmentations/patient154_frame01.nii.gz

Reconstructing 3D mask for patient155 ...
Saved: ./final_segmentations/patient155_frame01.nii.gz

Reconstructing 3D mask for patient156 ...
Saved: ./final_segmentations/patient156_frame01.nii.gz

Reconstructing 3D mask for patient157 ...
Saved: ./final_segmentations/patient157_frame01.nii.gz

Reconstructing 3D mask for patient158 ...
Saved: ./final_segmentations/patient158_frame01.nii.gz

Reconstructing 3D mask for patient159 ...
Saved: ./final_segmentations/patient159_frame01.nii.gz

Reconstructing 3D mask for patient160 ...
Saved: ./final_segmentations/patient160_frame01.nii.gz

Reconstructing 3D m