# Interactive Heart Segmentation Visualization

This notebook allows you to visualize model predictions with increasing detail:
1.  **Validation Set**: Standard comparison (GT vs Pred).
2.  **Test Set**: Inference on unseen data.
3.  **Extended Visualization**: Heatmaps (Probabilities) and Contour analysis.

In [6]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, Spacingd, ToTensord, DivisiblePadd, CropForegroundd
)
from pathlib import Path
from ipywidgets import interact, IntSlider

## 1. Setup & Load Model

In [7]:
# Determine paths
BASE_DIR = Path(os.getcwd())
DATASET_DIR = BASE_DIR / "datasets"
RESULTS_DIR = BASE_DIR / "results"
MODEL_PATH = RESULTS_DIR / "best_metric_model.pth"

print(f"Model path: {MODEL_PATH}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize Model
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
    dropout=0.2,
).to(device)

# Load Weights
if MODEL_PATH.exists():
    state_dict = torch.load(MODEL_PATH, map_location=device)
    if 'model_state_dict' in state_dict:
        model.load_state_dict(state_dict['model_state_dict'])
    else:
        model.load_state_dict(state_dict)
    print("Model weights loaded successfully.")
    model.eval()
else:
    print("WARNING: Model file not found!")

Model path: c:\Users\swapn\code\AI Healthcare Imaging\results\best_metric_model.pth
Using device: cpu
Model weights loaded successfully.


---

## 2. Part A: Validation Set (With Ground Truth)
Standard view: Comparing Ground Truth vs Prediction masks.

In [8]:
PIXDIM = (1.5, 1.5, 1.0)

# Transforms for Validation (Includes Label)
val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=PIXDIM, mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS", labels=None),
    NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    DivisiblePadd(keys=["image", "label"], k=16),
    ToTensord(keys=["image", "label"]),
])

# Load Validation Files (Last 20% of training data)
images_tr = sorted([str(p) for p in (DATASET_DIR / "imagesTr").glob("*.nii")])
labels_tr = sorted([str(p) for p in (DATASET_DIR / "labelsTr").glob("*.nii")])

data_tr = [{"image": i, "label": l} for i, l in zip(images_tr, labels_tr)]
val_files = data_tr[int(len(data_tr)*0.8):]

print(f"Validation samples: {len(val_files)}")

Validation samples: 4


