In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm

# Define label columns for reference
LABEL_COLUMNS = [
    'Left Infraclinoid Internal Carotid Artery',
    'Right Infraclinoid Internal Carotid Artery', 
    'Left Supraclinoid Internal Carotid Artery',
    'Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery',
    'Right Middle Cerebral Artery',
    'Anterior Communicating Artery',
    'Left Anterior Cerebral Artery',
    'Right Anterior Cerebral Artery',
    'Left Posterior Communicating Artery',
    'Right Posterior Communicating Artery',
    'Basilar Tip',
    'Other Posterior Circulation',
    'Aneurysm Present'
]

def visualize_label(volume_path, label_percentage_path, output_path, series_uid):
    """
    Visualize lesion locations on volume slices
    
    Args:
        volume_path: Path to the volume .npy file
        label_percentage_path: Path to the label percentage coordinates .npy file
        output_path: Path to save the visualization
        series_uid: Series UID for the filename
    """
    # Load volume
    volume = np.load(volume_path)
    
    # Load label percentage coordinates
    percentage_coords = np.load(label_percentage_path)
    
    # Number of lesions
    num_lesions = len(percentage_coords)
    
    if num_lesions == 0:
        print(f"No lesions found for {series_uid}")
        return
    
    # Get volume dimensions
    dim0, dim1, dim2 = volume.shape
    
    # Convert percentage coordinates to pixel coordinates
    pixel_coords = []
    for coord in percentage_coords:
        x_pct, y_pct, z_pct = coord
        x_pixel = int(x_pct * dim2)  # dim2 corresponds to x
        y_pixel = int(y_pct * dim1)  # dim1 corresponds to y
        z_pixel = int(z_pct * dim0)  # dim0 corresponds to z (slice index)
        
        # Clamp to valid range
        x_pixel = np.clip(x_pixel, 0, dim2)
        y_pixel = np.clip(y_pixel, 0, dim1)
        z_pixel = np.clip(z_pixel, 0, dim0)
        
        pixel_coords.append((x_pixel, y_pixel, z_pixel))
    
    # Group lesions by slice (z_pixel)
    slice_to_lesions = {}
    for idx, (x_pixel, y_pixel, z_pixel) in enumerate(pixel_coords):
        if z_pixel not in slice_to_lesions:
            slice_to_lesions[z_pixel] = []
        slice_to_lesions[z_pixel].append((x_pixel, y_pixel, idx))
    
    # Create figure with subplots for each unique slice
    num_slices = len(slice_to_lesions)
    
    if num_slices == 1:
        fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        axes = [ax]
    else:
        cols = min(3, num_slices)
        rows = (num_slices + cols - 1) // cols
        fig, axes = plt.subplots(rows, cols, figsize=(8 * cols, 8 * rows))
        axes = axes.flatten() if num_slices > 1 else [axes]
    
    # Sort slices for consistent ordering
    sorted_slices = sorted(slice_to_lesions.keys())
    
    # Plot each slice with its lesions
    for plot_idx, z_pixel in enumerate(sorted_slices):
        ax = axes[plot_idx]
        
        # Get the slice
        slice_img = volume[z_pixel, :, :]
        
        # Display the slice
        ax.imshow(slice_img, cmap='gray', vmin=0, vmax=255)
        
        # Get lesions for this slice
        lesions = slice_to_lesions[z_pixel]
        
        # Draw crosses for each lesion
        cross_size = 20  # Size of the cross arms
        for x_pixel, y_pixel, lesion_idx in lesions:
            # Draw horizontal line
            ax.plot([x_pixel - cross_size, x_pixel + cross_size], 
                   [y_pixel, y_pixel], 
                   color='yellow', linewidth=1, alpha=0.5)
            
            # Draw vertical line
            ax.plot([x_pixel, x_pixel], 
                   [y_pixel - cross_size, y_pixel + cross_size], 
                   color='yellow', linewidth=1, alpha=0.5)
            
            # Add lesion number
            ax.text(x_pixel + cross_size + 5, y_pixel, 
                   f'L{lesion_idx + 1}', 
                   color='red', fontsize=12, fontweight='bold',
                   bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.5))
        
        # Set title with slice info
        ax.set_title(f'Slice {z_pixel}/{dim0-1} ({len(lesions)} lesion(s))', 
                    fontsize=14, fontweight='bold')
        ax.axis('off')
    
    # Hide unused subplots
    for idx in range(num_slices, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1, dpi=150)
    plt.close()
    
    return num_lesions, num_slices

def visualize_all_labels(volume_dir, label_percentage_dir, output_dir):
    """
    Visualize all labels from the dataset
    
    Args:
        volume_dir: Directory containing volume .npy files
        label_percentage_dir: Directory containing label percentage .npy files
        output_dir: Directory to save visualizations
    """
    volume_path = Path(volume_dir)
    label_path = Path(label_percentage_dir)
    output_path = Path(output_dir)
    
    output_path.mkdir(exist_ok=True)
    
    # Get all label percentage files
    label_files = list(label_path.glob('*.npy'))
    
    print(f"Found {len(label_files)} label files")
    
    stats = {
        'total_processed': 0,
        'total_lesions': 0,
        'total_slices_with_lesions': 0
    }
    
    for label_file in tqdm(label_files, desc="Visualizing labels"):
        series_uid = label_file.stem
        
        # Check if corresponding volume exists
        volume_file = volume_path / f"{series_uid}.npy"
        
        if not volume_file.exists():
            print(f"Warning: Volume not found for {series_uid}")
            continue
        
        try:
            output_file = output_path / f"{series_uid}.png"
            
            num_lesions, num_slices = visualize_label(
                volume_file, 
                label_file, 
                output_file, 
                series_uid
            )
            
            stats['total_processed'] += 1
            stats['total_lesions'] += num_lesions
            stats['total_slices_with_lesions'] += num_slices
            
        except Exception as e:
            print(f"Error processing {series_uid}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Print statistics
    print(f"\n{'='*60}")
    print(f"VISUALIZATION COMPLETE")
    print(f"{'='*60}")
    print(f"Total series processed: {stats['total_processed']}")
    print(f"Total lesions visualized: {stats['total_lesions']}")
    print(f"Total slices with lesions: {stats['total_slices_with_lesions']}")
    if stats['total_processed'] > 0:
        print(f"Average lesions per series: {stats['total_lesions'] / stats['total_processed']:.2f}")
        print(f"Average slices per series: {stats['total_slices_with_lesions'] / stats['total_processed']:.2f}")
    print(f"{'='*60}")

# Main execution
if __name__ == "__main__":
    # Paths
    volume_dir = r"./volume_uint8_256"  # or "./volume_uint8_256" for resized volumes
    label_percentage_dir = r"./label_percentage"
    output_dir = r"./label_visual"
    
    visualize_all_labels(
        volume_dir=volume_dir,
        label_percentage_dir=label_percentage_dir,
        output_dir=output_dir
    )

Found 1842 label files


Visualizing labels: 100%|██████████| 1842/1842 [08:06<00:00,  3.79it/s]


VISUALIZATION COMPLETE
Total series processed: 1842
Total lesions visualized: 2214
Total slices with lesions: 2169
Average lesions per series: 1.20
Average slices per series: 1.18



