In [None]:
### please specify your input path here
PROJECT_FOLDER = "YOUR_PROJECT_FOLDER" # parent folder of the input images
IMAGE_DATA_FOLDER = PROJECT_FOLDER + "images/" # folder of the input images
INPUT_TEST_CSV_FILE = "YOUR_TEST_FILE" # csv file list locations / paths to test cases (dicom)
OUTPUT_FILE = "YOUR_OUTPUT" # in csv format

## Install Packages

In [1]:
# !cp -r /kaggle/input/python-packages /kaggle/working
# !pip install -q /kaggle/working/python-packages/pylibjpeg_libjpeg-1.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# !pip install -q /kaggle/working/python-packages/pylibjpeg_openjpeg-1.2.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# !pip install -q /kaggle/working/python-packages/pylibjpeg_rle-1.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# !pip install -q /kaggle/working/python-packages/iopath-0.1.9-py3-none-any.whl
# !pip install -q /kaggle/working/python-packages/av-9.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# !pip install -q /kaggle/working/python-packages/fvcore-0.1.5.post20220512/
# !pip install -q /kaggle/working/python-packages/parameterized-0.8.1-py2.py3-none-any.whl
# !pip install -q /kaggle/working/python-packages/pytorchvideo-0.1.5/
# !pip install -q /kaggle/working/python-packages/timm-0.6.7-py3-none-any.whl
# !pip install -q /kaggle/working/python-packages/antlr4-python3-runtime-4.9.3/
# !pip install -q /kaggle/working/python-packages/omegaconf-2.2.2-py3-none-any.whl
# !pip install -q /kaggle/working/python-packages/monai-0.8.1-202202162213-py3-none-any.whl
#
# !cp /kaggle/input/gdcm-conda-install/gdcm.tar /kaggle/working/
# !tar -xzvf gdcm.tar
# !conda install --offline /kaggle/working/gdcm/gdcm-2.8.9-py37h71b2a6d_0.tar.bz2

## Imports

In [None]:
import sys
sys.path.append("./rsna-cspine-src/")
sys.path.append("./rsna-cspine-src/skp")


import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import os.path as osp
import pandas as pd
import pydicom
import time
import torch
import torch.nn.functional as F

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from collections import defaultdict
from omegaconf import OmegaConf
from scipy.ndimage.interpolation import zoom 
from sklearn.metrics import roc_auc_score
from skp import builder
from tqdm import tqdm

torch.set_grad_enabled(False)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Helper Functions

In [4]:
def window(x, WL, WW):
    upper, lower = WL+WW//2, WL-WW//2
    x = np.clip(x, lower, upper)
    x = x - lower
    x = x / (upper - lower)
    x = x * 255
    x = x.astype('uint8')
    return x


def load_dicom_volume(dicom_folder):
    dicom_files = glob.glob(osp.join(dicom_folder, "*.dcm"))
    dicoms = [pydicom.dcmread(_) for _ in dicom_files]
    z_positions = [float(_.ImagePositionPatient[2]) for _ in dicoms]
    dicom_arrays = [_.pixel_array.astype("float32") for _ in dicoms]
    rescale_slope = float(dicoms[0].RescaleSlope)
    rescale_intercept = float(dicoms[0].RescaleIntercept)
    del dicoms 
    
    # Deal with potential scenario where not all arrays are the same shape
    # This assumes that all arrays have the same number of dimensions (2)
    array_shapes = np.vstack([_.shape for _ in dicom_arrays])
    h, w = np.median(array_shapes[:,0]), np.median(array_shapes[:,1])
    for ind, arr in enumerate(dicom_arrays):
        if arr.shape[0] != h or arr.shape[1] != w:
            print("Mismatched shape, resizing ...")
            scale_h, scale_w = float(h) / arr.shape[0], float(w) / arr.shape[1]
            arr = zoom(arr, [scale_h, scale_w], order=1, prefilter=False)
            dicom_arrays[ind] = arr
    
    array = np.stack(dicom_arrays)
    del dicom_arrays 
    array = rescale_slope * array + rescale_intercept
    array = window(array, WL=400, WW=2500)
    
    # Sort in DESCENDING order by z-position
    array = array[np.argsort(z_positions)[::-1]]
    return array


