## Introduction

In this notebook, we are going to fine-tune a pretrained CNN model using the CT scan images from the competition. Specifically, we are going to use a ResNet-50 model pre-trained on ImageNet data. We'll make use of the [fastai](https://docs.fast.ai) library.

This being a first baseline, we are not going to make use of segmentation data, DICOM tags, or meta data. Importantly, we are not going to consider the CT scan images as sequence, which will give more global information. Our focus here is to set up a data pipeline.

## Code

### EDA
We import the necessary modules.

In [None]:
#!pip install -qU python-gdcm pylibjpeg

In [None]:
import numpy as np
import pandas as pd
import os, random
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from fastai.vision.all import *
from fastai.basics import *
from fastai.callback.all import *
from fastai.medical.imaging import *
import shutil
import pydicom
import cv2
import glob
import time
import seaborn as sns
from sklearn.model_selection import train_test_split
from joblib import Parallel, delayed

from tqdm.notebook import tqdm
from joblib import Parallel, delayed

random.seed(42)

Let's look at one of the DICOM scan images.

In [None]:
f_dicom = pydicom.dcmread('/kaggle/input/rsna-2023-abdominal-trauma-detection/train_images/10004/21057/1000.dcm')
img = f_dicom.pixel_array

plt.figure(figsize=(15, 15))
plt.imshow(img, cmap="gray")
plt.show()

It's 512x512 resolution image. Let's now look at the meta-data (DICOM tags) associated with the above scan.

In [None]:
f_dicom

The _Transfer Syntax ID_ field tells us the image is encoded using RLE Lossless compression. This means we'll probably not be able to use NVIDIA DALI for decoding speedup. The 3rd entry in the _Image Position_ list is the z-coordinate and helps us stack the 2D scans for a 3D view if needed.

We are going to work with PNG images and not DICOM images. We are going to make use of a small subset of the PNG training data provided [here](https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427427); we import this data in our Kaggle notebook using Add Data. In fact, for testing our pipeline, we're going to make use of only a small number of samples from this set.

In [None]:
BASE_DIR = '/kaggle/input/rsna-2023-abdominal-trauma-detection'

Let us look at the labels provided in *train.csv*. We note that the labels are assigned to each patient and not to each CT scan image. There are many CT images associated with each patient; some of them may reflect the injury and some may not, depending upon their z-coordinate.

However, for simplification, for each CT scan image we'll simply assign the label given to the respective patient.

In [None]:
patient_labels = pd.read_csv(os.path.join(BASE_DIR, 'train.csv'))
patient_labels.head()

We do a sanity check that the *healthy* and *injury* (or *low injury* and *high injury*) probabilities add up to 1 for each type of injury.

In [None]:
#Check data has no NANs
print(patient_labels.isnull().any().any())
#Check healthy and injury labels are complementary
print((patient_labels['bowel_healthy'] == np.abs(1 - patient_labels['bowel_injury'])).all())
print((patient_labels['extravasation_healthy'] == np.abs(1 - patient_labels['extravasation_injury'])).all())
print((patient_labels['kidney_healthy'] == np.abs(1 - patient_labels['kidney_low'] - patient_labels['kidney_high'])).all())
print((patient_labels['liver_healthy'] == np.abs(1 - patient_labels['liver_low'] - patient_labels['liver_high'])).all())
print((patient_labels['spleen_healthy'] == np.abs(1 - patient_labels['spleen_low'] - patient_labels['spleen_high'])).all())

The *any_injury* label seems a bit redundant since it can be derived from the other injury labels. It simply means whether at least one type of bowel/extravsation/liver/spleen/kidney injury is present.

In [None]:
#Check consistency of any_injury label with respect to other labels
print((patient_labels['any_injury'] == 1 - np.min(patient_labels[['bowel_healthy', 'extravasation_healthy', 
                                                          'kidney_healthy', 'liver_healthy', 'spleen_healthy']], axis = 1)).all())

#### Checking Data Imabalance and Correlations

Let us find the percentage of samples having different forms of injuries. This will help us know the distribution of the labels and find out whether there is imbalance in positive/negative samples.

In [None]:
def get_pos_percent(df, label):
    num_entries = df.shape[0]
    return (df[label] == 1).sum()*100/num_entries

injury_labels = ['bowel_injury', 'extravasation_injury' , 'kidney_low' , 'kidney_high', 'liver_low',
                 'liver_high', 'spleen_low', 'spleen_high']


for label in injury_labels:    
    print(f'% of {label} samples = {get_pos_percent(patient_labels, label): .3f}')

