In [1]:
import os
import numpy as np
import pydicom
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
import ast

def should_rescale_ct(ds, pixel_array):
    """Determine if CT should be rescaled"""
    if ds.get('Modality', '') != 'CT':
        return False
    if not (hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept')):
        return False
    min_pixel = pixel_array.min()
    if min_pixel >= -100 or min_pixel == -2000:
        return True
    return False

def get_direction_label(vec):
    """Convert orientation vector to direction label"""
    abs_vec = np.abs(vec)
    dominant_idx = np.argmax(abs_vec)
    dominant_val = vec[dominant_idx]
    
    if dominant_idx == 0:  # X axis
        return 'RL' if dominant_val > 0 else 'LR'
    elif dominant_idx == 1:  # Y axis
        return 'AP' if dominant_val > 0 else 'PA'
    else:  # Z axis
        return 'FH' if dominant_val > 0 else 'HF'

def determine_orientation(iop):
    """Determine row and column directions from ImageOrientationPatient"""
    row_vec = np.array(iop[:3])
    col_vec = np.array(iop[3:6])
    
    dim2_dir = get_direction_label(row_vec)
    dim1_dir = get_direction_label(col_vec)
    
    dim2_axis = 0 if dim2_dir in ['LR', 'RL'] else (1 if dim2_dir in ['AP', 'PA'] else 2)
    dim1_axis = 0 if dim1_dir in ['LR', 'RL'] else (1 if dim1_dir in ['AP', 'PA'] else 2)
    
    used_axes = {dim2_axis, dim1_axis}
    dim0_axis = (set([0, 1, 2]) - used_axes).pop()
    dim0_dir = ['RL', 'AP', 'FH'][dim0_axis]
    
    return dim2_dir, dim1_dir, dim0_dir

def resize_volume_xy_only(volume, target_xy=256):
    """Resize volume to 256x256 in x,y dimensions, keep z unchanged"""
    # Calculate zoom factors: keep z (dim 0) unchanged, resize x,y (dims 1,2) to 256
    zoom_factors = [1.0, target_xy / volume.shape[1], target_xy / volume.shape[2]]
    resized = zoom(volume, zoom_factors, order=1)
    return resized

def parse_coordinates(coord_str):
    """Parse coordinate string to dictionary"""
    return ast.literal_eval(coord_str)

def get_slice_position(ds):
    """Get slice position for sorting"""
    return float(ds.ImagePositionPatient[2])

def build_sop_to_slice_mapping(series_folder):
    """Build mapping of SOP UID to slice index by sorting DICOM files"""
    dcm_files = list(Path(series_folder).glob("*.dcm"))
    
    is_multiframe_series = len(dcm_files) == 1
    
    rows, cols = None, None
    num_slices = 0
    sop_to_slice_mapping = {}
    
    if is_multiframe_series:
        dcm_file = dcm_files[0]
        ds = pydicom.dcmread(dcm_file)
        pixel_array = ds.pixel_array
        
        rows = ds.Rows
        cols = ds.Columns
        
        assert len(pixel_array.shape) == 3
        num_slices = pixel_array.shape[0]
        
        if 'NumberOfFrames' in ds:
            assert num_slices == int(ds.NumberOfFrames)
    else:
        slice_data = []
        
        for dcm_file in dcm_files:
            ds = pydicom.dcmread(dcm_file)
            
            if rows is None:
                rows = ds.Rows
                cols = ds.Columns
            
            sop_uid = ds.SOPInstanceUID
            position = get_slice_position(ds)
            
            slice_data.append({
                'sop_uid': sop_uid,
                'position': position
            })
        
        slice_data.sort(key=lambda x: x['position'])
        
        for idx, slice_info in enumerate(slice_data):
            sop_to_slice_mapping[slice_info['sop_uid']] = idx
        
        num_slices = len(slice_data)
    
    return sop_to_slice_mapping, num_slices, rows, cols, is_multiframe_series

def is_series_valid(series_uid):
    """Check if a series is valid for processing"""
    if series_uid in INVALID_SERIES_SET:
        return False, "invalid_series"
    return True, series_uid

def process_series_with_crop(series_path, localizers_df):
    """Process a single series folder with cropping and label computation"""
    dcm_files = list(Path(series_path).glob('*.dcm'))
    
    # Skip if only one file (multi-frame)
    # if len(dcm_files) <= 1:
        # return None
    
    # Read first DICOM to get orientation
    ds = pydicom.dcmread(dcm_files[0])
    series_uid = ds.SeriesInstanceUID
    modality = ds.get('Modality', 'Unknown')
    
    # iop = ds.ImageOrientationPatient
    # dim2_dir, dim1_dir, dim0_dir = determine_orientation(iop)
    
    # Build SOP mapping for label calculation
    sop_to_slice_mapping, num_slices, rows, cols, is_multiframe = build_sop_to_slice_mapping(series_path)
    
    # Read all slices and sort by Z position
    if not is_multiframe:
        slices = []
        for dcm_file in dcm_files:
            ds = pydicom.dcmread(dcm_file)
            z_pos = ds.ImagePositionPatient[2]
            slices.append((z_pos, ds))
        
        slices.sort(key=lambda x: x[0])
        
        # Stack volume
        volume = []
        for z_pos, ds in slices:
            pixel_array = ds.pixel_array.astype(np.float32)
            
            # Apply rescale for CT if needed
            if should_rescale_ct(ds, pixel_array):
                pixel_array = pixel_array * ds.RescaleSlope + ds.RescaleIntercept
            
            volume.append(pixel_array)
        
        volume = np.stack(volume, axis=0)
    else:
        ds = pydicom.dcmread(dcm_files[0])
        pixel_array = ds.pixel_array.astype(np.float32)
        
        # Apply rescale for CT if needed
        if should_rescale_ct(ds, pixel_array):
            pixel_array = pixel_array * ds.RescaleSlope + ds.RescaleIntercept
        
        volume = pixel_array  # shape (num_slices, rows, cols)

    # Record original size
    original_shape = volume.shape

    # Crop at dim0 [-600:, :, :]
    crop_start = max(0, volume.shape[0] - 600)
    volume_cropped = volume[crop_start:, :, :]
    cropped_shape = volume_cropped.shape
    
    # Calculate crop offset for coordinate adjustment
    crop_offset = crop_start
    
    # Clip CT
    if modality == 'CT':
        volume_cropped = np.clip(volume_cropped, -200, 600)
    
    # Normalization
    if modality == 'CT':
        vmin, vmax = volume_cropped.min(), volume_cropped.max()
    else:
        # Min-max normalization for MR
        vmin, vmax = np.percentile(volume_cropped, 0.5), np.percentile(volume_cropped, 99.5)
        volume_cropped = np.clip(volume_cropped, vmin, vmax)
    
    if vmax > vmin:
        volume_cropped = (volume_cropped - vmin) / (vmax - vmin)
    else:
        volume_cropped = np.zeros_like(volume_cropped)
    
    # Calculate percentage coordinates after cropping
    series_localizers = localizers_df[localizers_df['SeriesInstanceUID'] == series_uid]
    percentage_coords = []
    
    if len(series_localizers) > 0:
        for _, loc_row in series_localizers.iterrows():
            sop_uid = loc_row['SOPInstanceUID']
            coords = parse_coordinates(loc_row['coordinates'])
            
            if coords is None:
                continue
            
            x = coords['x']
            y = coords['y']
            
            # Determine z index
            if is_multiframe:
                z_idx = int(coords['f']) - 1
            else:
                z_idx = sop_to_slice_mapping.get(sop_uid)
                if z_idx is None:
                    continue
            
            # Adjust z index after cropping
            z_idx_cropped = z_idx - crop_offset
            
            # Skip if the coordinate is outside the cropped region
            if z_idx_cropped < 0 or z_idx_cropped >= cropped_shape[0]:
                continue
            
            # Calculate percentage coordinates based on cropped volume
            x_pct = x / cols
            y_pct = y / rows
            z_pct = z_idx_cropped / cropped_shape[0]
            
            percentage_coords.append([x_pct, y_pct, z_pct])
    
    # Resize to 256x256 in x,y only, keep z unchanged
    volume_resized = resize_volume_xy_only(volume_cropped, target_xy=256)
    
    # Scale to 0-255 and convert to uint8
    volume_uint8 = (volume_cropped * 255).astype(np.uint8)
    volume_uint8_256_z = (volume_resized * 255).astype(np.uint8)
    
    return {
        'SeriesInstanceUID': series_uid,
        'Modality': modality,
        'original_shape': original_shape,
        'cropped_shape': cropped_shape,
        'volume_uint8': volume_uint8,
        'volume_uint8_256_z': volume_uint8_256_z,
        'percentage_coords': percentage_coords
    }

# ============================================================================
# SETUP
# ============================================================================

segmentations_dir = Path(r'E:\data_old\segmentations')
series_dir = Path(r'E:\data_old\series')
output_dir = Path(r'./volume_uint8')
output_dir_256_z = Path(r'E:\kaggle-rsna-data_processing3\volume_uint8_256_z')
label_percentage_folder = Path(r'./label_percentage_z')

# Create output directories
output_dir.mkdir(exist_ok=True)
output_dir_256_z.mkdir(exist_ok=True, parents=True)
label_percentage_folder.mkdir(exist_ok=True)

# Load localizers CSV
localizers_csv_path = r"E:\data_old_updated\train_localizers.csv"
localizers_df = pd.read_csv(localizers_csv_path)

# Define invalid series to exclude
INVALID_SERIES = [
    "1.2.826.0.1.3680043.8.498.35204126697881966597435252550544407444",
    "1.2.826.0.1.3680043.8.498.11145695452143851764832708867797988068",
    "1.2.826.0.1.3680043.8.498.12937082136541515013380696257898978214",
    "1.2.826.0.1.3680043.8.498.86840850085811129970747331978337342341",
    "1.2.826.0.1.3680043.8.498.10733938921373716882398209756836684843",
    "1.2.826.0.1.3680043.8.498.11292203154407642658894712229998766945",
    "1.2.826.0.1.3680043.8.498.74390569791112039529514861261033590424",
    "1.2.826.0.1.3680043.8.498.99892390884723813599532075083872271516",
    "1.2.826.0.1.3680043.8.498.99421822954919332641371697175982753182",
    "1.2.826.0.1.3680043.8.498.93005379507993862369794871518209403819",
    "1.2.826.0.1.3680043.8.498.87133443408651185245864983172506753347",
    "1.2.826.0.1.3680043.8.498.85042275841446604538710616923989532822",
    "1.2.826.0.1.3680043.8.498.81867770017494605078034950552739870155",
    "1.2.826.0.1.3680043.8.498.75294325392457179365040684378207706807",
    "1.2.826.0.1.3680043.8.498.73348230187682293339845869829853553626",
    "1.2.826.0.1.3680043.8.498.34908224715351895924870591631151425521",
    "1.2.826.0.1.3680043.8.498.13356606276376861530476731358572238037",
    "1.2.826.0.1.3680043.8.498.13299935636593758131187104226860563078",
    "1.2.826.0.1.3680043.8.498.12780687841924878965940656634052376723",
    "1.2.826.0.1.3680043.8.498.12285352638636973719542944532929535087",
    "1.2.826.0.1.3680043.8.498.11019101980573889157112037207769236902",
    "1.2.826.0.1.3680043.8.498.10820472882684587647235099308830427864",
]

INVALID_SERIES_SET = set(INVALID_SERIES)

# Initialize size CSV
size_csv_path = './size.csv'
if not Path(size_csv_path).exists():
    pd.DataFrame(columns=['SeriesInstanceUID', 'original_dim0', 'original_dim1', 'original_dim2',
                          'cropped_dim0', 'cropped_dim1', 'cropped_dim2']).to_csv(size_csv_path, index=False)

# ============================================================================
# PHASE 1: Get series lists from both directories
# ============================================================================

print("="*80)
print("PHASE 1: Identifying series in segmentations and series directories")
print("="*80)

# Get series with segmentations
segmentation_series = set()
if segmentations_dir.exists():
    # Find all .nii files that don't end with _cowseg.nii
    for f in segmentations_dir.iterdir():
        if f.is_file() and f.name.endswith('.nii') and not f.name.endswith('_cowseg.nii'):
            # Extract seriesUID by removing the .nii extension
            series_uid = f.name.replace('.nii', '')
            segmentation_series.add(series_uid)
    
    print(f"\nSeries in segmentations: {len(segmentation_series)}")
else:
    print(f"\nWarning: Segmentations directory not found: {segmentations_dir}")

# Get all series
all_series = set()
if series_dir.exists():
    all_series = {f.name for f in series_dir.iterdir() if f.is_dir()}
    print(f"Series in series: {len(all_series)}")
else:
    print(f"\nError: Series directory not found: {series_dir}")
    exit(1)

# Series without segmentations
series_without_seg = all_series - segmentation_series
print(f"Series without segmentations: {len(series_without_seg)}")

# Check already processed
already_processed = set()
if output_dir_256_z.exists():
    already_processed = {f.stem for f in output_dir_256_z.glob('*.npy')}
print(f"Already processed: {len(already_processed)}")

# Filter out invalid and already processed
segmentation_series_to_process = segmentation_series - INVALID_SERIES_SET - already_processed
series_without_seg_to_process = series_without_seg - INVALID_SERIES_SET - already_processed

print("\n" + "="*80)
print("PROCESSING PLAN:")
print("="*80)
print(f"Phase 2 - Series with segmentations to process: {len(segmentation_series_to_process)}")
print(f"Phase 3 - Series without segmentations to process: {len(series_without_seg_to_process)}")
print(f"Total to process: {len(segmentation_series_to_process) + len(series_without_seg_to_process)}")
print(f"Skipping (already processed): {len(already_processed & all_series)}")
print(f"Skipping (invalid): {len(INVALID_SERIES_SET & all_series)}")
print("="*80)

# ============================================================================
# PHASE 2: Process series with segmentations first
# ============================================================================

def process_batch(series_list, source_dir, batch_name):
    """Process a batch of series from a given source directory"""
    if len(series_list) == 0:
        print(f"\nNo series to process in {batch_name}")
        return []
    
    print(f"\n{'='*80}")
    print(f"{batch_name}")
    print(f"{'='*80}")
    
    results = []
    
    for series_name in tqdm(list(series_list), desc=batch_name):
        series_folder = source_dir / series_name
        
        if not series_folder.exists():
            print(f"\nWarning: Series folder not found: {series_folder}")
            continue
        
        try:
            result = process_series_with_crop(series_folder, localizers_df)
            if result is not None:
                series_uid = result['SeriesInstanceUID']
                
                # Save volumes
                # np.save(output_dir / f"{series_uid}.npy", result['volume_uint8'])
                np.save(output_dir_256_z / f"{series_uid}.npy", result['volume_uint8_256_z'])
                
                # Save percentage coordinates
                if result['percentage_coords']:
                    np.save(label_percentage_folder / f"{series_uid}.npy",
                           np.array(result['percentage_coords'], dtype=np.float32))
                
                # Save size information
                size_info = {
                    'SeriesInstanceUID': series_uid,
                    'original_dim0': result['original_shape'][0],
                    'original_dim1': result['original_shape'][1],
                    'original_dim2': result['original_shape'][2],
                    'cropped_dim0': result['cropped_shape'][0],
                    'cropped_dim1': result['cropped_shape'][1],
                    'cropped_dim2': result['cropped_shape'][2]
                }
                pd.DataFrame([size_info]).to_csv(size_csv_path, mode='a', header=False, index=False)
                
                results.append({
                    'SeriesInstanceUID': series_uid,
                    'Modality': result['Modality']
                })
        except Exception as e:
            print(f"\nError processing {series_name}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    return results

# Process segmentation series first
seg_results = process_batch(
    segmentation_series_to_process,
    series_dir,
    "PHASE 2: Processing series WITH segmentations"
)

# ============================================================================
# PHASE 3: Process remaining series without segmentations
# ============================================================================

other_results = process_batch(
    series_without_seg_to_process,
    series_dir,
    "PHASE 3: Processing series WITHOUT segmentations"
)

# ============================================================================
# FINAL SUMMARY
# ============================================================================

all_results = seg_results + other_results

print("\n" + "="*80)
print("FINAL SUMMARY")
print("="*80)
print(f"Series with segmentations processed: {len(seg_results)}")
print(f"Series without segmentations processed: {len(other_results)}")
print(f"Total successfully processed this run: {len(all_results)}")
print(f"Total already processed (skipped): {len(already_processed & all_series)}")
print(f"Total invalid (skipped): {len(INVALID_SERIES_SET & all_series)}")
print("="*80)

PHASE 1: Identifying series in segmentations and series directories

Series in segmentations: 178
Series in series: 4348
Series without segmentations: 4171
Already processed: 0

PROCESSING PLAN:
Phase 2 - Series with segmentations to process: 177
Phase 3 - Series without segmentations to process: 4150
Total to process: 4327
Skipping (already processed): 0
Skipping (invalid): 22

PHASE 2: Processing series WITH segmentations


PHASE 2: Processing series WITH segmentations:  45%|████▍     | 79/177 [08:14<07:22,  4.52s/it]




PHASE 2: Processing series WITH segmentations: 100%|██████████| 177/177 [14:32<00:00,  4.93s/it]



PHASE 3: Processing series WITHOUT segmentations


PHASE 3: Processing series WITHOUT segmentations: 100%|██████████| 4150/4150 [7:46:20<00:00,  6.74s/it]   


FINAL SUMMARY
Series with segmentations processed: 176
Series without segmentations processed: 4150
Total successfully processed this run: 4326
Total already processed (skipped): 0
Total invalid (skipped): 22





In [4]:
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm

# compare file with same name in ./label_percentage_z and ./label_percentage
print("\nChecking for label files in both label_percentage and label_percentage_z...")
label_percentage_dir = Path(r'./label_percentage')
label_percentage_z_dir = Path(r'./label_percentage_z')
label_percentage_files = {f.stem for f in label_percentage_dir.glob('*.npy')}
label_percentage_z_files = {f.stem for f in label_percentage_z_dir.glob('*.npy')}
common_label_files = label_percentage_files & label_percentage_z_files
print(f"Label files in label_percentage: {len(label_percentage_files)}")
print(f"Label files in label_percentage_z: {len(label_percentage_z_files)}")

for fname in tqdm(common_label_files):
    f1 = np.load(label_percentage_dir / f"{fname}.npy")
    f2 = np.load(label_percentage_z_dir / f"{fname}.npy")
    if not np.array_equal(f1, f2):
        print(f"Difference found in label file: {fname}.npy")


Checking for label files in both label_percentage and label_percentage_z...
Label files in label_percentage: 1842
Label files in label_percentage_z: 1842


100%|██████████| 1842/1842 [00:00<00:00, 2330.95it/s]
