# Resampling pipeline for RSNA 2023 Abdominal Trauma Detection

This notebook loads and resamples the original images data to the same spacing (2mm, 2mm, 2mm) and size(320, 256, 256) and builds datasets and dataloaders with PyTorch.

Thanks to the functions to load .dcm files and construct 3D images provided from [https://www.kaggle.com/code/parhammostame/construct-3d-arrays-from-dcm-nii-3-view-angles](http://)

Using dataloader num_workers = 4, the preprocessing of test set can be done in about 1h 10 mins

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import os
import random
import re

from tqdm import tqdm

import pydicom as dicom
import nibabel as nib
import SimpleITK as sitk

import torch
import torch.nn as nn

In [None]:
SEED = 344
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True # Fix the network according to random seed
    print('Finish seeding with seed {}'.format(seed))
    
seed_everything(SEED)
print('Training on device {}'.format(device))

In [None]:
dicom_tag_columns = [
    'Columns',
    'ImageOrientationPatient',
    'ImagePositionPatient',
    'InstanceNumber',
    'PatientID',
    'PatientPosition',
    'PixelSpacing',
    'RescaleIntercept',
    'RescaleSlope',
    'Rows',
    'SeriesNumber',
    'SliceThickness',
    'path',
    'WindowCenter',
    'WindowWidth'
]

train_dicom_tags = pd.read_parquet('/kaggle/input/rsna-2023-abdominal-trauma-detection/train_dicom_tags.parquet', columns=dicom_tag_columns)
test_dicom_tags = pd.read_parquet('/kaggle/input/rsna-2023-abdominal-trauma-detection/test_dicom_tags.parquet', columns=dicom_tag_columns)

train_series_meta = pd.read_csv('/kaggle/input/rsna-2023-abdominal-trauma-detection/train_series_meta.csv')
test_series_meta = pd.read_csv('/kaggle/input/rsna-2023-abdominal-trauma-detection/test_series_meta.csv')

test_dicom_tags

In [None]:
#train_dicom_tags[train_dicom_tags.RescaleIntercept == 0]
train_dicom_tags[train_dicom_tags.RescaleIntercept == 0]

In [None]:
dfilter = (train_dicom_tags['SliceThickness'] == 3.0)
dfilter = train_dicom_tags[dfilter].PatientID.drop_duplicates().values.astype('int64')
train_series_meta[train_series_meta['patient_id'].isin(dfilter)]

In [None]:
def raw_path_gen(patient_id, series_id, train=True):
    path = '/kaggle/input/rsna-2023-abdominal-trauma-detection/'
    if(train):
        path += 'train_images/'
    else:
        path += 'test_images/'
    
    path += str(patient_id) + '/' + str(series_id)
    
    return path

def create_3D_scans(folder, downsample_rate=1): 
    filenames = os.listdir(folder)
    filenames = [int(filename.split('.')[0]) for filename in filenames]
    filenames = sorted(filenames)
    filenames = [str(filename) + '.dcm' for filename in filenames]
        
    volume = []
    #for filename in tqdm(filenames[::downsample_rate], position=0): 
    for filename in filenames[::downsample_rate]: 
        filepath = os.path.join(folder, filename)
        ds = dicom.dcmread(filepath)
        image = ds.pixel_array
        
        if ds.PixelRepresentation == 1:
            bit_shift = ds.BitsAllocated - ds.BitsStored
            dtype = image.dtype 
            image = (image << bit_shift).astype(dtype) >>  bit_shift
        
        # find rescale params
        if ("RescaleIntercept" in ds) and ("RescaleSlope" in ds):
            intercept = float(ds.RescaleIntercept)
            slope = float(ds.RescaleSlope)
    
        # find clipping params
        center = int(ds.WindowCenter)
        width = int(ds.WindowWidth)
        low = center - width / 2
        high = center + width / 2    
        
        
        image = (image * slope) + intercept
        image = np.clip(image, low, high)

        image = (image / np.max(image) * 255).astype(np.int16)
        image = image[::downsample_rate, ::downsample_rate]
        volume.append( image )
    
    volume = np.stack(volume, axis=0)
    return volume

In [None]:
def test_3D_scans(folder, downsample_rate=1): 
    filenames = os.listdir(folder)
    filenames = [int(filename.split('.')[0]) for filename in filenames]
    filenames = sorted(filenames)
    filenames = [str(filename) + '.dcm' for filename in filenames]
        
    volume = []
    #for filename in tqdm(filenames[::downsample_rate], position=0): 
    for filename in filenames[::downsample_rate]: 
        filepath = os.path.join(folder, filename)
        ds = dicom.dcmread(filepath)
        image = ds.pixel_array
        
        if ds.PixelRepresentation == 1:
            print(filepath)
            bit_shift = ds.BitsAllocated - ds.BitsStored
            dtype = image.dtype 
            image = (image << bit_shift).astype(dtype) >>  bit_shift
            return
        break

In [None]:
# for i in tqdm(range(0, len(train_series_meta)), position=0):
#     patient_id, series_id = train_series_meta.loc[i, ["patient_id", "series_id"]].astype('int')
#     filepath = raw_path_gen(patient_id, series_id)
#     test_3D_scans(filepath)

In [None]:
def plot_image_with_seg(volume, volume_seg=[], orientation='Coronal', num_subplots=20):
    # simply copy
    if len(volume_seg) == 0:
        plot_mask = 0
    else:
        plot_mask = 1
        
    if orientation == 'Coronal':
        slices = np.linspace(0, volume.shape[2]-1, num_subplots).astype(np.int16)
        volume = volume.transpose([1, 0, 2])
        if plot_mask:
            volume_seg = volume_seg.transpose([1, 0, 2])
        
    elif orientation == 'Sagittal':
        slices = np.linspace(0, volume.shape[2]-1, num_subplots).astype(np.int16)
        volume = volume.transpose([2, 0, 1])
        if plot_mask:
            volume_seg = volume_seg.transpose([2, 0, 1])

    elif orientation == 'Axial':
        slices = np.linspace(0, volume.shape[0]-1, num_subplots).astype(np.int16)
           
    rows = np.max( [np.floor(np.sqrt(num_subplots)).astype(int) - 2, 1])
    cols = np.ceil(num_subplots/rows).astype(int)
    
    fig, ax = plt.subplots(rows, cols, figsize=(cols * 2, rows * 4))
    fig.tight_layout(h_pad=0.01, w_pad=0)
    
    ax = ax.ravel()
    for this_ax in ax:
        this_ax.axis('off')

    for counter, this_slice in enumerate( slices ):
        plt.sca(ax[counter])
        
        image = volume[this_slice, :, :]
        plt.imshow(image, cmap='gray')
        
        if plot_mask:
            mask = np.where(volume_seg[this_slice, :, :], volume_seg[this_slice, :, :], np.nan)
            plt.imshow(mask, cmap='Set1', alpha=0.5)  

In [None]:
# filepath = raw_path_gen(14429, 57624)
# volume = create_3D_scans(filepath)
# print(f'3D Image file shape: {volume.shape}')

In [None]:
# plot_image_with_seg(volume, orientation='Axial', num_subplots=5)
# plot_image_with_seg(volume, orientation='Sagittal', num_subplots=5)

In [None]:
#load series and resample
import re

fake_img = sitk.ReadImage('/kaggle/input/rsna-2023-abdominal-trauma-detection/test_images/48843/62825/30.dcm')

def load_resample(patient_id, series_id, train = True,
                  target_spacing = (2.0, 2.0, 2.0),
                  target_size = (256, 256, 320)):
    filepath = raw_path_gen(patient_id, series_id, train=train)
    if not os.path.exists(filepath):
        return fake_img
    
    if (train):
        dicom_tags = train_dicom_tags
    else:
        dicom_tags = test_dicom_tags
    
    filenames = os.listdir(filepath)
    
    if(len(filenames) < 2):
        return sitk.GetImageFromArray(create_3D_scans(filepath))
    
    dicom_name1 = filepath + '/' + filenames[0]
    dicom_name1 = dicom_name1.split('/')[-4:]
    dicom_name1 = '/'.join(dicom_name1)
    dicom_tags1 = dicom_tags[dicom_tags.path == dicom_name1]
    
    dicom_name2 = filepath + '/' + filenames[1]
    dicom_name2 = dicom_name2.split('/')[-4:]
    dicom_name2 = '/'.join(dicom_name2)
    dicom_tags2 = dicom_tags[dicom_tags.path == dicom_name2]
    
    dicom_po1 = dicom_tags1.ImagePositionPatient.values[0]
    dicom_po2 = dicom_tags2.ImagePositionPatient.values[0]
    dicom_po1 = float(re.split(', |]', dicom_po1)[-2])
    dicom_po2 = float(re.split(', |]', dicom_po2)[-2])
    
    dicom_dz = abs((dicom_po1 - dicom_po2) / (dicom_tags1.InstanceNumber.values[0] - dicom_tags2.InstanceNumber.values[0]))

    dicom_spacing_x = float(re.split(', |]|\[', dicom_tags1.PixelSpacing.values[0])[-2])
    dicom_spacing_y = float(re.split(', |]|\[', dicom_tags1.PixelSpacing.values[0])[1])

    if (dicom_dz <= 1.0):
        downsample_rate = 2
    else:
        downsample_rate = 1
    
    dicom_spacing_x = dicom_spacing_x * downsample_rate
    dicom_spacing_y = dicom_spacing_y * downsample_rate
    dicom_dz = dicom_dz * downsample_rate
    
    original_image = create_3D_scans(filepath, downsample_rate)
    original_image = sitk.GetImageFromArray(original_image)
    original_image.SetSpacing((dicom_spacing_x,
                        dicom_spacing_y,
                        dicom_dz))
    

    original_size = original_image.GetSize()
    original_spacing = original_image.GetSpacing()
    
    r_x = dicom_spacing_x / target_spacing[0]
    r_y = dicom_spacing_y / target_spacing[1]
    r_z = dicom_dz / target_spacing[2]
    
    o_x = -(target_size[0] - original_size[0] * r_x)
    o_y = -(target_size[0] - original_size[1] * r_y)
    o_z = -(target_size[2] - original_size[2] * r_z)
    #print((target_size[0] - original_size[0] * r_x) / 2)
    
    
    target_origin = (o_x, o_y, o_z)
    
    
    #print(original_spacing)
    #print(original_size)
    
    resampler = sitk.ResampleImageFilter()
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetOutputSpacing(target_spacing)
    resampler.SetSize(target_size)
    resampler.SetOutputOrigin(target_origin)
    resampler.SetDefaultPixelValue(-150)
    resampled_image = resampler.Execute(original_image)
    #print(resampled_image.GetSize())
    
    
    return resampled_image

    #49954, 41479
    #10026, 42932
    
    
# img = load_resample(10004, 21057)
# sitk.WriteImage(img, './test.nii.gz')
# image_a = sitk.GetArrayFromImage(img)
# #image_a = np.flip(image_a, 0)
# #image_a = np.clip(image_a, -150, 250)

# print(image_a.shape)
# plot_image_with_seg(image_a, orientation='Axial', num_subplots=5)
# plot_image_with_seg(image_a, orientation='Sagittal', num_subplots=5)

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class CTDataset(Dataset):
    def __init__(self, train=True, device='cpu'):
        self.train = train
        self.device = device
        if(train):
            self.series_meta = train_series_meta
        else:
            self.series_meta = test_series_meta
        
    def __len__(self):
        #return 1100
        return len(self.series_meta)
    
    def __getitem__(self, idx):
        patient_id, series_id = self.series_meta.loc[idx, ["patient_id", "series_id"]].astype('int')
        img_a = sitk.GetArrayFromImage(load_resample(patient_id, series_id, train=self.train))
        img_t = torch.from_numpy(img_a).to(self.device)
        return img_t

In [None]:
import multiprocessing
num_cpus = multiprocessing.cpu_count()

train_ds = CTDataset(train=True)
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=num_cpus)

test_ds = CTDataset(train=False)
test_dl = DataLoader(test_ds, batch_size=4, shuffle=True, num_workers=num_cpus)

In [None]:
for imgs in tqdm(test_dl):
    print(imgs.shape)

In [None]:
# img_a = sitk.GetArrayFromImage(load_resample(49954, 41479, train=True))
# plot_image_with_seg(img_a, orientation='Axial', num_subplots=5)
# plot_image_with_seg(img_a, orientation='Sagittal', num_subplots=5)

In [None]:
!cp /kaggle/input/rsna-2023-abdominal-trauma-detection/sample_submission.csv submission.csv