In [3]:
import os
import numpy as np
from pathlib import Path
from scipy.ndimage import gaussian_filter1d
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')


def generate_attention_map(volume_shape, target_pct, std=5):
    """Generate attention heatmap for a volume with multiple targets
    
    Args:
        volume_shape: (depth, height, width) - shape of the volume
        target_pct: (N, 3) array of [x_pct, y_pct, z_pct]
        std: standard deviation for Gaussian blur in pixels
    
    Returns:
        attention_map: (depth, height, width) array with values in [0, 1]
    """
    depth, height, width = volume_shape
    
    # Initialize final attention map
    att_map_aggregated = np.zeros(volume_shape, dtype=np.float32)
    
    # Process each target
    for target in target_pct:
        x_pct, y_pct, z_pct = target
        
        # Convert percentages to indices
        z_idx = int(z_pct * depth)
        row_idx = int(y_pct * height)
        col_idx = int(x_pct * width)
        
        # Clamp to valid range
        z_idx = max(0, min(z_idx, depth - 1))
        row_idx = max(0, min(row_idx, height - 1))
        col_idx = max(0, min(col_idx, width - 1))
        
        # Create individual attention map for this target
        att_map_single = np.zeros(volume_shape, dtype=np.float32)
        att_map_single[z_idx, row_idx, col_idx] = 1.0
        
        # Apply 3D Gaussian blur
        from scipy.ndimage import gaussian_filter
        att_map_blurred = gaussian_filter(att_map_single, sigma=std, mode='constant')
        
        # Normalize to [0, 1]
        max_val = att_map_blurred.max()
        if max_val > 0:
            att_map_blurred = att_map_blurred / max_val
        
        # Maximum aggregation
        att_map_aggregated = np.maximum(att_map_aggregated, att_map_blurred)
    
    return att_map_aggregated


def generate_slice_labels(target_pct, target_cls, depth=256, num_classes=14):
    """Generate slice-level labels for a volume (binary labels at exact indices)
    
    Args:
        target_pct: (N, 3) array of [x_pct, y_pct, z_pct]
        target_cls: (N,) array of class indices for each target
        depth: depth dimension (default 256)
        num_classes: number of condition classes (default 14)
    
    Returns:
        slice_labels: (num_classes+1, depth) array 
                      First 14 rows for each class, last row is maximum of all
                      Values are binary: 1 at exact target indices, 0 elsewhere
    """
    # Initialize slice labels: (14, 256)
    slice_labels = np.zeros((num_classes, depth), dtype=np.float32)
    
    # Process each target
    for target, cls_idx in zip(target_pct, target_cls):
        x_pct, y_pct, z_pct = target
        
        # Get z index
        z_idx = int(z_pct * depth)
        z_idx = max(0, min(z_idx, depth - 1))
        
        # Set label to 1 at the exact index for this class
        slice_labels[cls_idx, z_idx] = 1.0
    
    # Add last row as maximum of all classes
    slice_labels_with_max = np.zeros((num_classes + 1, depth), dtype=np.float32)
    slice_labels_with_max[:num_classes, :] = slice_labels
    slice_labels_with_max[-1, :] = np.maximum.reduce(slice_labels, axis=0)
    
    return slice_labels_with_max


