In [6]:
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from pathlib import Path

def find_slice_with_largest_segmentation(mask, axis):
    """
    Find the slice along the given axis with the largest segmentation area
    
    Args:
        mask: 3D numpy array of segmentation mask
        axis: axis along which to find the slice (0, 1, or 2)
    
    Returns:
        slice_idx: index of the slice with largest segmentation
    """
    # Sum across the other two axes to get area per slice
    slice_areas = np.sum(mask > 0, axis=tuple([i for i in range(3) if i != axis]))
    
    # Find slice with maximum area
    slice_idx = np.argmax(slice_areas)
    
    return slice_idx

def visualize_case(image_path, label_path, output_path=None):
    """
    Visualize a single case with 1x3 subplots showing volume, mask, and overlay for last dim
    
    Args:
        image_path: path to image NIfTI file
        label_path: path to label NIfTI file
        output_path: optional path to save figure
    """
    # Load data
    img_nii = nib.load(image_path)
    label_nii = nib.load(label_path)
    
    img_data = img_nii.get_fdata()
    label_data = label_nii.get_fdata()
    
    # Find slice with largest segmentation for last dimension
    last_dim = len(img_data.shape) - 1
    slice_idx = find_slice_with_largest_segmentation(label_data, axis=last_dim)
    
    # Extract slice from last dimension
    img_slice = img_data[:, :, slice_idx]
    mask_slice = label_data[:, :, slice_idx]
    
    # Create figure with 1x3 subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    case_name = os.path.basename(image_path).replace('_0000.nii.gz', '')
    fig.suptitle(f'{case_name} - Dim{last_dim} Plane (slice {slice_idx})', fontsize=16, fontweight='bold')
    
    # Plot 1: Volume only
    axes[0].imshow(img_slice, cmap='gray')
    axes[0].set_title('Volume', fontsize=12)
    axes[0].axis('off')
    
    # Plot 2: Mask only
    axes[1].imshow(mask_slice, cmap='hot')
    axes[1].set_title('Mask', fontsize=12)
    axes[1].axis('off')
    
    # Plot 3: Volume with mask overlay
    axes[2].imshow(img_slice, cmap='gray')
    mask_overlay = np.ma.masked_where(mask_slice == 0, mask_slice)
    axes[2].imshow(mask_overlay, cmap='hot', alpha=0.5)
    axes[2].set_title('Overlay', fontsize=12)
    axes[2].axis('off')
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

def visualize_dataset(dataset_dir, output_dir, max_cases=None):
    """
    Visualize all cases in the dataset
    
    Args:
        dataset_dir: path to dataset directory
        output_dir: directory to save visualizations
        max_cases: optional limit on number of cases to visualize
    """
    images_dir = os.path.join(dataset_dir, 'imagesTr')
    labels_dir = os.path.join(dataset_dir, 'labelsTr')
    
    # Get all image files
    image_files = sorted([f for f in os.listdir(images_dir) if f.endswith('_0000.nii.gz')])
    
    if max_cases:
        image_files = image_files[:max_cases]
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Visualizing {len(image_files)} cases from {dataset_dir}...")
    print(f"Output directory: {output_dir}")
    
    for img_file in image_files:
        # Get corresponding label file
        case_id = img_file.replace('_0000.nii.gz', '')
        label_file = f'{case_id}.nii.gz'
        
        image_path = os.path.join(images_dir, img_file)
        label_path = os.path.join(labels_dir, label_file)
        
        # Check if label exists
        if not os.path.exists(label_path):
            print(f"Warning: Label not found for {img_file}")
            continue
        
        # Set output path
        output_path = os.path.join(output_dir, f'{case_id}_visualization.png')
        
        try:
            visualize_case(image_path, label_path, output_path)
            print(f"Processed: {case_id}")
        except Exception as e:
            print(f"Error processing {case_id}: {str(e)}")

def main():
    """Main visualization function"""
    
    # # Visualize TOPCOW dataset
    # topcow_dir = r"E:\kaggle-rsna-data_processing3\Dataset500_Topcow"
    # topcow_output = "./visualize_Topcow"
    
    # if os.path.exists(topcow_dir):
    #     print("="*60)
    #     print("Processing TOPCOW Dataset")
    #     print("="*60)
    #     visualize_dataset(
    #         dataset_dir=topcow_dir,
    #         output_dir=topcow_output,
    #         max_cases=None  # Set to None to visualize all cases
    #     )
    # else:
    #     print(f"TOPCOW directory not found: {topcow_dir}")
    
    # Visualize IAD dataset (if exists)
    iad_dir = r"E:\kaggle-rsna-data_processing3\Dataset501_IAD"  # Adjust path as needed
    iad_output = "./visualize_IAD"
    
    if os.path.exists(iad_dir):
        print("\n" + "="*60)
        print("Processing IAD Dataset")
        print("="*60)
        visualize_dataset(
            dataset_dir=iad_dir,
            output_dir=iad_output,
            max_cases=None  # Set to None to visualize all cases
        )
    else:
        print(f"\nIAD directory not found: {iad_dir}")
    
    print("\n" + "="*60)
    print("Visualization complete!")
    print("="*60)

if __name__ == "__main__":
    main()


Processing IAD Dataset
Visualizing 170 cases from E:\kaggle-rsna-data_processing3\Dataset501_IAD...
Output directory: ./visualize_IAD
Processed: CASE_0001
Processed: CASE_0002
Processed: CASE_0003
Processed: CASE_0004
Processed: CASE_0005
Processed: CASE_0006
Processed: CASE_0007
Processed: CASE_0008
Processed: CASE_0009
Processed: CASE_0010
Processed: CASE_0011
Processed: CASE_0012
Processed: CASE_0013
Processed: CASE_0014
Processed: CASE_0015
Processed: CASE_0016
Processed: CASE_0017
Processed: CASE_0018
Processed: CASE_0019
Processed: CASE_0020
Processed: CASE_0021
Processed: CASE_0022
Processed: CASE_0023
Processed: CASE_0024
Processed: CASE_0025
Processed: CASE_0026
Processed: CASE_0027
Processed: CASE_0028
Processed: CASE_0029
Processed: CASE_0030
Processed: CASE_0031
Processed: CASE_0032
Processed: CASE_0033
Processed: CASE_0034
Processed: CASE_0035
Processed: CASE_0036
Processed: CASE_0037
Processed: CASE_0038
Processed: CASE_0039
Processed: CASE_0040
Processed: CASE_0041
Proc