Indeed there is a lot of imbalance between number of positive and negative samples for each type of injury.

Now, let us check how much the injuries are correlated among themselves. We'll plot a heatmap of the Pearson correlation coefficients (between -1 and 1).

In [None]:
sns.heatmap(patient_labels[injury_labels].corr(), vmin = -1, vmax = 1, annot = True)

The correlations are very weak, with most values close to 0 and maximum as 0.2.

### Data Preparation

#### Splitting dataset

Let us first form training and validation sets. Because of the imbalance, we would like to ensure that the class/injury distribution is similar in the three sets.

In [None]:
train_patients, validation_patients = train_test_split(patient_labels, test_size = 0.2, random_state = SEED) 

In [None]:
print('Number of samples')
print(f'Training set: {train_patients.shape[0]}   Validation set: {validation_patients.shape[0]}')

In [None]:
print(f'% of positive samples - Training, Validation')
for label in injury_labels:    
    print(f'{label}: {get_pos_percent(train_patients, label): .3f}, {get_pos_percent(validation_patients, label): .3f}')

The distribution of the labels in not exactly equal in the three sets, but it is more or less similar. For now, this will do for our purposes.

#### Multi-label Classification Data

We are going to consider our problem as multi-label classification. This is because a patient can have one of five different types of injuries, and the injuries can potentially co-exist.