def process_series_attention_and_labels(series_uid, volume_dir, target_pct_dir, 
                                       target_cls_dir, att_map_dir, slice_label_dir):
    """Process a single series to generate attention map and slice labels
    
    Args:
        series_uid: series identifier
        volume_dir: directory containing volumes
        target_pct_dir: directory containing target percentages
        target_cls_dir: directory containing target class labels
        att_map_dir: output directory for attention maps
        slice_label_dir: output directory for slice labels
    
    Returns:
        success: True if processed successfully
    """
    # Check if target file exists
    target_pct_file = Path(target_pct_dir) / f"{series_uid}.npy"
    if not target_pct_file.exists():
        return False  # No targets, skip
    
    # Load volume to get shape
    volume_file = Path(volume_dir) / f"{series_uid}.npy"
    if not volume_file.exists():
        print(f"Warning: Volume not found for {series_uid}")
        return False
    
    volume = np.load(volume_file)
    volume_shape = volume.shape
    
    # Load target percentages
    target_pct = np.load(target_pct_file)
    
    # Load target classes
    target_cls_file = Path(target_cls_dir) / f"{series_uid}.npy"
    if not target_cls_file.exists():
        print(f"Warning: Target classes not found for {series_uid}")
        return False
    
    target_cls = np.load(target_cls_file)
    
    # Validate dimensions
    if len(target_pct) != len(target_cls):
        print(f"Warning: Mismatch in target counts for {series_uid}: {len(target_pct)} vs {len(target_cls)}")
        return False
    
    # Generate attention map
    att_map = generate_attention_map(volume_shape, target_pct, std=5)
    
    # Save attention map
    np.save(Path(att_map_dir) / f"{series_uid}.npy", att_map.astype(np.float32))
    
    # Generate slice labels (assuming depth=256) - now binary at exact indices
    slice_labels = generate_slice_labels(target_pct, target_cls, depth=256, num_classes=14)
    
    # Save slice labels
    np.save(Path(slice_label_dir) / f"{series_uid}.npy", slice_labels)
    
    return True


def main():
    # Paths
    volume_dir =  r"E:\kaggle-rsna-data_processing3\volume_uint8_256_z"
    target_pct_dir = r"E:\kaggle-rsna-data_processing3\label_percentage_z"
    target_cls_dir = r"E:\kaggle-rsna-crop2\target_cls"
    
    att_map_dir = r"E:\kaggle-rsna-data_processing3\att_map_z"
    slice_label_dir = r".\slice_level_label_z"
    
    # Create output directories
    Path(att_map_dir).mkdir(parents=True, exist_ok=True)
    Path(slice_label_dir).mkdir(parents=True, exist_ok=True)
    
    # Get all series from volume directory
    volume_files = list(Path(volume_dir).glob("*.npy"))
    series_uids = [f.stem for f in volume_files]
    
    print(f"Found {len(series_uids)} series in volume directory")
    
    # Statistics
    successful = 0
    skipped_no_targets = 0
    failed = 0
    
    # Process all series
    for series_uid in tqdm(series_uids, desc="Processing series"):
        try:
            success = process_series_attention_and_labels(
                series_uid,
                volume_dir,
                target_pct_dir,
                target_cls_dir,
                att_map_dir,
                slice_label_dir
            )
            
            if success:
                successful += 1
            else:
                skipped_no_targets += 1
                
        except Exception as e:
            print(f"\nERROR processing {series_uid}: {e}")
            import traceback
            traceback.print_exc()
            failed += 1
            continue
    
    # Summary
    print("\n" + "="*60)
    print("PROCESSING COMPLETE")
    print("="*60)
    print(f"Total series found: {len(series_uids)}")
    print(f"Successfully processed: {successful}")
    print(f"Skipped (no targets): {skipped_no_targets}")
    print(f"Failed: {failed}")
    print("="*60)
    print(f"\nOutput directories:")
    print(f"  Attention maps: {att_map_dir}")
    print(f"  Slice labels: {slice_label_dir}")
    
    # Verify outputs
    att_map_count = len(list(Path(att_map_dir).glob("*.npy")))
    slice_label_count = len(list(Path(slice_label_dir).glob("*.npy")))
    print(f"\nGenerated files:")
    print(f"  Attention maps: {att_map_count}")
    print(f"  Slice labels: {slice_label_count}")
    print("="*60)


if __name__ == "__main__":
    main()

Found 4326 series in volume directory


Processing series: 100%|██████████| 4326/4326 [46:35<00:00,  1.55it/s]  


PROCESSING COMPLETE
Total series found: 4326
Successfully processed: 1842
Skipped (no targets): 2484
Failed: 0

Output directories:
  Attention maps: E:\kaggle-rsna-data_processing3\att_map_z
  Slice labels: .\slice_level_label_z

Generated files:
  Attention maps: 1842
  Slice labels: 1842



