In [5]:
import os
import glob
import numpy as np
import pydicom
import nibabel as nib
from pathlib import Path
from scipy.ndimage import zoom
from tqdm import tqdm

def get_direction_label(vec):
    """Convert orientation vector to direction label"""
    abs_vec = np.abs(vec)
    dominant_idx = np.argmax(abs_vec)
    dominant_val = vec[dominant_idx]
    
    if dominant_idx == 0:
        return 'RL' if dominant_val > 0 else 'LR'
    elif dominant_idx == 1:
        return 'AP' if dominant_val > 0 else 'PA'
    else:
        return 'FH' if dominant_val > 0 else 'HF'

def determine_dicom_orientation(iop):
    """Determine row and column directions from ImageOrientationPatient"""
    row_vec = np.array(iop[:3])
    col_vec = np.array(iop[3:6])
    
    dim2_dir = get_direction_label(row_vec)
    dim1_dir = get_direction_label(col_vec)
    
    dim2_axis = 0 if dim2_dir in ['LR', 'RL'] else (1 if dim2_dir in ['AP', 'PA'] else 2)
    dim1_axis = 0 if dim1_dir in ['LR', 'RL'] else (1 if dim1_dir in ['AP', 'PA'] else 2)
    
    used_axes = {dim2_axis, dim1_axis}
    dim0_axis = (set([0, 1, 2]) - used_axes).pop()
    dim0_dir = ['RL', 'AP', 'FH'][dim0_axis]
    
    return [dim0_dir, dim1_dir, dim2_dir]

def load_dicom_volume(series_path):
    """Load DICOM series and return volume with orientation info"""
    dcm_files = glob.glob(os.path.join(series_path, '*.dcm'))
    
    if len(dcm_files) == 0:
        return None, None, "No DICOM files found"
    
    if len(dcm_files) == 1:
        return None, None, "Multi-frame DICOM (skipped)"
    
    slices = [pydicom.dcmread(f) for f in dcm_files]
    slices.sort(key=lambda x: float(x.ImagePositionPatient[2]))
    
    volume = np.stack([s.pixel_array for s in slices], axis=0)
    orientation = determine_dicom_orientation(slices[0].ImageOrientationPatient)
    
    return volume, orientation, None

def determine_nifti_orientation(affine):
    """Determine orientation from NIfTI affine matrix"""
    ornt = nib.orientations.io_orientation(affine)
    axis_codes = nib.orientations.ornt2axcodes(ornt)
    
    direction_map = {
        'L': 'RL', 'R': 'LR', 'P': 'AP', 
        'A': 'PA', 'I': 'HF', 'S': 'FH'
    }
    
    return [direction_map[code] for code in axis_codes]

def align_volume_to_target(volume, source_orient, target_orient):
    """Align source volume to match target orientation"""
    dir_to_axis = {
        'LR': 0, 'RL': 0, 'PA': 1, 
        'AP': 1, 'HF': 2, 'FH': 2
    }
    
    transpose_order = []
    for target_dir in target_orient:
        target_axis = dir_to_axis[target_dir]
        for i, source_dir in enumerate(source_orient):
            if dir_to_axis[source_dir] == target_axis:
                transpose_order.append(i)
                break
    
    aligned = np.transpose(volume, transpose_order)
    
    for i, (target_dir, source_idx) in enumerate(zip(target_orient, transpose_order)):
        if target_dir != source_orient[source_idx]:
            aligned = np.flip(aligned, axis=i)
    
    return aligned

def resize_mask(mask, target_shape, order=0):
    """Resize mask using nearest neighbor interpolation"""
    zoom_factors = np.array(target_shape) / np.array(mask.shape)
    return zoom(mask, zoom_factors, order=order).astype(np.uint8)

