# UT-EndoMRI Dataset Exploration

This notebook provides an exploratory analysis of the UT-EndoMRI dataset for endometriosis segmentation.

**Contents:**
1. Dataset Overview
2. Load and Visualize MRI Scans
3. Analyze Label Statistics
4. Intensity Distribution Analysis
5. Inter-rater Agreement (Dataset 1)
6. Class Imbalance Analysis

In [None]:
# Imports
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path
import json

from src.data.utils import (
    load_nifti,
    get_dataset_statistics,
    parse_filename,
    get_subject_files,
    load_data_splits
)

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (15, 8)

print("Imports successful!")

## 1. Dataset Overview

In [None]:
# Set paths
data_root = "../data/raw/UT-EndoMRI"

# Get statistics for both datasets
print("Dataset 1 (D1_MHS - Multi-center, Multi-rater):")
print("="*60)
stats_d1 = get_dataset_statistics(data_root, "D1_MHS")
print(json.dumps(stats_d1, indent=2))

print("\n\nDataset 2 (D2_TCPW - Single-center, Single-rater):")
print("="*60)
stats_d2 = get_dataset_statistics(data_root, "D2_TCPW")
print(json.dumps(stats_d2, indent=2))

## 2. Load and Visualize MRI Scans

In [None]:
# Load a sample subject from Dataset 2
subject_dir = Path(data_root) / "D2_TCPW" / "D2-000"
files = get_subject_files(subject_dir)

# Load T2FS image
image_file = [f for f in files['images'] if 'T2FS' in f.name][0]
image_data, image_obj = load_nifti(str(image_file))

# Load labels
label_files = {f.name.split('_')[1].replace('.nii.gz', ''): f for f in files['labels']}

print(f"Image shape: {image_data.shape}")
print(f"Image spacing: {np.abs(np.diag(image_obj.affine)[:3])} mm")
print(f"Intensity range: [{image_data.min():.2f}, {image_data.max():.2f}]")
print(f"Available labels: {list(label_files.keys())}")