def plot_volume(array, skip=10, sagittal=False):
    length = array.shape[2] if sagittal else array.shape[0]
    for i in range(0, length, skip):
        image = array[..., i] if sagittal else array[i]
        if np.sum(image) == 0:
            continue
        plt.imshow(image, cmap="gray")
        plt.show()

        
def rescale(x):
    # Rescale to [-1, 1]
    x = x / x.max()
    x = x - 0.5
    x = x * 2
    return x


def unscale(x):
    x = x + 1
    x = x * 255 / 2
    return x


def load_models(config_file, checkpoint_folder, model_type="classification", cuda=torch.cuda.is_available(), load_indices=None):
    assert model_type in ["classification", "segmentation", "sequence", "tdcnn"]
    config = OmegaConf.load(config_file)
    if model_type == "segmentation":
        config.model.params.encoder_params.pretrained = False
    elif model_type == "classification":
        config.model.params.pretrained = False
    elif model_type == "tdcnn":
        config.model.params.cnn_params.pretrained = False 
    checkpoints = np.sort(glob.glob(osp.join(checkpoint_folder, "*")))
    if isinstance(load_indices, (list, tuple)):
        load_indices = list(load_indices)
        checkpoints = checkpoints[load_indices]
    models = []
    for each_checkpoint in checkpoints:
        _config = config.copy()
        _config.model.load_pretrained = str(each_checkpoint)
        _model = builder.build_model(_config).eval()
        if cuda:
            _model = _model.to(device)
        models.append(_model)
    return models 

            
def get_cervical_spine_coordinates(volume, inference_shape, segmentation_models, threshold, adjustment, uid):
    orig_shape = volume.shape[2:]
    volume = F.interpolate(volume, size=inference_shape, mode="nearest")
    segmentation = torch.sigmoid(torch.cat([seg_model(volume.to(device)) for seg_model in segmentation_models])).mean(0)
    cspine_coords = {}
    for level in range(7):
        coords = torch.stack(torch.where(segmentation[level] >= threshold)).cpu().numpy()
        coords[0] = coords[0] * orig_shape[0] / inference_shape[0] 
        coords[1] = coords[1] * orig_shape[1] / inference_shape[1] 
        coords[2] = coords[2] * orig_shape[2] / inference_shape[2] 
        adjusted_threshold = threshold
        need_refine = False
        while coords.shape[1] == 0 and adjusted_threshold > adjustment:
            print(f"uid:{uid} C{level+1} not found, lowering threshold to {adjusted_threshold - adjustment:0.1f} ...")
            adjusted_threshold -= adjustment
            adjusted_threshold = np.round(adjusted_threshold, 1)
            coords = torch.stack(torch.where(segmentation[level] >= threshold)).cpu().numpy()
        if coords.shape[1] == 0:
            print(f"uid:{uid} Segmentation for C{level+1} failed !")
            cspine_coords[level] = None
        else:
            cspine_coords[level] = (coords[0].min(), coords[0].max(), coords[1].min(), coords[1].max(), coords[2].min(), coords[2].max())
    return cspine_coords


def center_crop(x, crop_size):
    h, w = crop_size
    orig_h, orig_w = x.shape[-2], x.shape[-1]
    diff_h, diff_w = (orig_h - h) // 2, (orig_w - w) // 2
    return x[..., diff_h:diff_h+h, diff_w:diff_w+w]

## Load Models

In [5]:
cspine_segmentation_models = load_models("./rsna-cspine-src/configs/seg/pseudoseg000.yaml",
                                         "./rsna-cspine-pseudoseg000/",
                                         model_type="segmentation",
                                         load_indices=[0, 1, 2, 3, 4])