def process_all_masks(seg_base_path, series_base_path, output_path):
    """Process all mask files and align with DICOM data"""
    
    # Create output directory
    os.makedirs(output_path, exist_ok=True)
    
    # Series to exclude due to dimension mismatches
    EXCLUDED_SERIES = {
        "1.2.826.0.1.3680043.8.498.12271269630687930751200307891697907423",
        "1.2.826.0.1.3680043.8.498.14375161350968928494386548917647435597",
        "1.2.826.0.1.3680043.8.498.50369188120242587742908379292729868174",
        "1.2.826.0.1.3680043.8.498.54865110953409154322874363435644372368",
        "1.2.826.0.1.3680043.8.498.68654901185438820364160878605611510817",
        "1.2.826.0.1.3680043.8.498.75294325392457179365040684378207706807",
        "1.2.826.0.1.3680043.8.498.97256479550884529885940791074752719030",
    }
    
    # Find all mask files (*_cowseg.nii)
    mask_files = glob.glob(os.path.join(seg_base_path, '*_cowseg.nii'))
    
    print(f"Found {len(mask_files)} mask files to process\n")
    
    successful = 0
    skipped = 0
    errors = 0
    
    for mask_path in tqdm(mask_files, desc="Processing masks"):
        # Extract series UID from mask filename
        mask_filename = Path(mask_path).stem
        series_uid = mask_filename.replace('_cowseg', '')
        
        # Skip excluded series
        if series_uid in EXCLUDED_SERIES:
            skipped += 1
            continue
        
        series_path = os.path.join(series_base_path, series_uid)
        
        # Check if series folder exists
        if not os.path.exists(series_path):
            print(f"\nSkipped {series_uid}: Series folder not found")
            skipped += 1
            continue
        
        try:
            # Load DICOM
            dcm_vol, dcm_orient, error = load_dicom_volume(series_path)
            if error:
                print(f"\nSkipped {series_uid}: {error}")
                skipped += 1
                continue
            
            # Load mask
            mask_img = nib.load(mask_path)
            mask_vol = mask_img.get_fdata()
            mask_orient = determine_nifti_orientation(mask_img.affine)
            
            # Align mask to DICOM orientation
            aligned_mask = align_volume_to_target(mask_vol, mask_orient, dcm_orient)

            if series_uid == "1.2.826.0.1.3680043.8.498.23047023542526806696555440426928375679":
                aligned_mask[102] = 0  # Fix known issue
            if series_uid == "1.2.826.0.1.3680043.8.498.11938739392606296532297884225608408867":
                aligned_mask[0] = 0

            # Check dimension match
            if dcm_vol.shape != aligned_mask.shape:
                print(f"\nSkipped {series_uid}: Dimension mismatch - DICOM {dcm_vol.shape} vs Mask {aligned_mask.shape}")
                skipped += 1
                continue
            
            # Crop at [-600:, :, :]
            crop_start = max(0, aligned_mask.shape[0] - 600)
            cropped_mask = aligned_mask[crop_start:, :, :]
            
            if np.any(aligned_mask[:crop_start, :, :]):
                print(f"\nWarning: Non-zero values found in cropped region for {series_uid}")
                # print the minimum depth of non-zero values
                non_zero_indices = np.argwhere(aligned_mask[:crop_start, :, :])
                if non_zero_indices.size > 0:
                    min_index = non_zero_indices[:, 0].min()
                    depth = aligned_mask.shape[0] - min_index
                    print(f"  Minimum index of non-zero value in cropped region: {min_index}, depth from end: {depth}, total depth: {aligned_mask.shape[0]}")
            
            # Resize to 256x256x256 using nearest neighbor
            mask_256 = resize_mask(cropped_mask, (256, 256, 256), order=0)
            
            # Save
            output_file = os.path.join(output_path, f"{series_uid}.npy")
            np.save(output_file, mask_256)
            
            successful += 1
            
        except Exception as e:
            print(f"\nError processing {series_uid}: {e}") 
            errors += 1
            continue
    
    # Summary
    print(f"\n{'='*80}")
    print("PROCESSING SUMMARY")
    print(f"{'='*80}")
    print(f"Successfully processed: {successful}")
    print(f"Skipped (excluded/not found/errors): {skipped}")
    print(f"Errors: {errors}")
    print(f"Total: {len(mask_files)}")
    print(f"{'='*80}")

if __name__ == "__main__":
    seg_base_path = r"E:\data_old\segmentations"
    series_base_path = r"E:\data_old\series"
    output_path = r"./mask_256"
    
    process_all_masks(seg_base_path, series_base_path, output_path)

Found 178 mask files to process



Processing masks:  33%|███▎      | 58/178 [02:19<07:57,  3.98s/it]




Processing masks:  33%|███▎      | 59/178 [02:30<12:06,  6.10s/it]

  Minimum index of non-zero value in cropped region: 101, depth from end: 775, total depth: 876


Processing masks:  49%|████▉     | 87/178 [03:34<04:30,  2.97s/it]


Skipped 1.2.826.0.1.3680043.8.498.43502795339700498960289295234851562632: Series folder not found


Processing masks:  95%|█████████▍| 169/178 [07:17<00:29,  3.28s/it]




Processing masks:  96%|█████████▌| 170/178 [07:29<00:46,  5.81s/it]

  Minimum index of non-zero value in cropped region: 19, depth from end: 970, total depth: 989


Processing masks: 100%|██████████| 178/178 [07:40<00:00,  2.59s/it]


PROCESSING SUMMARY
Successfully processed: 170
Skipped (excluded/not found/errors): 8
Errors: 0
Total: 178





In [6]:
1.2.826.0.1.3680043.8.498.23047023542526806696555440426928375679 [102] = 0

SyntaxError: invalid syntax (2830724151.py, line 1)