In [1]:
import os
import pickle
import random
from tqdm import tqdm

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision.transforms import Compose, RandomHorizontalFlip, ColorJitter, RandomAffine, RandomErasing, ToTensor, Resize
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score

import numpy as np
import pandas as pd
import pydicom
import pdb

import torch
import torch.nn as nn

import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, List, NamedTuple, Optional, Dict

import torch
import torch.nn as nn

from torchvision.ops import Conv2dNormActivation, MLP

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

from IPython.display import clear_output, display
import time
import matplotlib.pyplot as plt

def show_img(img):
    plt.figure(figsize=(5, 5))
    plt.imshow(img, cmap="gray")
    plt.axis(False)
    plt.show()
    clear_output(wait=True)
    time.sleep(0.01)

In [2]:
BASEDIR = '../rsna-2023-abdominal-trauma-detection'

TRAIN_IMG_PATH = os.path.join(BASEDIR, 'train_images')
TRAIN_META_PATH = os.path.join(BASEDIR, 'train_series_meta.csv')
TEST_IMG_PATH = os.path.join(BASEDIR, 'test_images')
TEST_META_PATH = os.path.join(BASEDIR, 'test_series_meta.csv')

TRAIN_LABEL_PATH = os.path.join(BASEDIR, 'train.csv')

In [3]:
from skimage.transform import resize

def fetch_img_paths():
    img_paths = []
    
    for patient in tqdm(os.listdir(TRAIN_IMG_PATH)):
        for scan in os.listdir(os.path.join(TRAIN_IMG_PATH, patient)):
            scans = []
            for img in os.listdir(os.path.join(TRAIN_IMG_PATH, patient, scan)):
                scans.append(os.path.join(TRAIN_IMG_PATH, patient, scan, img))
            img_paths.append(scans)
            
    return img_paths


def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    """
    Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
    """
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    pixel_array = dcm.pixel_array
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        pixel_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
#         pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)

    intercept = float(dcm.RescaleIntercept)
    slope = float(dcm.RescaleSlope)
    center = int(dcm.WindowCenter)
    width = int(dcm.WindowWidth)
    low = center - width / 2
    high = center + width / 2    
    
    pixel_array = (pixel_array * slope) + intercept
    pixel_array = np.clip(pixel_array, low, high)

    return pixel_array

def dcm_read(f):
    dicom = pydicom.dcmread(f)

    img = standardize_pixel_array(dicom)
    img = (img - img.min()) / (img.max() - img.min() + 1e-6)

    if dicom.PhotometricInterpretation == "MONOCHROME1":
        img = 1 - img
    
    img = resize(img, (512, 512), anti_aliasing=True) # sklearn image resize

    return img

## Dataloader

In [14]:
import torch

def interpolate_channels(img_tensor):
    # Get the current number of channels
    C, H, W = img_tensor.shape

    # Initialize the output tensor
    output = torch.zeros((80, H, W))

    # Handle the edge case when C is 1
    if C == 1:
        for i in range(80):
            output[i] = img_tensor[0]
        return output

    # Handle the edge case when C is 2
    if C == 2:
        for i in range(40):
            output[i] = img_tensor[0]
        for i in range(40, 80):
            output[i] = img_tensor[1]
        return output

    # If channels are already 80 or more, return the original image
    if C >= 80:
        return img_tensor

    # Set the first and last channels
    output[0] = img_tensor[0]
    output[79] = img_tensor[-1]

    # Calculate the step for even spacing
    step = 78 / (C - 2)

    # Evenly space the remaining original channels in the range 1-78
    for i in range(1, C - 1):
        output[int(1 + i * step)] = img_tensor[i]

    # Perform linear interpolation
    for i in range(1, 79):
        if output[i].sum() == 0:
            left = i - 1
            right = i + 1
            while output[left].sum() == 0:
                left -= 1
            while output[right].sum() == 0:
                right += 1

            alpha = (i - left) / (right - left)

            output[i] = (1 - alpha) * output[left] + alpha * output[right]

    return output