The easiest way to prepare our training data for passing on to the fastai dataloader is to follow the convention used in this [tutorial](https://docs.fast.ai/tutorial.vision.html#multi-label-classification). We're going to use the patient labels in _train.csv_ to derive labels for each image.

In [None]:
filename_labels = pd.DataFrame()

patients = random.choices(os.listdir(os.path.join(BASE_DIR, 'train_images')), k = 500)

f_list = []

start = time.time()

for pat in tqdm(patients):

    series = os.listdir(os.path.join(BASE_DIR, 'train_images', str(pat)))
    
    for s in series:
        f_names = random.choices(os.listdir(os.path.join(BASE_DIR, 'train_images', str(pat), str(s))), k = 10)
        f_names = [os.path.join(str(pat), str(s), x) for x in f_names]
        f_list.append(f_names)

end = time.time()
print(end - start)

In [None]:
f_list = [item for sublist in f_list for item in sublist]
filename_labels['fname'] = pd.Series(f_list)

From _train.csv_ , we construct a dictionary _patient_dict_ which has patient_id as key and the list of injuries as labels. Our target labels are whether or not a patient has one or more of bowel/extravasation/liver/spleen/kidney injuries; we derive the other probabilites from these predictions. We split the _injury_ probability equally between _low_ and _high_ for liver, spleen, and kidney.

Ideally, we should use sigmoid heads for bowel/extravasation, and softmax heads for liver/spleen/kidney injuries. That is for a later iteration.

In [None]:
train_patients['is_valid'] = False
validation_patients['is_valid'] = True
train_val_df = pd.concat([train_patients, validation_patients])

In [None]:
train_val_df['is_valid'].value_counts()

In [None]:
patient_dict = {}

target_labels = ['bowel_injury', 'extravasation_injury' , 'kidney_healthy' , 'kidney_low' , 'kidney_high', 'liver_healthy', 'liver_low',
                 'liver_high', 'spleen_healthy', 'spleen_low', 'spleen_high']
reduced_target_labels = ['bowel_injury', 'extravasation_injury' , 'kidney_injury', 'liver_injury', 'spleen_injury']


for idx, patient_id in enumerate(train_val_df['patient_id']):
    entry = train_val_df.iloc[idx][target_labels]
    is_valid = train_val_df.iloc[idx]['is_valid'] 
    patient_dict[patient_id] = (entry, is_valid)

Having constructed our _patient_dict_ dictionary, we can loop through all the PNG filenames and for each look up the appropriate entry from the dictionary using _patient_id_ as the key. It was important to construct a dictionary first, so that this lookup can be fast using hash table.

In [None]:
label_list = []
is_valid_list = []
reduced_target_labels = ['bowel_injury', 'extravasation_injury' , 'kidney_injury', 'liver_injury', 'spleen_injury']

start = time.time()
for scan_name in tqdm(filename_labels['fname']):
    patient_id = int(scan_name.split('/')[0])
    entry = patient_dict[patient_id][0]
    
    if_bowel = entry['bowel_injury']
    if_extravasation = entry['extravasation_injury']
    if_kidney = max(1*entry['kidney_healthy'], 2*entry['kidney_low'], 3*entry['kidney_high']) - 1
    if_liver = max(1*entry['liver_healthy'], 2*entry['liver_low'], 3*entry['liver_high']) - 1
    if_spleen = max(1*entry['spleen_healthy'], 2*entry['spleen_low'], 3*entry['spleen_high']) - 1

    labels = [if_bowel, if_extravasation, if_kidney, if_liver, if_spleen]
    label_list.append(labels)
    
    is_valid = patient_dict[patient_id][1]
    is_valid_list.append(is_valid)

end = time.time()
print(end - start)

In [None]:
filename_labels[reduced_target_labels] = pd.DataFrame(label_list)
filename_labels['is_valid'] = pd.Series(is_valid_list)

In [None]:
assert(filename_labels['bowel_injury'].all() in [0, 1])
assert(filename_labels['extravasation_injury'].all() in [0, 1])
assert(filename_labels['kidney_injury'].all() in [0, 1, 2])
assert(filename_labels['liver_injury'].all() in [0, 1, 2])
assert(filename_labels['spleen_injury'].all() in [0, 1, 2])

We have the data in the format we wanted. Now we can construct the dataloader and train a CNN model.

In [None]:
print(filename_labels.shape)
filename_labels.head()

In [None]:
filename_labels.to_csv('train_image_labels.csv', index = False)

### Training

In [None]:
# filename_labels = pd.read_csv('/kaggle/input/rsna-2023-atd-baseline-1-training/train_image_labels.csv')
# filename_labels.columns

In [None]:
# filename_labels = filename_labels.drop(filename_labels.columns[0], axis = 1)
# print(filename_labels.shape)
# filename_labels.head()

In [None]:
SIZE = 512

We're going to take a ResNet-50 model pretrained on ImageNet (directly available via fastai library) and fine-tune it on our small data.

First, we construct our dataloader. We perform some data augmentation by randomly cropping images to 224x224. Setting pin memory to true enables faster data transfer between CPU and GPU, potentially speeding up training. We choose a relatively large batch size (64) to make good use of GPU.

In [None]:
def standardize_pixel_array(fn):
    """
    Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
    """
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    dcm = pydicom.dcmread(fn)
    pixel_array = dcm.pixel_array
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        pixel_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
#         pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)

    intercept = float(dcm.RescaleIntercept)
    slope = float(dcm.RescaleSlope)
    center = int(dcm.WindowCenter)
    width = int(dcm.WindowWidth)
    low = center - width / 2
    high = center + width / 2    
    
    pixel_array = (pixel_array * slope) + intercept
    pixel_array = np.clip(pixel_array, low, high)
    
    return pixel_array

In [None]:
class PILDicom2(PILBase):
    "same as PILDicom but changed pixel array dtype to int32 since uint16 cannot be handled by PIL/PyTorch"
    
    _open_args,_tensor_cls,_show_args = {},TensorDicom,TensorDicom._show_args
    @classmethod
    def create(cls, fn:Path|str|bytes, mode=None)->None:
        "Open a `DICOM file` from path `fn` or bytes `fn` and load it as a `PIL Image`"
        if isinstance(fn,bytes): im = Image.fromarray(pydicom.dcmread(pydicom.filebase.DicomBytesIO(fn)).pixel_array)
        if isinstance(fn,(Path,str)): im = Image.fromarray(standardize_pixel_array(fn).astype(np.int32))
        im.load()
        im = im._new(im.im)
        return cls(im.convert(mode) if mode else im)

In [None]:
bowel_vocab = filename_labels['bowel_injury'].unique()
extravasation_vocab = filename_labels['extravasation_injury'].unique()
kidney_vocab = filename_labels['kidney_injury'].unique()
liver_vocab = filename_labels['liver_injury'].unique()
spleen_vocab = filename_labels['spleen_injury'].unique()

blocks = (ImageBlock(cls=PILDicom2), 
          CategoryBlock(vocab = bowel_vocab),
          CategoryBlock(vocab = extravasation_vocab),
          CategoryBlock(vocab = kidney_vocab),
          CategoryBlock(vocab = liver_vocab),
          CategoryBlock(vocab = spleen_vocab))

In [None]:
kidney_vocab

In [None]:
getters = (ColReader('fname', pref = os.path.join(BASE_DIR, 'train_images/')), 
           ColReader('bowel_injury'), ColReader('extravasation_injury'),
           ColReader('kidney_injury'), ColReader('liver_injury'), ColReader('spleen_injury'))

In [None]:
data_block = DataBlock(blocks = blocks,
                       getters = getters,
                       splitter = ColSplitter('is_valid'),
                       item_tfms = Resize(128, resamples = (0, 0)),
                       n_inp = 1)

In [None]:
dls = data_block.dataloaders(filename_labels)

In [None]:
dls.c

In [None]:
dls.show_batch()

We need to use appropriate metrics for our multi-label classfication problem. F1 score is one of the metrics that can be used.

We create a learner.

In [None]:
dsets = data_block.datasets(filename_labels)
dsets[0]

In [None]:
class MultiHeadModel(Module):
    
    def __init__(self, body):
    
        self.body = body
        nf = num_features_model(nn.Sequential(*self.body.children()))

        self.bowel = create_head(nf, 1)
        self.extravasation = create_head(nf, 1)
        self.kidney = create_head(nf, 3)
        self.liver = create_head(nf, 3)
        self.spleen = create_head(nf, 3)
        
    def forward(self, x):
        
        y = self.body(x)
        bowel = self.bowel(y)
        extravasation = self.extravasation(y)
        kidney = self.kidney(y)
        liver = self.liver(y)
        spleen = self.spleen(y)
        return [bowel, extravasation, kidney, liver, spleen]

In [None]:
base_model = create_vision_model(models.resnet50, 10, True, n_in = 1)
body = create_body(base_model, pretrained=True)
net = MultiHeadModel(body)

In [None]:
class CombinationLoss(Module):
    "Cross entropy loss on multiple targets"
    def __init__(self, weights = [2, 6, 3, 3, 3]):
        self.w = weights
        
    def forward(self, xs, *ys, reduction = 'mean'):
        loss = 0
    
        for i, w, x, y in zip(range(len(xs)), self.w, xs, ys):
            if i < 2:
                loss += w*F.binary_cross_entropy_with_logits(x, y.unsqueeze(1).float(), reduction = reduction)
            else:
                loss += w*F.cross_entropy(x, y, reduction = reduction)
        return loss

In [None]:
from sklearn.metrics import recall_score

class RecallPartial(Metric):
    "Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
    def __init__(self, a=0, **kwargs):
        self.func = partial(recall_score, average='macro', zero_division=0)
        self.a = a

    def reset(self): self.targs,self.preds = [],[]

    def accumulate(self, learn):
        pred = learn.pred[self.a].argmax(-1)
        targ = learn.y[self.a]
        pred,targ = to_detach(pred),to_detach(targ)
        pred,targ = flatten_check(pred,targ)
        self.preds.append(pred)
        self.targs.append(targ)

    @property
    def value(self):
        if len(self.preds) == 0: return
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        return self.func(targs, preds)

    @property
    def name(self): return 'recall_' + filename_labels.columns[self.a+1].split('_')[0]
    
class RecallCombine(Metric):
    
    def accumulate(self, learn):
        scores = [learn.metrics[i].value for i in range(3)]
        self.combine = np.average(scores, weights=[2,1,1])

    @property
    def value(self):
        return self.combine

In [None]:
learn = Learner(dls, net, loss_func = CombinationLoss(), metrics=[RecallPartial(a=i) for i in range(len(dls.c))] + [RecallCombine()])

In [None]:
learn.opt_func

We do mixed-precision training to speed up our training process.

In [None]:
learn.to_fp16()

fastai provides a handy function to let us find a good learning rate.

In [None]:
#learn.lr_find()

Let's now fine tune our model for 10 epochs using the recommended learning rate.

In [None]:
import cProfile

cProfile.run('learn.fine_tune(1, 1.2e-3)')

As the epochs progress, both the training and validation losses decrease. The F1 score reaches 1, meaning that both precision and recall are 1. Hence, our model is classifying all samples in training set correctly.

Let's look at the predictions for some examples in the training set. For each image, the upper label list is the true set of labels; the lower label list is the predicted set of labels.

Our model is doing well to predict samples from the training set. We've used a very small amount of data to train, so probably the learning task is too easy.

### Save Model

We need to save our model, so that our [inference notebook](https://www.kaggle.com/code/pankajpansari/rsna-2023-atd-baseline-1-inference) can import it to make predictions on the test set.

In [None]:
learn.export('/kaggle/working/model.pt')

In [None]:
len(os.listdir('/kaggle/input/atd-training-128-png-subset/kaggle/working/train_images_128_png_subset'))

In [None]:
!du -sh /kaggle/input/atd-training-128-png-subset/kaggle/working/train_images_128_png_subset