In [None]:
def visualize_validation(sample_idx=0):
    if sample_idx >= len(val_files):
        return

    # print(f"Loading Validation Sample {sample_idx}: {Path(val_files[sample_idx]['image']).name}")
    data = val_transforms(val_files[sample_idx])
    image = data["image"].unsqueeze(0).to(device)
    label = data["label"].unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        pred = (torch.sigmoid(model(image)) > 0.5).float()

    img_vol = image[0, 0].cpu().numpy()
    lbl_vol = label[0, 0].cpu().numpy()
    pred_vol = pred[0, 1].cpu().numpy()

    def plot(slice_idx):
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        # 1. MRI
        axes[0].imshow(img_vol[:, :, slice_idx], cmap='gray')
        axes[0].set_title(f"MRI Slice {slice_idx}")
        axes[0].axis('off')
        # 2. GT
        axes[1].imshow(lbl_vol[:, :, slice_idx], cmap='gray')
        axes[1].set_title("Ground Truth")
        axes[1].axis('off')
        # 3. Pred
        axes[2].imshow(pred_vol[:, :, slice_idx], cmap='gray')
        axes[2].set_title("Prediction")
        axes[2].axis('off')
        # 4. Overlay
        axes[3].imshow(img_vol[:, :, slice_idx], cmap='gray')
        overlay = np.zeros((*img_vol.shape[:2], 4))
        overlay[lbl_vol[:, :, slice_idx] > 0, 1] = 1.0 # Green=GT
        overlay[lbl_vol[:, :, slice_idx] > 0, 3] = 0.3
        overlay[pred_vol[:, :, slice_idx] > 0, 0] = 1.0 # Red=Pred
        overlay[pred_vol[:, :, slice_idx] > 0, 3] = 0.5
        axes[3].imshow(overlay)
        axes[3].set_title("Overlay (G=GT, R=Pred)")
        axes[3].axis('off')
        plt.show()

    interact(plot, slice_idx=IntSlider(min=0, max=img_vol.shape[2]-1, step=1, value=img_vol.shape[2]//2))

print("Visualizing Validation Set (Images with Labels)...")
interact(visualize_validation, sample_idx=IntSlider(min=0, max=len(val_files)-1, step=1, value=0))

Visualizing Validation Set (Images with Labels)...


interactive(children=(IntSlider(value=0, description='sample_idx', max=3), Output()), _dom_classes=('widget-in…

<function __main__.visualize_validation(sample_idx=0)>

---

## 3. Part B: Test Set (Inference Only)
For unseen images with no labels.

In [10]:
# Transforms for Test (Image Only)
test_transforms = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    Spacingd(keys=["image"], pixdim=PIXDIM, mode=("bilinear")),
    Orientationd(keys=["image"], axcodes="RAS"),
    NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
    CropForegroundd(keys=["image"], source_key="image"),
    DivisiblePadd(keys=["image"], k=16),
    ToTensord(keys=["image"]),
])

# Load Test Files
images_ts = sorted([str(p) for p in (DATASET_DIR / "imagesTs").glob("*.nii")])
test_files = [{"image": i} for i in images_ts]

print(f"Test samples: {len(test_files)}")

Test samples: 10




In [11]:
def visualize_test(sample_idx=0):
    if sample_idx >= len(test_files):
        return

    # print(f"Loading Test Sample {sample_idx}: {Path(test_files[sample_idx]['image']).name}")
    data = test_transforms(test_files[sample_idx])
    image = data["image"].unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        pred = (torch.sigmoid(model(image)) > 0.5).float()

    img_vol = image[0, 0].cpu().numpy()
    pred_vol = pred[0, 1].cpu().numpy()

    def plot(slice_idx):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img_vol[:, :, slice_idx], cmap='gray')
        axes[0].set_title(f"MRI Slice {slice_idx}")
        axes[0].axis('off')
        
        axes[1].imshow(pred_vol[:, :, slice_idx], cmap='gray')
        axes[1].set_title("Predicted Mask")
        axes[1].axis('off')
        
        axes[2].imshow(img_vol[:, :, slice_idx], cmap='gray')
        overlay = np.zeros((*img_vol.shape[:2], 4))
        overlay[pred_vol[:, :, slice_idx] > 0, 0] = 1.0 # Red=Pred
        overlay[pred_vol[:, :, slice_idx] > 0, 3] = 0.5
        axes[2].imshow(overlay)
        axes[2].set_title("Prediction Overlay")
        axes[2].axis('off')
        plt.show()

    interact(plot, slice_idx=IntSlider(min=0, max=img_vol.shape[2]-1, step=1, value=img_vol.shape[2]//2))

print("Visualizing Test Set (Unseen, No Labels)...")
interact(visualize_test, sample_idx=IntSlider(min=0, max=len(test_files)-1, step=1, value=0))

Visualizing Test Set (Unseen, No Labels)...


interactive(children=(IntSlider(value=0, description='sample_idx', max=9), Output()), _dom_classes=('widget-in…

<function __main__.visualize_test(sample_idx=0)>

---

## 4. Part C: Extended Visualization
**Advanced Analysis**: Heatmaps (Probabilities) and Contours.
This helps identify "uncertain" areas where the model is confident but not perfectly binary.

In [12]:
def visualize_extended(sample_idx=0):
    if sample_idx >= len(val_files):
        return

    # print(f"Loading Sample {sample_idx} for Extended Viz...")
    data = val_transforms(val_files[sample_idx])
    image = data["image"].unsqueeze(0).to(device)
    label = data["label"].unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        logits = model(image)
        probs = torch.sigmoid(logits)
        pred = (probs > 0.5).float()

    img_vol = image[0, 0].cpu().numpy()
    lbl_vol = label[0, 0].cpu().numpy()
    prob_vol = probs[0, 1].cpu().numpy() # Probability map
    
    def plot(slice_idx):
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        # 1. Probability Heatmap
        # Shows model confidence (0=Black, 1=Hot/White)
        im1 = axes[0].imshow(prob_vol[:, :, slice_idx], cmap='inferno', vmin=0, vmax=1)
        axes[0].set_title("Probability Heatmap (Model Confidence)")
        axes[0].axis('off')
        plt.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)
        
        # 2. Contour Plot
        # Overlay Lines: Green=GT, Red=Pred
        axes[1].imshow(img_vol[:, :, slice_idx], cmap='gray')
        # Ground Truth Contour
        if np.any(lbl_vol[:, :, slice_idx]):
            axes[1].contour(lbl_vol[:, :, slice_idx], levels=[0.5], colors='lime', linewidths=2, label='GT')
        # Prediction Contour
        if np.any(pred[:, 0, :, :, slice_idx].cpu().numpy()):
             axes[1].contour(pred[:, 0, :, :, slice_idx].cpu().numpy(), levels=[0.5], colors='red', linewidths=2, linestyles='dashed', label='Pred')
        axes[1].set_title("Contour Overlay (Green=GT, Red=Pred)")
        axes[1].axis('off')

        # 3. Enhanced Overlay (Hybrid)
        axes[2].imshow(img_vol[:, :, slice_idx], cmap='gray')
        axes[2].imshow(prob_vol[:, :, slice_idx], cmap='jet', alpha=0.3) # Weak overlay of probability
        if np.any(lbl_vol[:, :, slice_idx]):
            axes[2].contour(lbl_vol[:, :, slice_idx], levels=[0.5], colors='white', linewidths=1.5)
        axes[2].set_title("Hybrid (MRI + Heatmap + GT Contour)")
        axes[2].axis('off')

        plt.show()

    interact(plot, slice_idx=IntSlider(min=0, max=img_vol.shape[2]-1, step=1, value=img_vol.shape[2]//2))

print("Extended Visualization (Heatmaps & Contours)...")
interact(visualize_extended, sample_idx=IntSlider(min=0, max=len(val_files)-1, step=1, value=0))

Extended Visualization (Heatmaps & Contours)...


interactive(children=(IntSlider(value=0, description='sample_idx', max=3), Output()), _dom_classes=('widget-in…

<function __main__.visualize_extended(sample_idx=0)>