class AbdominalTestData(Dataset):
    def __init__(self):
        super().__init__()
        self.img_paths = fetch_img_paths()
                
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        dicom_images = self.img_paths[idx]
        dicom_images = sorted(dicom_images)

        patient_id = int(dicom_images[0].split('/')[-3])
        series_id = int(dicom_images[0].split('/')[-2])
        
        images = []
        for d in dicom_images:
            image = dcm_read(d)
            # show_img(image)
            images.append(image)
        
        images = np.stack(images)
        image = torch.tensor(images, dtype = torch.float32)
        image = interpolate_channels(image)
        center_idx = image.shape[0] // 2
        image = image[center_idx-40:center_idx+40:2]
                                
        return image, {
            'patient_id': patient_id,
            'series_id': series_id
        }

## Inference Net

In [5]:
class ConvStemConfig(NamedTuple):
    out_channels: int
    kernel_size: int
    stride: int
    norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
    activation_layer: Callable[..., nn.Module] = nn.ReLU


class MLPBlock(MLP):

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )


class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y


class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input + self.pos_embedding
        return self.ln(self.layers(self.dropout(input)))


class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
    ):
        super().__init__()
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer

        if conv_stem_configs is not None:
            # As per https://arxiv.org/abs/2106.14881
            seq_proj = nn.Sequential()
            prev_channels = 3
            for i, conv_stem_layer_config in enumerate(conv_stem_configs):
                seq_proj.add_module(
                    f"conv_bn_relu_{i}",
                    Conv2dNormActivation(
                        in_channels=prev_channels,
                        out_channels=conv_stem_layer_config.out_channels,
                        kernel_size=conv_stem_layer_config.kernel_size,
                        stride=conv_stem_layer_config.stride,
                        norm_layer=conv_stem_layer_config.norm_layer,
                        activation_layer=conv_stem_layer_config.activation_layer,
                    ),
                )
                prev_channels = conv_stem_layer_config.out_channels
            seq_proj.add_module(
                "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
            )
            self.conv_proj: nn.Module = seq_proj
        else:
            self.conv_proj = nn.Conv2d(
                in_channels=16, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
            )

        seq_length = (image_size // patch_size) ** 2

        # Add a class token
        # self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        # seq_length += 1

        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, "Wrong image height!")
        torch._assert(w == self.image_size, "Wrong image width!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        # batch_class_token = self.class_token.expand(n, -1, -1)
        # x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        # x = x[:, 0]

        # x = self.heads(x)


        return x




class RSNA_model(nn.Module):
    def __init__(self):
        super().__init__()

        # [40, 512, 512]
        self.block_1 = nn.Sequential(
                nn.Conv2d(in_channels = 40, out_channels = 32, kernel_size = 5, stride = 2, padding = 2),
                # -> [32, 256, 256]
                nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride = 1, padding = 1),
                nn.MaxPool2d(kernel_size = 2, stride = 2)  # [32, 128, 128]
                )

        # [32, 128, 128]
        self.block_2 = nn.Sequential(
                nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 5, stride = 2, padding = 2),
                # -> [32, 64, 64]
                nn.Conv2d(in_channels = 32, out_channels = 16, kernel_size = 3, stride = 1, padding = 1),
                nn.MaxPool2d(kernel_size = 2, stride = 2)  # [16, 32, 32]
                )
        self.conv_blocks = nn.Sequential(self.block_1, self.block_2)

        self.vit = VisionTransformer(image_size = 32, patch_size = 4, num_layers = 12, num_heads = 4,
                                              hidden_dim = 256, mlp_dim = 1024)  # [64, 256]

        self.flatten = nn.Flatten(1)  # [16384]

        self.head_bowel = nn.Sequential(nn.Linear(in_features = 16384, out_features = 1024),
                                        nn.Linear(in_features = 1024, out_features = 2),
                                        nn.Softmax(dim=1))
        self.head_ext = nn.Sequential(nn.Linear(in_features = 16384, out_features = 1024),
                                      nn.Linear(in_features = 1024, out_features = 2),
                                      nn.Softmax(dim=1))
        self.head_kidney = nn.Sequential(nn.Linear(in_features = 16384, out_features = 1024),
                                         nn.Linear(in_features = 1024, out_features = 3),
                                         nn.Softmax(dim=1))
        self.head_liver = nn.Sequential(nn.Linear(in_features = 16384, out_features = 1024),
                                        nn.Linear(in_features = 1024, out_features = 3),
                                        nn.Softmax(dim=1))
        self.head_spleen = nn.Sequential(nn.Linear(in_features = 16384, out_features = 1024),
                                         nn.Linear(in_features = 1024, out_features = 3),
                                         nn.Softmax(dim=1))

    def forward(self, x):
        latent_space = self.flatten(self.vit(self.conv_blocks(x)))
        return self.head_bowel(latent_space), \
                self.head_ext(latent_space), \
                self.head_kidney(latent_space), \
                self.head_liver(latent_space), \
                self.head_spleen(latent_space)

