#### Notebook for image preprocessing required in training the 3D-UNet for fine AC-PC localization

##### Steps
* Grab the coarse localized and true AC-PC landmarks, output from CoarseLocalization_Slicer.ipynb
* Grab the skull stripped brain scans, also output from CoarseLocalization_Slicer.ipynb
* Apply rotational augmentations, compute the corresponding 2 channel input patches and intermediately processed 4 channel       heatmap patches for 3D-UNet training. Also apply rotations to the coarse localized and true AC-PC landmarks
* Write information required for 3D-UNet training. This involves the image coordinates of the coarse localized and true
  AC-PC coordinates, along with their rotated version. It also involves the scaling factors required to modulate the 
  gaussian heatmaps and assemble the full 6-channel ground-truth required for 3D-UNet training

In [1]:
#Import libraries
import pandas as pd
import numpy as np
import SimpleITK as sitk
import os, re, time, sys
from pathlib import Path
from itertools import product
import concurrent.futures

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
def window_image(image_arr, window_center, window_width):
    """Windows a volume to be within the given soft-tissue HU range
    
    Args: 
        image_arr: numpy array of the image volume
        window_center: midpoint of the desired HU range
        window_width: total width of the desired HU range around the window_center
    Returns: 
        window_image: windowed image for optimal soft tissue viewing
    
    """
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    
    window_image = image_arr.copy()
    window_image[window_image < img_min] = img_min
    window_image[window_image > img_max] = img_max
    return window_image