In [None]:
# Visualize multiple slices
def visualize_slices(image, label=None, num_slices=6, cmap='gray'):
    """Visualize multiple slices from 3D volume"""
    depth = image.shape[2]
    slice_indices = np.linspace(depth//4, 3*depth//4, num_slices, dtype=int)
    
    fig, axes = plt.subplots(2, num_slices//2, figsize=(20, 8))
    axes = axes.flatten()
    
    for idx, slice_idx in enumerate(slice_indices):
        axes[idx].imshow(image[:, :, slice_idx].T, cmap=cmap, origin='lower')
        
        if label is not None:
            # Overlay label
            label_slice = label[:, :, slice_idx].T
            masked = np.ma.masked_where(label_slice == 0, label_slice)
            axes[idx].imshow(masked, cmap='jet', alpha=0.5, origin='lower')
        
        axes[idx].set_title(f'Slice {slice_idx}')
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

# Load uterus label
if 'ut' in label_files:
    ut_label, _ = load_nifti(str(label_files['ut']))
    print("Visualizing T2FS image with uterus label overlay:")
    visualize_slices(image_data, ut_label)
else:
    print("Visualizing T2FS image only:")
    visualize_slices(image_data)

## 3. Label Statistics and Volume Analysis

In [None]:
# Collect volume statistics for all structures
def collect_volumes(data_root, dataset_name):
    """Collect volume statistics for all structures"""
    from collections import defaultdict
    
    dataset_path = Path(data_root) / dataset_name
    volumes = defaultdict(list)
    
    for subject_dir in dataset_path.iterdir():
        if not subject_dir.is_dir():
            continue
        
        files = get_subject_files(subject_dir)
        
        for label_file in files['labels']:
            info = parse_filename(label_file.name)
            struct_type = info['type']
            
            try:
                data, img = load_nifti(str(label_file))
                spacing = np.abs(np.diag(img.affine)[:3])
                voxel_volume = np.prod(spacing)
                num_voxels = np.sum(data > 0)
                volume_cc = (num_voxels * voxel_volume) / 1000
                
                if volume_cc > 0:
                    volumes[struct_type].append(volume_cc)
            except:
                continue
    
    return volumes

# Collect volumes
print("Collecting volume statistics for Dataset 2...")
volumes_d2 = collect_volumes(data_root, "D2_TCPW")

# Create DataFrame
volume_data = []
for struct, vols in volumes_d2.items():
    for vol in vols:
        volume_data.append({'Structure': struct, 'Volume (cc)': vol})

df_volumes = pd.DataFrame(volume_data)

# Summary statistics
print("\nVolume Statistics Summary:")
print(df_volumes.groupby('Structure')['Volume (cc)'].describe())

In [None]:
# Visualize volume distributions
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Box plot
df_volumes.boxplot(column='Volume (cc)', by='Structure', ax=axes[0])
axes[0].set_title('Volume Distribution by Structure')
axes[0].set_ylabel('Volume (cc)')
axes[0].set_xlabel('Structure')
plt.sca(axes[0])
plt.xticks(rotation=45)

# Violin plot
sns.violinplot(data=df_volumes, x='Structure', y='Volume (cc)', ax=axes[1])
axes[1].set_title('Volume Distribution (Violin Plot)')
axes[1].set_xlabel('Structure')
axes[1].set_ylabel('Volume (cc)')
axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

# Print key insights
print("\nKey Insights:")
print(f"- Uterus is the largest structure (mean: {df_volumes[df_volumes['Structure']=='ut']['Volume (cc)'].mean():.1f} cc)")
print(f"- Ovaries are much smaller (mean: {df_volumes[df_volumes['Structure']=='ov']['Volume (cc)'].mean():.1f} cc)")
if 'em' in volumes_d2:
    print(f"- Endometriomas when present (mean: {df_volumes[df_volumes['Structure']=='em']['Volume (cc)'].mean():.1f} cc)")

## 4. Intensity Distribution Analysis

In [None]:
# Analyze intensity distributions within different structures
def analyze_intensities(image, labels_dict):
    """Extract intensity statistics for each structure"""
    intensities = {}
    
    for name, label in labels_dict.items():
        mask = label > 0
        if np.any(mask):
            intensities[name] = image[mask]
    
    return intensities

# Load multiple labels for current subject
labels_dict = {}
for name, file_path in label_files.items():
    label_data, _ = load_nifti(str(file_path))
    labels_dict[name] = label_data

# Get intensities
intensities = analyze_intensities(image_data, labels_dict)

# Plot intensity distributions
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Histogram
for name, values in intensities.items():
    axes[0].hist(values.flatten(), bins=50, alpha=0.6, label=name)
axes[0].set_xlabel('Intensity Value')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Intensity Distribution by Structure')
axes[0].legend()

# Box plot
intensity_data = []
for name, values in intensities.items():
    sample = np.random.choice(values.flatten(), min(1000, len(values.flatten())))
    for val in sample:
        intensity_data.append({'Structure': name, 'Intensity': val})

df_intensity = pd.DataFrame(intensity_data)
sns.boxplot(data=df_intensity, x='Structure', y='Intensity', ax=axes[1])
axes[1].set_title('Intensity Distribution by Structure')
axes[1].set_xlabel('Structure')
axes[1].set_ylabel('Intensity')

plt.tight_layout()
plt.show()

# Print statistics
print("\nIntensity Statistics:")
for name, values in intensities.items():
    print(f"\n{name}:")
    print(f"  Mean: {values.mean():.2f}")
    print(f"  Std: {values.std():.2f}")
    print(f"  Range: [{values.min():.2f}, {values.max():.2f}]")

## 5. Class Imbalance Analysis

In [None]:
# Analyze class distribution in labels
def compute_class_distribution(labels_dict):
    """Compute voxel counts for each structure"""
    class_dist = {}
    
    # Get total volume size
    total_voxels = list(labels_dict.values())[0].size
    
    # Count background (everything not labeled)
    all_labels = np.zeros_like(list(labels_dict.values())[0])
    for label in labels_dict.values():
        all_labels = np.logical_or(all_labels, label > 0)
    
    class_dist['background'] = total_voxels - np.sum(all_labels)
    
    # Count each structure
    for name, label in labels_dict.items():
        class_dist[name] = np.sum(label > 0)
    
    return class_dist

# Compute distribution
class_dist = compute_class_distribution(labels_dict)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Pie chart
axes[0].pie(class_dist.values(), labels=class_dist.keys(), autopct='%1.1f%%')
axes[0].set_title('Class Distribution (Voxel Percentage)')

# Bar chart (log scale)
names = list(class_dist.keys())
counts = list(class_dist.values())
axes[1].bar(names, counts)
axes[1].set_yscale('log')
axes[1].set_xlabel('Class')
axes[1].set_ylabel('Voxel Count (log scale)')
axes[1].set_title('Class Distribution (Log Scale)')
axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

# Print statistics
print("\nClass Imbalance Statistics:")
total = sum(class_dist.values())
for name, count in class_dist.items():
    percentage = (count / total) * 100
    print(f"{name}: {count} voxels ({percentage:.2f}%)")

# Compute imbalance ratio
foreground_voxels = sum([v for k, v in class_dist.items() if k != 'background'])
background_voxels = class_dist['background']
imbalance_ratio = background_voxels / foreground_voxels
print(f"\nBackground/Foreground ratio: {imbalance_ratio:.2f}:1")

## 6. Train/Val/Test Split Analysis

In [None]:
# Load splits
splits_file = "../data/splits/split_info.json"
if Path(splits_file).exists():
    splits = load_data_splits(splits_file)
    
    print("Data Split Information:")
    print(f"\nDataset: {splits['dataset']}")
    print(f"Random seed: {splits['seed']}")
    print(f"\nSplit sizes:")
    print(f"  Train: {len(splits['train'])} subjects ({splits['ratios']['train']:.1%})")
    print(f"  Val: {len(splits['val'])} subjects ({splits['ratios']['val']:.1%})")
    print(f"  Test: {len(splits['test'])} subjects ({splits['ratios']['test']:.1%})")
    
    # Visualize
    fig, ax = plt.subplots(figsize=(8, 6))
    splits_counts = [len(splits['train']), len(splits['val']), len(splits['test'])]
    ax.bar(['Train', 'Val', 'Test'], splits_counts, color=['#2ecc71', '#3498db', '#e74c3c'])
    ax.set_ylabel('Number of Subjects')
    ax.set_title('Train/Val/Test Split Distribution')
    for i, v in enumerate(splits_counts):
        ax.text(i, v + 0.5, str(v), ha='center', va='bottom')
    plt.show()
else:
    print("Split file not found. Run: python scripts/create_splits.py")

## Summary and Key Findings

From this exploratory analysis, we can conclude:

1. **Dataset Composition:**
   - Dataset 1: Multi-center, multi-rater (more challenging)
   - Dataset 2: Single-center, single-rater (more consistent)

2. **Structure Sizes:**
   - Uterus is the largest structure (~220 cc)
   - Ovaries are much smaller (~12 cc)
   - This size difference contributes to segmentation difficulty

3. **Class Imbalance:**
   - Significant imbalance between background and foreground
   - Important for loss function design (Focal Loss, Tversky Loss)

4. **Inter-rater Variability:**
   - Lower agreement for ovaries vs uterus
   - Justifies uncertainty quantification approach

5. **Next Steps:**
   - Implement preprocessing pipeline
   - Design appropriate loss functions for imbalanced data
   - Develop uncertainty-aware model architecture