feature_extractors_3d = load_models("./rsna-cspine-src/configs/chunk/chunk000.yaml",
                                    "./rsna-cspine-chunk000/",
                                    load_indices=[0, 1, 2, 3, 4])

chunk_sequence_models = load_models("./rsna-cspine-src/configs/chunkseq/chunkseq003.yaml",
                                    "./rsna-cspine-chunkseq003/",
                                    model_type="sequence",
                                    load_indices=[0, 1, 2, 3, 4])


feature_extractors_2d = load_models("./rsna-cspine-src/configs/chunk/chunk101.yaml",
                                    "./rsna-cspine-chunk101/",
                                    model_type="tdcnn",
                                    load_indices=[0, 1, 2, 3, 4])

slice_sequence_models = load_models("./rsna-cspine-src/configs/chunkseq/chunkseq005.yaml",
                                    "./rsna-cspine-chunkseq005/",
                                    model_type="sequence",
                                    load_indices=[0, 1, 2, 3, 4])


fused_sequence_models = load_models("./rsna-cspine-src/configs/chunkseq/chunkseq006.yaml",
                                    "./rsna-cspine-chunkseq006/",
                                    model_type="sequence",
                                    load_indices=[0, 1, 2, 3, 4])

Creating model <NetSegment3D> ...
Confirmed encoder output stride 16 !
  Loading pretrained checkpoint from ./rsna-cspine-pseudoseg000/fold0.ckpt
Creating model <NetSegment3D> ...
Confirmed encoder output stride 16 !
  Loading pretrained checkpoint from ./rsna-cspine-pseudoseg000/fold1.ckpt
Creating model <NetSegment3D> ...
Confirmed encoder output stride 16 !
  Loading pretrained checkpoint from ./rsna-cspine-pseudoseg000/fold2.ckpt
Creating model <NetSegment3D> ...
Confirmed encoder output stride 16 !
  Loading pretrained checkpoint from ./rsna-cspine-pseudoseg000/fold3.ckpt
Creating model <NetSegment3D> ...
Confirmed encoder output stride 16 !
  Loading pretrained checkpoint from ./rsna-cspine-pseudoseg000/fold4.ckpt
Creating model <Net3D> ...
  Using backbone <x3d_l> ...
  Pretrained : False
  Loading pretrained checkpoint from ./rsna-cspine-chunk000/fold0.ckpt
Creating model <Net3D> ...
  Using backbone <x3d_l> ...
  Pretrained : False
  Loading pretrained checkpoint from ./rsna-c

In [6]:
test_df = pd.read_csv(INPUT_TEST_CSV_FILE)

print('test shape:', test_df.shape)

test shape: (5161, 3)


In [None]:
threshold = 0.4
adjustment = 0.1
segmentation_inference_size = (192, 192, 192)
chunk_inference_size = (64, 288, 288)
slice_inference_size = (32, 288, 288)