In [3]:
def rotate_image(image, physical_coordinates, angles, resampling_type = 'linear', dimension = 3):
    """Rotates a 3D image and its associated landmarks based on the provided angles
    
    Args:
        image: SITK image object that needs to be rotated
        physical_coordinates: list of physical coordinates of associated landmarks
        angles: desired 3D angle of rotation
        resampling_type: linear or bspline - type of resampling used in generating the rotated image
        dimension: image dimension
    """
    
    #unpack the physical coordinates of associated AC-PC landmarks (true and coarse)
    ac_true_phys_coordinates = physical_coordinates[:3]
    pc_true_phys_coordinates = physical_coordinates[3:6]   
    ac_phys_coordinates = physical_coordinates[6:9]  
    pc_phys_coordinates = physical_coordinates[9:]
    
    #unpack theta values (for image rotation) and convert to radians
    theta_radians_x, theta_radians_y, theta_radians_z  = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2])
    
    #define the 3D rotational transformation
    transform_x = sitk.AffineTransform(dimension)
    transform_x.SetCenter(image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize())//2.0))
    
    transform_y = sitk.AffineTransform(dimension)
    transform_y.SetCenter(image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize())//2.0))
    
    transform_z = sitk.AffineTransform(dimension)
    transform_z.SetCenter(image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize())//2.0))

    transform_cor = sitk.AffineTransform(dimension)
    transform_cor.SetCenter(image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize())//2.0))
    
    matrix_x = np.array([[1.0, 0.0, 0.0],
                         [0.0, np.cos(theta_radians_x), -np.sin(theta_radians_x)],
                         [0.0, np.sin(theta_radians_x), np.cos(theta_radians_x)]]) #rotation around the x axis 
    
    matrix_y = np.array([[np.cos(theta_radians_y), 0.0, np.sin(theta_radians_y)],
                         [0.0, 1.0, 0.0],
                         [-np.sin(theta_radians_y), 0.0, np.cos(theta_radians_y)]])  #rotation around the y axis 
    
    matrix_z = np.array([[np.cos(theta_radians_z), -np.sin(theta_radians_z), 0.0],
                         [np.sin(theta_radians_z),  np.cos(theta_radians_z), 0.0],  #rotation around the z axis 
                         [0.0, 0.0, 1.0]])

    matrix_cor = np.array(image.GetDirection()).reshape(3,3)    
    
    transform_x.SetMatrix(matrix_x.ravel())
    transform_y.SetMatrix(matrix_y.ravel()) 
    transform_z.SetMatrix(matrix_z.ravel())
    transform_cor.SetMatrix(matrix_cor.ravel())
    
    composite_transform = sitk.CompositeTransform([transform_x, transform_y, transform_z, transform_cor])
    
    #derive the physical coordinates of the bounding box of the given image
    extreme_points = [image.TransformIndexToPhysicalPoint((0,0,0)), 
                      image.TransformIndexToPhysicalPoint((image.GetWidth(),0,0)),
                      image.TransformIndexToPhysicalPoint((image.GetWidth(),image.GetHeight(),0)),
                      image.TransformIndexToPhysicalPoint((0,image.GetHeight(),0)),
                      image.TransformIndexToPhysicalPoint((0,0,image.GetDepth())), 
                      image.TransformIndexToPhysicalPoint((image.GetWidth(),0,image.GetDepth())),
                      image.TransformIndexToPhysicalPoint((image.GetWidth(),image.GetHeight(),image.GetDepth())),
                      image.TransformIndexToPhysicalPoint((0,image.GetHeight(),image.GetDepth()))]
    
    #obtain the points where the extreme points of the image get mapped to, when they are transformed. This is required to 
    #specify the origin of the resampled and rotated image. Note that SITK uses the inverse of the specified transformation 
    inv_transform = composite_transform.GetInverse()

    extreme_points_transformed = [inv_transform.TransformPoint(pnt) for pnt in extreme_points]
    min_x = min(extreme_points_transformed)[0]
    min_y = min(extreme_points_transformed, key=lambda p: p[1])[1]
    min_z = min(extreme_points_transformed, key=lambda p: p[2])[2]
    max_x = max(extreme_points_transformed)[0]
    max_y = max(extreme_points_transformed, key=lambda p: p[1])[1]
    max_z = max(extreme_points_transformed, key=lambda p: p[2])[2]

    #transform the physical coordinates of the associated landmarks 
    landmarks_transformed = [inv_transform.TransformPoint(pnt) for pnt in 
                                                 [ac_true_phys_coordinates, pc_true_phys_coordinates, 
                                                    ac_phys_coordinates, pc_phys_coordinates]]
    
    
    # Use the original spacing (arbitrary decision).
    output_spacing = image.GetSpacing()
    # Identity direction cosine matrix.   
    output_direction = [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
    # Define the new origin.
    output_origin = [min_x, min_y, min_z]     
    #same size as input
    output_size = image.GetSize()
    
    #resample the image in its rotated space
    if resampling_type == 'linear':
        resampler = sitk.sitkLinear
    else:
        resampler = sitk.BSplineResampler
    
    rotated_image = sitk.Resample(image, output_size, composite_transform, resampler, output_origin, output_spacing,
                                  output_direction, defaultPixelValue = -3)
    
    return rotated_image, landmarks_transformed

In [4]:
def heatmap_processing(scan_id, rot_id, img_arr, patch_size, ac_true_coords, pc_true_coords,
                       ac_pred_coords, pc_pred_coords, sigma_list, write_path):
    
    
    """Generates ground-truth patches for the 3D-UNet. 
    
    Args:
        scan_id (str): Scan identifier.
        rot_id (int): Rotation identifier.
        img_arr (numpy array): The image data, windowed between 0-80/100 HU
        patch_size (int) : patch-size
        ac_true_coords (tuple of float): True AC coordinates in the image reference system. 
        pc_true_coords (tuple of float): True PC coordinates in the image reference system. 
        ac_pred_coords (tuple of float): Predicted AC coordinates from registration, in the image reference system. 
        pc_pred_coords (tuple of float): Predicted PC coordinates from registration, in the image reference system. 
        sigma_list(lsit of int): List of sigmas (standard deviation) to modulate the Gaussians with.
        write_path (string): Path to where the processed heatmap patches should be written. 
 
    Returns:
        scaling_factors_df_sigma (pandas dataframe): Contains the modulation factors for gaussian heatmap patches for 
        sigmas [4,6,8,10,12,14] 

    Implementation details:
        Define the heatmaps around the true landmarks (these include rotated positions). For each sigma, find the 
        modulating factors. Chop the heatmaps up based on the patch_size and save them. 
    """
    
    #unpack predicted and true AC-PC coordinates
    ac_x, ac_y, ac_z = ac_true_coords[0], ac_true_coords[1], ac_true_coords[2]
    pc_x, pc_y, pc_z = pc_true_coords[0], pc_true_coords[1], pc_true_coords[2]

    ac_x_pred, ac_y_pred, ac_z_pred = ac_pred_coords[0], ac_pred_coords[1], ac_pred_coords[2]
    pc_x_pred, pc_y_pred, pc_z_pred = pc_pred_coords[0], pc_pred_coords[1], pc_pred_coords[2]
    
    #its more efficient to store the spatial locations of the image array and the associated value of the gaussian heatmap
    #at those values rather than doing passes over the 3D image dimensions (i, j, k) and computing this as a 3D array to begin with
    gaussian_heatmap_df = pd.DataFrame({'z':np.where(img_arr > -1)[0],
                                        'y':np.where(img_arr > -1)[1], 
                                        'x':np.where(img_arr > -1)[2]})
    
    #This only computes the (x-mu_x)**2 + (y-mu_y)**2 + (z-mu_z)**2 part of the gaussian heatmap. 
    #These intermediate heatmaps get modulated by the chosen sigma and exponentiated in a data loader process during training
    gaussian_heatmap_df['ac_val'] = ((gaussian_heatmap_df['z']-ac_z)**2 +
                                          (gaussian_heatmap_df['y'] - ac_y)**2 + 
                                          (gaussian_heatmap_df['x'] - ac_x)**2)
    gaussian_heatmap_df['pc_val'] = ((gaussian_heatmap_df['z']-pc_z)**2 +
                                          (gaussian_heatmap_df['y'] - pc_y)**2 + 
                                          (gaussian_heatmap_df['x'] - pc_x)**2)


    gaussian_heatmap_ac = gaussian_heatmap_df['ac_val'].values.reshape(img_arr.shape)

    gaussian_heatmap_pc = gaussian_heatmap_df['pc_val'].values.reshape(img_arr.shape)
    
    #for all possible sigma values, calculate the minimum and maximum values of the sigma-modulated gaussian heatmaps
    #these values will be used to scale the heatmaps during training to obtain a normalized (0-1) heatmap
    scaling_factors_df_sigma_all = pd.DataFrame()
    for sigma in sigma_list:
        gaussian_heatmap_ac_sigma = np.exp(-1/(2*sigma**2) * gaussian_heatmap_ac)    
        gaussian_heatmap_pc_sigma = np.exp(-1/(2*sigma**2) * gaussian_heatmap_pc)
    
        scaling_factors_df_sigma_all = pd.concat([scaling_factors_df_sigma_all, 
                                                  pd.DataFrame({'scan_id':[scan_id],
                                                                'rot_id':[rot_id], 'sigma':[sigma],
                                                                'min_gaussian_heatmap_ac':[gaussian_heatmap_ac_sigma.min()],
                                                                'max_gaussian_heatmap_ac':[gaussian_heatmap_ac_sigma.max()],
                                                                'min_gaussian_heatmap_pc':[gaussian_heatmap_pc_sigma.min()],
                                                                'max_gaussian_heatmap_pc':[gaussian_heatmap_pc_sigma.max()]})])
    
    #patch the heatmap up based on the coarse localized AC-PC. Note that we do not need to store full heatmaps
    #as the sigma modulation is on the voxel level. This trick enables us to work in memory-constrained setups. 
    ac_patch_ac_hm = gaussian_heatmap_ac[(int(np.round(ac_z_pred) - patch_size//2)):(int(np.round(ac_z_pred) + patch_size//2)),
                                         (int(np.round(ac_y_pred) - patch_size//2)):(int(np.round(ac_y_pred) + patch_size//2)),
                                         (int(np.round(ac_x_pred) - patch_size//2)):(int(np.round(ac_x_pred) + patch_size//2))]  
    ac_patch_pc_hm = gaussian_heatmap_pc[(int(np.round(ac_z_pred) - patch_size//2)):(int(np.round(ac_z_pred) + patch_size//2)),
                                         (int(np.round(ac_y_pred) - patch_size//2)):(int(np.round(ac_y_pred) + patch_size//2)),
                                         (int(np.round(ac_x_pred) - patch_size//2)):(int(np.round(ac_x_pred) + patch_size//2))]
    pc_patch_ac_hm = gaussian_heatmap_ac[(int(np.round(pc_z_pred) - patch_size//2)):(int(np.round(pc_z_pred) + patch_size//2)),
                                         (int(np.round(pc_y_pred) - patch_size//2)):(int(np.round(pc_y_pred) + patch_size//2)),
                                         (int(np.round(pc_x_pred) - patch_size//2)):(int(np.round(pc_x_pred) + patch_size//2))]
    pc_patch_pc_hm = gaussian_heatmap_pc[(int(np.round(pc_z_pred) - patch_size//2)):(int(np.round(pc_z_pred) + patch_size//2)),
                                         (int(np.round(pc_y_pred) - patch_size//2)):(int(np.round(pc_y_pred) + patch_size//2)),
                                         (int(np.round(pc_x_pred) - patch_size//2)):(int(np.round(pc_x_pred) + patch_size//2))]
    
    #assemble the intermediately processed heatmap. Note that these do not contain the background channel as the
    #background channel values depend upon the sigma modulated AC and PC channels, which will be computed on the fly during 
    #training. 
    gt_pat = np.concatenate((pc_patch_ac_hm, ac_patch_ac_hm, pc_patch_pc_hm, ac_patch_pc_hm)).reshape(4, patch_size, patch_size, patch_size)


    hm_write_path = write_path / scan_id
    if not os.path.exists(hm_write_path):
        os.makedirs(hm_write_path)
    with open(os.path.join(str(hm_write_path.resolve()), f'hm_patches_Rot_{rot_id}.npy'), mode='wb+') as f:
        np.save(f, gt_pat)      

    return scaling_factors_df_sigma_all

In [5]:
def generate_input_patches(scan_id, rot_id, img_arr, patch_size, ac_pred_coords, pc_pred_coords, write_path):
    

    """Generates input patches for the 3D-UNet. 

    Args:
        scan_id (str): Scan identifier.
        rot_id (int): Rotation identifier.
        img_arr (numpy array): The image data, windowed between 0-80/100 HU
        patch_size (tuple of int) : 3D patch-size.
        ac_pred_coords (tuple of float): Predicted AC coordinates from registration, in the image reference system. 
        pc_pred_coords (tuple of float): Predicted PC coordinates from registration, in the image reference system. 
        write_path (string): Path to where the processed heatmap patches should be written

    """
    #gather coarse AC-PC coordinates to crop the input
    ac_x_pred, ac_y_pred, ac_z_pred = ac_pred_coords[0], ac_pred_coords[1], ac_pred_coords[2]
    pc_x_pred, pc_y_pred, pc_z_pred = pc_pred_coords[0], pc_pred_coords[1], pc_pred_coords[2]

    #crop image patches of size patch_size around the coarse landmarks
    voxel_patch_ac = img_arr[(int(np.round(ac_z_pred))-patch_size//2):(int(np.round(ac_z_pred))+patch_size//2),
                             (int(np.round(ac_y_pred))-patch_size//2):(int(np.round(ac_y_pred))+patch_size//2),
                             (int(np.round(ac_x_pred))-patch_size//2):(int(np.round(ac_x_pred))+patch_size//2)]

    voxel_patch_pc = img_arr[(int(np.round(pc_z_pred))-patch_size//2):(int(np.round(pc_z_pred))+patch_size//2),
                             (int(np.round(pc_y_pred))-patch_size//2):(int(np.round(pc_y_pred))+patch_size//2),
                             (int(np.round(pc_x_pred))-patch_size//2):(int(np.round(pc_x_pred))+patch_size//2)]
    #concatenate AC and PC channels to form the 2 channel input
    input_pat = np.concatenate((voxel_patch_pc,voxel_patch_ac)).reshape(2, patch_size, patch_size, patch_size)

    #save the input to given location
    ip_write_path = write_path / scan_id
    if not os.path.exists(ip_write_path):
        os.makedirs(ip_write_path)
    with open(os.path.join(str(ip_write_path.resolve()), f'input_patches_Rot_{rot_id}.npy'), mode='wb+') as f:
        np.save(f, input_pat)      


In [23]:
def process_scan(scan_id):
    
    #read the axial image in
    axial_img_path = str((data_path / str(scan_id).lstrip(os.sep) / "Axial brain.nii").resolve())

    try:            
        img = sitk.ReadImage(axial_img_path)
    except Exception as e: 
        sys.exit(f'Image read error for scan {scan_id}: {e}')

    img_labels = all_landmarks[all_landmarks['scan_id'] == scan_id][['ac_gt', 'pc_gt', 'ac', 'pc']].values

    #read the csv containing pre computed coarse landmarks and the true AC-PC locations
    try:
        physical_coordinates = np.array([np.float64(x.strip("[]")) for x in img_labels[0][0].split(",")] + 
                             [np.float64(x.strip("[]")) for x in img_labels[0][1].split(",")] +
                             [np.float64(x.strip("[]")) for x in img_labels[0][2].split(",")] +
                             [np.float64(x.strip("[]")) for x in img_labels[0][3].split(",")])
    except IndexError:
        sys.exit(f'Coarse and ground-truth AC-PC coordinates not found for scan {scan_id}')

    #start processing 

    #will contain the Gaussian scaling factors per sigma value to assemble full heatmaps during training 
    scaling_factors_df = pd.DataFrame() 
    #will contain the coarse and reference standard AC-PC landmarks, both physical and image coordinates 
    scan_info_df = pd.DataFrame()

    id_ = 0 
    for deg in rot_angle_combinations:  
        if id_ > 0:
            #if this is a rotational case
            deg_x, deg_y, deg_z = deg

            #apply the desired 3D rotation to both the image and the coarse and reference standard AC-PC landmarks
            rotated_image, landmarks_transformed = rotate_image(img, physical_coordinates, (deg_x, deg_y, deg_z)) 

            #get the array data from the rotated volume
            img_arr = window_image(sitk.GetArrayFromImage(rotated_image), window_center = 50, window_width = 100)

            #convert **rotated** coarse and reference standard AC-PC coordinates from the physical to the image space
            ac_true_coordinates = list(rotated_image.TransformPhysicalPointToContinuousIndex(landmarks_transformed[0]))             
            pc_true_coordinates = list(rotated_image.TransformPhysicalPointToContinuousIndex(landmarks_transformed[1]))    
            ac_coordinates = list(rotated_image.TransformPhysicalPointToContinuousIndex(landmarks_transformed[2]))    
            pc_coordinates = list(rotated_image.TransformPhysicalPointToContinuousIndex(landmarks_transformed[3]))

            #gather **rotated** physical coordinates of the coarse and reference standard AC-PC
            physical_landmarks_ac_true = landmarks_transformed[0]
            physical_landmarks_pc_true = landmarks_transformed[1]
            physical_landmarks_ac = landmarks_transformed[2]
            physical_landmarks_pc = landmarks_transformed[3]

        else:
            #if processing the original image without rotations
            img_arr = window_image(sitk.GetArrayFromImage(img),  window_center = 50, window_width = 100) 

            #convert coarse and reference standard AC-PC coordinates from the physical to the image space
            ac_true_coordinates = list(img.TransformPhysicalPointToContinuousIndex(physical_coordinates[:3]))
            pc_true_coordinates = list(img.TransformPhysicalPointToContinuousIndex(physical_coordinates[3:6]))    
            ac_coordinates = list(img.TransformPhysicalPointToContinuousIndex(physical_coordinates[6:9]))    
            pc_coordinates = list(img.TransformPhysicalPointToContinuousIndex(physical_coordinates[9:]))

            #gather physical coordinates of the coarse and reference standard AC-PC
            physical_landmarks_ac_true = physical_coordinates[:3]
            physical_landmarks_pc_true = physical_coordinates[3:6]
            physical_landmarks_ac = physical_coordinates[6:9]
            physical_landmarks_pc = physical_coordinates[9:]

        #sanity checks to ensure that image coordinates are not negative, because they can't be.
        #If you used our notebook for preprocessing and coarse localization, the origin of these images would be at 0,0,0 
        #and all these scans are ensured to have a direction cosine of identity. So no image coordinates can be negative 
        if (ac_true_coordinates[0] < 0) | (ac_true_coordinates[1] < 0) | (ac_true_coordinates[2] < 0):
            sys.exit(f'Image coordinates negative for scan {scan_id}. Recheck processing')

        if (pc_true_coordinates[0] < 0) | (pc_true_coordinates[1] < 0) | (pc_true_coordinates[2] < 0):
            sys.exit(f'Image coordinates negative for scan {scan_id}. Recheck processing')

        if (ac_coordinates[0] < 0) | (ac_coordinates[1] < 0) | (ac_coordinates[2] < 0):
            sys.exit(f'Image coordinates negative for scan {scan_id}. Recheck processing')

        if (pc_coordinates[0] < 0) | (pc_coordinates[1] < 0) | (pc_coordinates[2] < 0):
            sys.exit(f'Image coordinates negative for scan {scan_id}. Recheck processing')
        
        #generate inputs for each rotation. These are essentially 3D patches of the brain scan, cropped around the 
        #coarse AC-PC landmarks
        try:
            generate_input_patches(scan_id, id_, img_arr, patch_size, ac_coordinates, pc_coordinates, ip_patches_write_path)
        except Exception as e:
            sys.exit(f'Issue with input patch generation for scan_id {scan_id}: {e}')

        #generate half-processed heatmaps for each rotation. These are intermediately processed, 4 channel heatmaps (AC and PC 
        #channels/gaussians, both cropped around the coarse AC-PC landmarks). 
        #This also returns scaling factors that are required to assemble full heatmaps 
        try: 
            scaling_factors_scan = heatmap_processing(scan_id, id_, img_arr, patch_size, ac_true_coordinates, pc_true_coordinates,
                           ac_coordinates, pc_coordinates, sigma_list, heatmap_patches_write_path)
        except Exception as e:
            sys.exit(f'Issue with heatmap patch generation for scan_id {scan_id}: {e}')

        #put together image coordinates of the coarse localized and true AC-PC landmarks -- along with those corresponding to 
        #rotated images for augmentation
        scan_info_df = pd.DataFrame({'scan_id':[scan_id],
                                            'deg':[deg],
                                            'rot_id':[id_],
                                            'pc':[physical_landmarks_pc],
                                            'ac':[physical_landmarks_ac],
                                            'pc_true':[physical_landmarks_pc_true],
                                            'ac_true':[physical_landmarks_ac_true],
                                            'pc_img':[pc_coordinates],
                                            'ac_img':[ac_coordinates],
                                            'pc_img_true':[pc_true_coordinates],
                                            'ac_img_true':[ac_true_coordinates],
                                            })


        id_ = id_ + 1
        
        scaling_factors_write_path.parent.mkdir(parents=True, exist_ok=True)
        #write the required information for training the 3D-UNet
        scaling_factors_scan.to_csv(str(scaling_factors_write_path.resolve()), mode='a', header=not os.path.exists(scaling_factors_write_path), 
                                   index = False)     
        scan_info_df.to_csv(str(scan_info_write_path.resolve()), mode='a', header=not os.path.exists(scan_info_write_path), 
                           index = False)


In [6]:
#Setup data read and write paths

In [7]:
root = Path()

#raw nifti location 
data_path = root / "brain_vols" 

#ground-truth or reference standard AC-PC annotations
gt_ann_path =  root / "acpc_annotations/acpc_gt.csv" 

#coarse localized AC-PC landmarks 
coarse_acpc_path = root / "acpc_annotations/acpc_coarse.csv" 

In [8]:
#contains both physical and image coordinates of the coarse localized and reference standard AC-PC. Note that these are
#computed for every chosen rotation of the input data for augmentation during training 
scan_info_write_path = root / "files_for_unet/scan_info.csv"

#contains the scaling factors for Gaussian heatmaps at different sigma levels. These help assemble the full heatmap on the 
#fly during training
scaling_factors_write_path =  root / "files_for_unet/scaling_factors_info.csv" 

In [9]:
ip_patches_write_path = root / "patched_data_4unet/ip_patches"
heatmap_patches_write_path = root / "patched_data_4unet/op_patches"

In [10]:
acpc_df_gt = pd.read_csv(gt_ann_path).drop(columns = 'Unnamed: 0')
acpc_df_coarse = pd.read_csv(coarse_acpc_path).drop(columns = 'Unnamed: 0')

In [11]:
#merge coarse localized and true AC-PC landmarks into one file
all_landmarks = acpc_df_gt.rename(columns = {'ac':'ac_gt','pc':'pc_gt'}).merge(acpc_df_coarse, how = 'inner', on = 'scan_id')

In [12]:
##Define the 3D rotations (theta_x, theta_y, and theta_z) needed for augmentation 

####### Original version used in the paper - generates 108 rotations per scan, to enable random sampling of a chosen
####### number of random rotations for augmentation during training - Modify as required
# x_rots = [rot for rot in np.arange(-10,12,2.5) if abs(rot) > 3] #rotation around the x-axis
# #rotation around the y-axis (chose a limited range due to the naturally constrained rotations of patient heads around 
# #the anterior-posterior axis)
# y_rots = [rot for rot in np.arange(-2.5,3.5,2.5)] 
# z_rots = [rot for rot in np.arange(-10,12,2.5) if abs(rot) > 3] #rotation around the z axis

#Generating 48 rotations per scan for this demonstration
x_rots = [rot for rot in np.arange(-5,6,2.5) if abs(rot) > 2] #rotation around the x-axis
#rotation around the y-axis (chose a limited range due to the naturally constrained rotations of patient heads around 
#the anterior-posterior axis)
y_rots = [rot for rot in np.arange(-2.5,3.5,2.5)] 
z_rots = [rot for rot in np.arange(-5,6,2.5) if abs(rot) > 2] #rotation around the z axis

#define the combinations
rot_angle_combinations = [(x,y,z) for x,y,z in list(product(x_rots,y_rots,z_rots))]

#add the unrotated original version (corresponding to rotation 0,0,0)
rot_angle_combinations = [(0,0,0)] + rot_angle_combinations
print(len(rot_angle_combinations))


49


In [13]:
scan_ids = all_landmarks['scan_id'].values

In [14]:
len(scan_ids) 
#This will depend upon the size of your dataset for training/inference.
#Note that there are only 5 scans for the sake of this demonstration. 

5

In [15]:
patch_size = 32
sigma_list = [4, 6, 8, 10, 12, 14]

In [None]:
##Serial processing for bebugging etc. 
# for scan_id in scan_ids:
#     process_scan(scan_id)

In [None]:
#use parallel processing as it significant speeds up processing times. Note that you may not see gains in this demonstration
#as we are using a small number of scans. This will be noticeable when dataset size is increased

In [20]:
start = time.time()
with concurrent.futures.ProcessPoolExecutor() as executor:
    results = executor.map(process_scan, scan_ids)
    
print([x for x in results])

print((time.time()-start)/60)

[None, None, None, None, None]
5.040703837076823


In [21]:
scaling_factors_scan = pd.read_csv(str(scaling_factors_write_path.resolve()))
scan_info_df = pd.read_csv(str(scan_info_write_path.resolve()))

In [22]:
scan_info_df.head()

Unnamed: 0.1,Unnamed: 0,scan_id,deg,rot_id,pc,ac,pc_true,ac_true,pc_img,ac_img,pc_img_true,ac_img_true
0,0,CQ500-CT-105,"(0, 0, 0)",0,[ -2.47016309 -17.91535228 51.14641125],[ -2.14664063 -39.84221152 52.82896874],[ -1.81597164 -18.19665342 51.25 ],[ -1.54456401 -40.40876649 49.69793696],"[98.07016156164178, 94.5227512717641, 52.39641...","[97.74663910523466, 116.44961051112443, 54.078...","[97.41597011435118, 94.80405241173364, 52.5]","[97.14456248283386, 117.01616547394497, 50.947..."
1,0,CQ500-CT-77,"(0, 0, 0)",0,[ -7.51852831 -17.21293653 48.99320025],[ -7.33458761 -39.5442935 47.58979143],[ -6.708891 -16.60974406 44.96885749],[ -6.57737804 -34.96270884 39.625 ],"[121.01852831431479, 117.06583417559821, 58.74...","[120.83458760668319, 139.39719114306388, 57.33...","[120.20889100432396, 116.46264170455491, 54.71...","[120.07737803788473, 134.81560648898676, 49.37..."
2,0,CQ500-CT-74,"(0, 0, 0)",0,[-1.92089237 11.74143367 35.25788138],[-1.96731448 -9.54547283 36.54762288],[-2.05856219 11.13124493 34. ],[-1.81952056 -8.7073242 34. ],"[117.92089237476993, 94.20156712148332, 68.757...","[117.96731447988752, 115.48847362645417, 70.04...","[118.0585621919468, 94.81175586434026, 67.4999...","[117.81952055583608, 114.65032498981941, 67.49..."
3,0,CQ500-CT-105,"(-5.0, -2.5, -5.0)",1,"(-4.970108889516774, -30.459842125450223, 51.6...","(-7.198446916562974, -8.580732496277022, 51.37...","(-5.6498200602436, -30.228945166375123, 51.672...","(-7.685386584953263, -8.3307090201708, 48.1823...","[110.90316326884052, 106.72614061182838, 65.69...","[108.67482524179432, 128.60525024100158, 65.44...","[110.2234520981137, 106.95703757090348, 65.746...","[108.18788557340403, 128.85527371710782, 62.25..."
4,0,CQ500-CT-185,"(0, 0, 0)",0,[ 1.19500692 21.28288092 32.08546628],[-0.54788461 -1.11507267 36.92791684],[ 2.04693699 20.68501611 31.53346886],[ 1.02349901 -0.52271671 32.43025278],"[128.80499307831556, 128.10511956781798, 83.33...","[130.54788461319743, 150.50307315531816, 88.17...","[127.95306301116943, 128.70298438173072, 82.78...","[128.9765009880066, 149.91071719553895, 83.680..."


Attach full paths to the cropped inputs and half-assembled ground-truth heatmaps to facilitate the data loader for 3D-UNet training, write them out

In [33]:
scan_info_df['scan_path'] = scan_info_df[['scan_id', 'rot_id']].apply(lambda x: os.path.join(x['scan_id'], f"input_patches_Rot_{x['rot_id']}.npy"), axis = 1)

In [38]:
scan_info_df['gt_heatmap_path'] = scan_info_df[['scan_id', 'rot_id']].apply(lambda x: os.path.join(x['scan_id'], f"hm_patches_Rot_{x['rot_id']}.npy"), axis = 1)

In [34]:
scaling_factors_scan['gt_heatmap_path'] = scaling_factors_scan[['scan_id', 'rot_id']].apply(lambda x: os.path.join(x['scan_id'], f"hm_patches_Rot_{x['rot_id']}.npy"), axis = 1)

In [35]:
scan_info_df['scan_path'].values

array(['CQ500-CT-105/input_patches_Rot_0.npy',
       'CQ500-CT-77/input_patches_Rot_0.npy',
       'CQ500-CT-74/input_patches_Rot_0.npy',
       'CQ500-CT-105/input_patches_Rot_1.npy',
       'CQ500-CT-185/input_patches_Rot_0.npy',
       'CQ500-CT-50/input_patches_Rot_0.npy',
       'CQ500-CT-77/input_patches_Rot_1.npy',
       'CQ500-CT-74/input_patches_Rot_1.npy',
       'CQ500-CT-105/input_patches_Rot_2.npy',
       'CQ500-CT-105/input_patches_Rot_3.npy',
       'CQ500-CT-77/input_patches_Rot_2.npy',
       'CQ500-CT-185/input_patches_Rot_1.npy',
       'CQ500-CT-74/input_patches_Rot_2.npy',
       'CQ500-CT-105/input_patches_Rot_4.npy',
       'CQ500-CT-50/input_patches_Rot_1.npy',
       'CQ500-CT-77/input_patches_Rot_3.npy',
       'CQ500-CT-74/input_patches_Rot_3.npy',
       'CQ500-CT-105/input_patches_Rot_5.npy',
       'CQ500-CT-185/input_patches_Rot_2.npy',
       'CQ500-CT-77/input_patches_Rot_4.npy',
       'CQ500-CT-105/input_patches_Rot_6.npy',
       'CQ500-CT-74/inpu

In [39]:
scan_info_df.to_csv(str(scan_info_write_path.resolve()), index = False)

In [37]:
scaling_factors_scan.to_csv(str(scaling_factors_write_path.resolve()), index = False)