In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

def visualize_masks(volume_dir, mask_dir, output_dir):
    """Visualize masks overlaid on volumes"""
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all mask files
    mask_files = list(Path(mask_dir).glob('*.npy'))
    
    print(f"Found {len(mask_files)} mask files to visualize\n")
    
    successful = 0
    skipped = 0
    
    for mask_file in tqdm(mask_files, desc="Creating visualizations"):
        series_uid = mask_file.stem
        volume_file = Path(volume_dir) / f"{series_uid}.npy"
        
        # Check if corresponding volume exists
        if not volume_file.exists():
            print(f"\nSkipped {series_uid}: Volume file not found")
            skipped += 1
            continue
        
        try:
            # Load volume and mask
            volume = np.load(volume_file)
            mask = np.load(mask_file)
            
            # Check shapes match
            if volume.shape != mask.shape:
                print(f"\nSkipped {series_uid}: Shape mismatch - Volume {volume.shape} vs Mask {mask.shape}")
                skipped += 1
                continue
            
            # Find slice with largest mask area (sum along dim 1 and 2)
            mask_areas = np.sum(mask, axis=(1, 2))
            max_slice_idx = np.argmax(mask_areas)
            
            # Skip if no mask present
            if mask_areas[max_slice_idx] == 0:
                print(f"\nSkipped {series_uid}: No mask present")
                skipped += 1
                continue
            
            # Create figure with 3 subplots
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            
            # Plot 1: Volume only
            axes[0].imshow(volume[max_slice_idx], cmap='gray')
            axes[0].set_title(f'Volume - Slice {max_slice_idx}')
            axes[0].axis('off')
            
            # Plot 2: Mask only
            axes[1].imshow(mask[max_slice_idx], cmap='jet')
            axes[1].set_title(f'Mask - Slice {max_slice_idx}')
            axes[1].axis('off')
            
            # Plot 3: Overlay
            axes[2].imshow(volume[max_slice_idx], cmap='gray', alpha=1.0)
            # Only show mask where it's non-zero
            mask_slice = mask[max_slice_idx]
            masked_overlay = np.ma.masked_where(mask_slice == 0, mask_slice)
            axes[2].imshow(masked_overlay, cmap='jet', alpha=0.5)
            axes[2].set_title(f'Overlay - Slice {max_slice_idx}')
            axes[2].axis('off')
            
            # Add series UID as figure title
            fig.suptitle(f'Series: {series_uid}', fontsize=10, y=0.98)
            
            plt.tight_layout()
            
            # Save figure
            output_file = Path(output_dir) / f"{series_uid}.png"
            plt.savefig(output_file, dpi=100, bbox_inches='tight')
            plt.close()
            
            successful += 1
            
        except Exception as e:
            print(f"\nError processing {series_uid}: {e}")
            skipped += 1
            continue
    
    # Summary
    print(f"\n{'='*80}")
    print("VISUALIZATION SUMMARY")
    print(f"{'='*80}")
    print(f"Successfully visualized: {successful}")
    print(f"Skipped: {skipped}")
    print(f"Total: {len(mask_files)}")
    print(f"{'='*80}")

if __name__ == "__main__":
    volume_dir = r"./volume_uint8_256"
    mask_dir = r"./mask_256"
    output_dir = r"./visualize_masks"
    
    visualize_masks(volume_dir, mask_dir, output_dir)

Found 170 mask files to visualize



Creating visualizations: 100%|██████████| 170/170 [00:42<00:00,  4.02it/s]


VISUALIZATION SUMMARY
Successfully visualized: 170
Skipped: 0
Total: 170