chunk_prediction_dict, slice_prediction_dict, fused_prediction_dict = {}, {}, {}
for index, row in tqdm(test_df.iterrows(), total=test_df.shape[0]):
    uid = row['StudyInstanceUID']
    image_folder = row['image_folder']
    X = load_dicom_volume(image_folder)
    X = rescale(X)
    X = torch.from_numpy(X).float().unsqueeze(0).unsqueeze(0)
    # X.shape = (1, 1, num_images, height, width)

    cspine_coords = get_cervical_spine_coordinates(X, segmentation_inference_size, cspine_segmentation_models, threshold, adjustment, uid)
    chunk_features = defaultdict(list)
    slice_features = defaultdict(list)

    for level, coords in cspine_coords.items():
        if not isinstance(coords, tuple):
            print(f"C{level+1} not found ... Using 0-vector ...")
            for fold, model in enumerate(feature_extractors_3d):
                chunk_features[fold].append(torch.zeros((1, 432)).float().to(device))
            for fold, model in enumerate(feature_extractors_2d):
                slice_features[fold].append(torch.zeros((1, 256)).float().to(device))
        else:
            x1, x2, y1, y2, z1, z2 = coords
            if ((z2-z1==0) or (y2-y1==0) or (x2-x1==0)):
                print(f"C{level+1} has 0 shape ... Using 0-vector ...")
                for fold, model in enumerate(feature_extractors_3d):
                    chunk_features[fold].append(torch.zeros((1, 432)).float().to(device))
                for fold, model in enumerate(feature_extractors_2d):
                    slice_features[fold].append(torch.zeros((1, 256)).float().to(device))
                continue
            orig_chunk = X[:, :, x1:x2, y1:y2, z1:z2]
            
            chunk = F.interpolate(orig_chunk, size=chunk_inference_size, mode="trilinear")
            for fold, model in enumerate(feature_extractors_3d):
                chunk_features[fold].append(model.extract_features(chunk.to(device)))

            chunk = F.interpolate(orig_chunk, size=slice_inference_size, mode="trilinear")
            for fold, model in enumerate(feature_extractors_2d):
                slice_features[fold].append(model.extract_features(chunk.to(device)))

         

    for fold, features in chunk_features.items():
        chunk_features[fold] = (torch.cat(features).unsqueeze(0).to(device), torch.ones((1, 7)).float().to(device))
    for fold, features in slice_features.items():
        slice_features[fold] = (torch.cat(features).unsqueeze(0).to(device), torch.ones((1, 7)).float().to(device))
        
    fused_features = {}
    for fold in [*chunk_features]:
        fused_features[fold] = (torch.cat([chunk_features[fold][0], slice_features[fold][0]], dim=-1), torch.ones((1, 7)).float().to(device))
        
    chunk_pred_list = []
    for fold, model in enumerate(chunk_sequence_models):
        chunk_pred_list.append(torch.sigmoid(model(chunk_features[fold])).cpu().numpy())
    chunk_prediction_dict[uid] = np.mean(np.stack(chunk_pred_list, axis=0), axis=0)
    slice_pred_list = []
    for fold, model in enumerate(slice_sequence_models):
        slice_pred_list.append(torch.sigmoid(model(slice_features[fold])).cpu().numpy())
    slice_prediction_dict[uid] = np.mean(np.stack(slice_pred_list, axis=0), axis=0)
    fused_pred_list = []
    for fold, model in enumerate(fused_sequence_models):
        fused_pred_list.append(torch.sigmoid(model(fused_features[fold])).cpu().numpy())
    fused_prediction_dict[uid] = np.mean(np.stack(fused_pred_list, axis=0), axis=0)

  3%|▎         | 175/5161 [22:01<8:03:58,  5.82s/it] 

uid:1.2.826.0.1.3680043.10.474.634358.865657 C6 not found, lowering threshold to 0.3 ...
uid:1.2.826.0.1.3680043.10.474.634358.865657 C6 not found, lowering threshold to 0.2 ...
uid:1.2.826.0.1.3680043.10.474.634358.865657 C6 not found, lowering threshold to 0.1 ...
uid:1.2.826.0.1.3680043.10.474.634358.865657 Segmentation for C6 failed !
uid:1.2.826.0.1.3680043.10.474.634358.865657 C7 not found, lowering threshold to 0.3 ...
uid:1.2.826.0.1.3680043.10.474.634358.865657 C7 not found, lowering threshold to 0.2 ...
uid:1.2.826.0.1.3680043.10.474.634358.865657 C7 not found, lowering threshold to 0.1 ...
uid:1.2.826.0.1.3680043.10.474.634358.865657 Segmentation for C7 failed !


  3%|▎         | 176/5161 [22:06<7:19:53,  5.29s/it]

C6 not found ... Using 0-vector ...
C7 not found ... Using 0-vector ...


  4%|▎         | 193/5161 [23:55<8:46:46,  6.36s/it]

