In [1]:
import numpy as np
import pandas as pd
import os, random
from fastai.vision.all import *
from fastai.medical.imaging import *
import shutil
import pydicom
import cv2
import glob
import time
from rsna_2023_atd_metric import score
import tqdm

from PIL import Image

random.seed(1441)

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


In [2]:
def standardize_pixel_array(dcm):
    """
    Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
    """
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    #dcm = pydicom.dcmread(fn)
    pixel_array = dcm.pixel_array
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        pixel_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
#         pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)

    intercept = float(dcm.RescaleIntercept)
    slope = float(dcm.RescaleSlope)
    center = int(dcm.WindowCenter)
    width = int(dcm.WindowWidth)
    low = center - width / 2
    high = center + width / 2    
    
    pixel_array = (pixel_array * slope) + intercept
    pixel_array = np.clip(pixel_array, low, high)

    return pixel_array

In [3]:
class MultiHeadModel(Module):
    
    def __init__(self, body):
    
        self.body = body
        nf = num_features_model(nn.Sequential(*self.body.children()))

        self.bowel = create_head(nf, 1)
        self.extravasation = create_head(nf, 1)
        self.kidney = create_head(nf, 3)
        self.liver = create_head(nf, 3)
        self.spleen = create_head(nf, 3)
        
    def forward(self, x):
        
        y = self.body(x)
        bowel = self.bowel(y)
        extravasation = self.extravasation(y)
        kidney = self.kidney(y)
        liver = self.liver(y)
        spleen = self.spleen(y)
        return [bowel, extravasation, kidney, liver, spleen]

In [4]:
class CombinationLoss(Module):
    "Cross entropy loss on multiple targets"
    def __init__(self, func = F.cross_entropy, weights = [2, 6, 3, 3, 3]):
        self.func = func
        self.w = weights
        
    def forward(self, xs, *ys, reduction = 'mean'):
        loss = 0
    
        for i, w, x, y in zip(range(len(xs)), self.w, xs, ys):
            if i < 2:
                loss += w*F.binary_cross_entropy_with_logits(x, y.unsqueeze(1).float(), reduction = reduction)
            else:
                #import pdb;pdb.set_trace()
                loss += w*F.cross_entropy(x, y, reduction = reduction)
        return loss

In [5]:
from sklearn.metrics import recall_score

class RecallPartial(Metric):
    "Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
    def __init__(self, a=0, **kwargs):
        self.func = partial(recall_score, average='macro', zero_division=0)
        self.a = a

    def reset(self): self.targs,self.preds = [],[]

    def accumulate(self, learn):
        pred = learn.pred[self.a].argmax(-1)
        targ = learn.y[self.a]
        pred,targ = to_detach(pred),to_detach(targ)
        pred,targ = flatten_check(pred,targ)
        self.preds.append(pred)
        self.targs.append(targ)

    @property
    def value(self):
        if len(self.preds) == 0: return
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        return self.func(targs, preds)

    @property
    def name(self): return 'recall_' + str(self.a+1)
    
class RecallCombine(Metric):
    def accumulate(self, learn):
        scores = [learn.metrics[i].value for i in range(3)]
        self.combine = np.average(scores, weights=[2,1,1])

    @property
    def value(self):
        return self.combine

In [6]:
learn = load_learner('/kaggle/input/rsna2023-atd-2d-cnn-image-level-model-2/model_2.pt', cpu = False)

In [7]:
TEST_PATH = '/kaggle/input/rsna-2023-abdominal-trauma-detection/test_images/'
SAVE_FOLDER = 'temp_folder/'
SIZE = 128
STRIDE = 10

if not os.path.exists(SAVE_FOLDER):
    os.makedirs(SAVE_FOLDER)

print('Number of test patients:', len(os.listdir(TEST_PATH)))

Number of test patients: 3


In [8]:
def convert_dicom_to_png(patient, size = 128):
    
    for study in (sorted(os.listdir(TEST_PATH + patient))):
        imgs = {}
        for f in sorted(glob.glob(TEST_PATH + f"{patient}/{study}/*.dcm"))[::STRIDE]:
            
            dicom = pydicom.dcmread(f)
            pos_z = dicom[(0x20, 0x32)].value[-1]
            img = standardize_pixel_array(dicom)
            
            img = (img - img.min())/(img.max() - img.min() + 1e-6)
            imgs[pos_z] = img
                
        for i, k in enumerate(sorted(imgs.keys())):
            
            img = imgs[k]
            
            img = cv2.resize(img, (size, size))
            cv2.imwrite(SAVE_FOLDER + f"{patient}_{study}_{i}.png", (img * 255).astype(np.uint8))
        
#_ = Parallel(n_jobs = 2)(
#    delayed(convert_dicom_to_png)(patient, size=SIZE)
#    for patient in tqdm(os.listdir(TEST_PATH))
#    )
    

In [9]:
def merge_arr(a, b):
    return np.concatenate((a, b.numpy()), axis = 0)

In [10]:
patients = os.listdir(TEST_PATH)

bowel_preds, extrav_preds = np.array([]).reshape(0), np.array([]).reshape(0)
kidney_preds, liver_preds, spleen_preds = np.array([]).reshape(0, 3), np.array([]).reshape(0, 3), np.array([]).reshape(0, 3)
fnames_list = []

start = time.time()
sigm = torch.nn.Sigmoid()
softm = torch.nn.Softmax(dim = 1)