In [15]:
data = AbdominalTestData()
ten_percent = int(0.5 * len(data))
train_size = int(0.7 * ten_percent)
val_size = ten_percent - train_size
unused_size = len(data) - ten_percent
_, val_data, _ = random_split(data, [train_size, val_size, unused_size])

100%|██████████| 3147/3147 [00:04<00:00, 729.15it/s]


In [16]:
data[0]

KeyboardInterrupt: 

: 

In [6]:


dataloader = DataLoader(val_data, batch_size=1, shuffle=False)
device = torch.device('cuda:2')

unet = RSNA_model().to(device)
unet.load_state_dict(torch.load('./unet_sgd_sub_1.0_16_5_0.01.pth', map_location=device))
unet.eval()

def calculate_any_injury(submission: pd.DataFrame) -> pd.Series:
    # Define the label groups
    binary_targets = ['bowel', 'extravasation']
    triple_level_targets = ['kidney', 'liver', 'spleen']
    all_target_categories = binary_targets + triple_level_targets

    # Derive the any_injury label by taking the max of 1 - p(healthy) for each label group
    healthy_cols = [x + '_healthy' for x in all_target_categories]
    any_injury_predictions = (1 - submission[healthy_cols]).max(axis=1)

    return any_injury_predictions


rows = []

for batch, (X, y) in enumerate(dataloader):
    X = X.to(device)
    y_pred = unet(X)
    
    bowel = y_pred[0].data.cpu().numpy().tolist()[0]
    extravasation = y_pred[1].data.cpu().numpy().tolist()[0]
    kidney = y_pred[2].data.cpu().numpy().tolist()[0]
    liver = y_pred[3].data.cpu().numpy().tolist()[0]
    spleen = y_pred[4].data.cpu().numpy().tolist()[0]
    
    injury_probabilities = [
        bowel[1]*2, # bowel_injury
        extravasation[1]*6, # extravasation_injury
        kidney[1]*2, # kidney_low
        kidney[2]*4, # kidney_high
        liver[1]*2, # liver_low
        liver[2]*4, # liver_high
        spleen[1]*2, # spleen_low
        spleen[2]*4  # spleen_high
    ]

    # Calculating the average
#     avg_injury = sum(injury_probabilities) / (2*3 + 2*4 + 2 + 6)

#     print(bowel, extravasation, kidney, liver, spleen)

    row = {
        'patient_id': y['patient_id'].item(),
        'series_id': y['series_id'].item(),
        'bowel_healthy': bowel[0],
        'bowel_injury': bowel[1], 
        'extravasation_healthy': extravasation[0],
        'extravasation_injury': extravasation[1], 
        'kidney_healthy': kidney[0],
        'kidney_low': kidney[1],
        'kidney_high': kidney[2], 
        'liver_healthy': liver[0],
        'liver_low': liver[1],
        'liver_high': liver[2], 
        'spleen_healthy': spleen[0],
        'spleen_low': spleen[1],
        'spleen_high': spleen[2],
#         'any_injury': avg_injury
    }
    rows.append(row)
    
submit_df = pd.DataFrame(rows)
submit_df['any_injury'] = calculate_any_injury(submit_df)
submit_df = submit_df.drop(columns=['series_id'])
submit_df = submit_df.groupby('patient_id').mean().reset_index()

submit_df

100%|██████████| 3147/3147 [00:04<00:00, 737.02it/s]


KeyboardInterrupt: 

In [None]:
submit_df.to_csv('submission.csv', index=False) 