uid:1.2.826.0.1.3680043.10.474.634358.874233 C5 not found, lowering threshold to 0.3 ...
uid:1.2.826.0.1.3680043.10.474.634358.874233 C5 not found, lowering threshold to 0.2 ...
uid:1.2.826.0.1.3680043.10.474.634358.874233 C5 not found, lowering threshold to 0.1 ...
uid:1.2.826.0.1.3680043.10.474.634358.874233 Segmentation for C5 failed !
C5 not found ... Using 0-vector ...


  4%|▍         | 208/5161 [25:38<9:05:41,  6.61s/it] 

In [None]:
# def competition_metric(p, t):
#     # p.shape = t.shape = (N, 8)
#     p = torch.from_numpy(p).float()
#     t = torch.from_numpy(t).float()
#     loss_matrix = F.binary_cross_entropy(p, t, reduction="none")
#     # loss_matrix.shape = (N, 8)
#     columnwise_losses = []
#     for col in range(loss_matrix.shape[1]):
#         weights = t[:, col] + 1 # positives are weighted 2x
#         columnwise_losses.append(((loss_matrix[:, col] * weights).sum() / weights.sum()).item())
#     columnwise_losses[-1] *= 7.0
#     return np.sum(columnwise_losses) / 14.0


# def auc(p, t):
#     if len(np.unique(t)) == 1:
#         return 0.5
#     return roc_auc_score(t, p)

In [None]:
# study_id_list = []
# predictions_list = []

# for study_id, pred in chunk_prediction_dict.items():
#     study_id_list.append(study_id.split("/")[-1])
#     predictions_list.append(pred)
    
# pred_df = pd.DataFrame(np.concatenate(predictions_list))
# pred_df.columns = [f"C{_+1}_pred" for _ in range(7)] + ["patient_overall_pred"]
# pred_df["StudyInstanceUID"] = study_id_list

# train_df = pd.read_csv("./rsna-2022-cervical-spine-fracture-detection/train.csv")
# pred_df = pred_df.merge(train_df, on="StudyInstanceUID")
# pred_df

# t_columns = [f"C{i+1}" for i in range(7)] + ["patient_overall"]
# p_columns = [c + "_pred" for c in t_columns]

# print(f"COMP. METRIC : {competition_metric(pred_df[p_columns].values, pred_df[t_columns].values):0.3f}")

# for i in range(len(t_columns)):
#     prefix = "AUC[overall] : " if i == len(t_columns) - 1 else f"AUC[C{i+1}]      : "
#     p = pred_df[p_columns[i]].values
#     t = pred_df[t_columns[i]].values
#     print(f"{prefix}{auc(p, t):0.3f}")

In [None]:
# study_id_list = []
# predictions_list = []

# for study_id, pred in slice_prediction_dict.items():
#     study_id_list.append(study_id.split("/")[-1])
#     predictions_list.append(pred)
    
# pred_df = pd.DataFrame(np.concatenate(predictions_list))
# pred_df.columns = [f"C{_+1}_pred" for _ in range(7)] + ["patient_overall_pred"]
# pred_df["StudyInstanceUID"] = study_id_list

# train_df = pd.read_csv("./rsna-2022-cervical-spine-fracture-detection/train.csv")
# pred_df = pred_df.merge(train_df, on="StudyInstanceUID")
# pred_df

# t_columns = [f"C{i+1}" for i in range(7)] + ["patient_overall"]
# p_columns = [c + "_pred" for c in t_columns]

# print(f"COMP. METRIC : {competition_metric(pred_df[p_columns].values, pred_df[t_columns].values):0.3f}")

# for i in range(len(t_columns)):
#     prefix = "AUC[overall] : " if i == len(t_columns) - 1 else f"AUC[C{i+1}]      : "
#     p = pred_df[p_columns[i]].values
#     t = pred_df[t_columns[i]].values
#     print(f"{prefix}{auc(p, t):0.3f}")