for idx, patient in enumerate(patients):
    
    convert_dicom_to_png(patient, SIZE)
    files = get_image_files(SAVE_FOLDER)
    test_dl = learn.dls.test_dl(files, with_labels = False, device = 'cuda', bs = 128)

    preds = learn.get_preds(dl = test_dl)[0]
        
    bowel_preds = merge_arr(bowel_preds, sigm(preds[0]).squeeze(-1))
    extrav_preds = merge_arr(extrav_preds, sigm(preds[1]).squeeze(-1))
    kidney_preds = merge_arr(kidney_preds, softm(preds[2]))
    liver_preds = merge_arr(liver_preds, softm(preds[3]))
    spleen_preds = merge_arr(spleen_preds, softm(preds[4]))
       
    fnames_list.append(files)
    
    for file in files:
        os.remove(file)
    
    if (idx + 1) % 5 == 0:
        end = time.time()
        print(f'{idx + 1} patients processed.')
        print(f'Time elapsed: {end - start} ')
        print(f'Avg time per patient: {(end - start)/(idx + 1)}')

In [11]:
from itertools import chain
fnames_list = list(chain.from_iterable(fnames_list))

In [12]:
test_files_probs = pd.DataFrame()

test_files_probs['fname'] = pd.Series(fnames_list, dtype = 'string')

test_files_probs['bowel_injury'] = pd.Series(bowel_preds)
test_files_probs['extravasation_injury'] = pd.Series(extrav_preds)
test_files_probs['kidney_low'] = pd.Series(kidney_preds[:, 1])
test_files_probs['kidney_high'] = pd.Series(kidney_preds[:, 2])
test_files_probs['liver_low'] = pd.Series(liver_preds[:, 1])
test_files_probs['liver_high'] = pd.Series(liver_preds[:, 2])
test_files_probs['spleen_low'] = pd.Series(spleen_preds[:, 1])
test_files_probs['spleen_high'] = pd.Series(spleen_preds[:, 2])

#test_files_probs

In [13]:
test_files_probs.head()

Unnamed: 0,fname,bowel_injury,extravasation_injury,kidney_low,kidney_high,liver_low,liver_high,spleen_low,spleen_high
0,temp_folder/63706_39279_0.png,0.001513,0.007261,9.168278e-07,1e-06,3.14382e-07,3.534692e-07,6.240921e-06,2e-06
1,temp_folder/50046_24574_0.png,0.47894,0.031321,3.017532e-06,0.000131,7.934494e-07,1.840951e-07,3.770563e-06,1e-06
2,temp_folder/48843_62825_0.png,0.000499,0.004182,1.272889e-06,1e-06,5.496062e-07,5.989269e-07,4.965273e-07,2e-06


In [14]:
patient_id_list = []
for idx, fname in enumerate(test_files_probs['fname']):
    patient_id_list.append(fname.split('/')[1].split('_')[0])
    
test_files_probs['patient_id'] = pd.Series(patient_id_list, dtype = 'string')

In [15]:
patients = set(test_files_probs.patient_id)

col_names = ['patient_id', 'bowel_healthy', 'bowel_injury',
            'extravasation_healthy', 'extravasation_injury',
            'kidney_healthy', 'kidney_low', 'kidney_high',
            'liver_healthy', 'liver_low', 'liver_high',
            'spleen_healthy', 'spleen_low', 'spleen_high']

patient_probs = pd.DataFrame([], columns = col_names)
    

for idx, pat in enumerate(patients):
    p = test_files_probs[test_files_probs.patient_id == pat]
    
    bi = p.bowel_injury.quantile(q = 0.95)
    ei = p.extravasation_injury.quantile(q = 0.95)
    kl = p.kidney_low.quantile(q = 0.95)
    kh = p.kidney_high.quantile(q = 0.95)
    ll = p.liver_low.quantile(q = 0.95)
    lh = p.liver_high.quantile(q = 0.95)
    sl = p.spleen_low.quantile(q = 0.95)
    sh = p.spleen_high.quantile(q = 0.95)
    
    eps = 1e-5
    patient_probs.loc[idx] = [str(pat), max(eps, 1 - bi), 2*bi, max(eps, 1 - ei), 6*ei, max(eps, 1 - kl - kh), 2*kl, 4*kh, max(eps, 1 - ll - lh), 2*ll, 4*lh, max(eps, 1 - sl - sh), 2*sl, 4*sh]

In [16]:
patient_probs.head()

Unnamed: 0,patient_id,bowel_healthy,bowel_injury,extravasation_healthy,extravasation_injury,kidney_healthy,kidney_low,kidney_high,liver_healthy,liver_low,liver_high,spleen_healthy,spleen_low,spleen_high
0,63706,0.998487,0.003026,0.992739,0.043565,0.999998,2e-06,6e-06,0.999999,6.28764e-07,1.413877e-06,0.999992,1.248184e-05,7e-06
1,50046,0.52106,0.95788,0.968679,0.187927,0.999866,6e-06,0.000522,0.999999,1.586899e-06,7.363802e-07,0.999995,7.541127e-06,5e-06
2,48843,0.999501,0.000999,0.995818,0.025095,0.999997,3e-06,5e-06,0.999999,1.099212e-06,2.395708e-06,0.999998,9.930545e-07,7e-06


In [17]:
sample_submission = pd.read_csv('/kaggle/input/rsna-2023-abdominal-trauma-detection/sample_submission.csv')
sample_submission.patient_id = sample_submission.patient_id.astype(str)


patient_probs = patient_probs.set_index('patient_id')
patient_probs = patient_probs.reindex(index = sample_submission['patient_id'])
patient_probs = patient_probs.reset_index()

In [18]:
patient_probs.to_csv('submission.csv', header = True, index = False)