In [None]:
# study_id_list = []
# predictions_list = []

# for study_id, pred in fused_prediction_dict.items():
#     study_id_list.append(study_id.split("/")[-1])
#     predictions_list.append(pred)
    
# pred_df = pd.DataFrame(np.concatenate(predictions_list))
# pred_df.columns = [f"C{_+1}_pred" for _ in range(7)] + ["patient_overall_pred"]
# pred_df["StudyInstanceUID"] = study_id_list

# train_df = pd.read_csv("./rsna-2022-cervical-spine-fracture-detection/train.csv")
# pred_df = pred_df.merge(train_df, on="StudyInstanceUID")
# pred_df

# t_columns = [f"C{i+1}" for i in range(7)] + ["patient_overall"]
# p_columns = [c + "_pred" for c in t_columns]

# print(f"COMP. METRIC : {competition_metric(pred_df[p_columns].values, pred_df[t_columns].values):0.3f}")

# for i in range(len(t_columns)):
#     prefix = "AUC[overall] : " if i == len(t_columns) - 1 else f"AUC[C{i+1}]      : "
#     p = pred_df[p_columns[i]].values
#     t = pred_df[t_columns[i]].values
#     print(f"{prefix}{auc(p, t):0.3f}")

In [None]:
# study_id_list = []
# predictions_list = []

# chunk_weight, chunk_slice_weight, slice_weight = 0.35, 0.35, 0.3
# assert chunk_weight + chunk_slice_weight + slice_weight == 1
# for study_id in [*slice_prediction_dict]:
#     study_id_list.append(study_id.split("/")[-1])
#     predictions_list.append(chunk_weight * chunk_prediction_dict[study_id] + \
#                             chunk_slice_weight * chunk_slice_prediction_dict[study_id] + \
#                             slice_weight * slice_prediction_dict[study_id])

# pred_df = pd.DataFrame(np.concatenate(predictions_list))
# pred_df.columns = [f"C{_+1}_pred" for _ in range(7)] + ["patient_overall_pred"]
# pred_df["StudyInstanceUID"] = study_id_list

# train_df = pd.read_csv("./rsna-2022-cervical-spine-fracture-detection/train.csv")
# pred_df = pred_df.merge(train_df, on="StudyInstanceUID")
# pred_df

# t_columns = [f"C{i+1}" for i in range(7)] + ["patient_overall"]
# p_columns = [c + "_pred" for c in t_columns]

# print(f"COMP. METRIC : {competition_metric(pred_df[p_columns].values, pred_df[t_columns].values):0.3f}")

# for i in range(len(t_columns)):
#     prefix = "AUC[overall] : " if i == len(t_columns) - 1 else f"AUC[C{i+1}]      : "
#     p = pred_df[p_columns[i]].values
#     t = pred_df[t_columns[i]].values
#     print(f"{prefix}{auc(p, t):0.3f}")

## Create Submission DataFrame

In [None]:
ensemble_pred_dict = {}
chunk_weight, slice_weight, fused_weight = 0.25, 0.25, 0.5
for study_id in [*slice_prediction_dict]:
    ensemble_pred_dict[study_id] = chunk_weight * chunk_prediction_dict[study_id] + \
                                    slice_weight * slice_prediction_dict[study_id] + \
                                    fused_weight * fused_prediction_dict[study_id]

row_id_list = []
fractured_list = []
for k, v in ensemble_pred_dict.items():
    for label_ind, label in enumerate(v[0]):
        row_id = f"{k}_C{label_ind + 1}" if label_ind < 7 else f"{k}_patient_overall"
        row_id_list.append(row_id)
        fractured_list.append(label)
        
sub_df = pd.DataFrame({"row_id": row_id_list, "fractured": fractured_list})
sub_df

In [None]:
sub_df.to_csv(OUTPUT_FILE, index=False)

In [None]:
#plot_volume(spine_map.numpy(), sagittal=True)