## Imports

In [1]:
!pip install /kaggle/input/rsna-atd-packages/dicomsdl-0.109.2-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl

Processing /kaggle/input/rsna-atd-packages/dicomsdl-0.109.2-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
Installing collected packages: dicomsdl
Successfully installed dicomsdl-0.109.2


In [2]:
import sys
sys.path.append("/kaggle/input/efficientnet-pytorch")
sys.path.append("/kaggle/input/monai-v101")
sys.path.append("/kaggle/input/pretrained-models-pytorch")
sys.path.append("/kaggle/input/rsna-atd-packages")
sys.path.append("/kaggle/input/smp-github/segmentation_models.pytorch-master")

import albumentations as A
import attention
import conv3d_same
import cv2
import dicomsdl as dsdl
import gc
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
import os
import pandas as pd
import pydicom as dicom
import pytorch_lightning as pl
import random
import timm
import timm.models.layers as layers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import yaml

from albumentations.pytorch import ToTensorV2
from ast import literal_eval
from functools import partial
from heads import ClassificationHead
from monai.transforms import Resize
from segmentation_models_pytorch.base.initialization import initialize_decoder, initialize_head
from sklearn.metrics import log_loss
from timm import create_model
from timm.models.layers import Conv2dSame
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

  if block_type is 'proj':
  elif block_type is 'down':
  assert block_type is 'normal'


## Segmentation

### Modules

In [3]:
def get_activation(activation):
    if activation is None:
        return nn.Identity()
    elif activation == "relu":
        return nn.ReLU(inplace=True)
    elif activation == "silu":
        return nn.SiLU(inplace=True)
    else:
        raise ValueError(f"Activation {activation} is not supported.")


class SeparableConv2d(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        bias=True,
    ):
        dephtwise_conv = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=False,
        )
        pointwise_conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            bias=bias,
        )
        super().__init__(dephtwise_conv, pointwise_conv)


class SeparableConvBnAct(nn.Sequential):
    def __init__(
        self, 
        in_channels, 
        out_channels, 
        kernel_size=3,
        stride=1,
        padding=1,
        dilation=1,
        use_batchnorm=True, 
        activation="silu"
    ):
        conv = SeparableConv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=not (use_batchnorm),
        )
        bn = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
        act = get_activation(activation)
        super(SeparableConvBnAct, self).__init__(conv, bn, act)


class ConvBnAct(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding=0,
        stride=1,
        use_batchnorm=True,
        activation=None
    ):
        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        bn = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
        act = get_activation(activation)
        super(ConvBnAct, self).__init__(conv, bn, act)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels, activation="silu"):
        super().__init__(
            nn.AdaptiveAvgPool2d(1),
            ConvBnAct(in_channels, out_channels, kernel_size=1, activation=activation)
        )

    def forward(self, x):
        size = x.shape[-2:]
        for module in self:
            x = module(x)
        return F.interpolate(x, size=size, mode="bilinear", align_corners=False)


class ASPP(nn.Module):
    def __init__(
        self, 
        in_channels,
        out_channels, 
        atrous_rates, 
        reduction=1,
        dropout=0.2, 
        activation="silu"
    ):
        super(ASPP, self).__init__()
        modules = []
        modules.append(
            ConvBnAct(
                in_channels, 
                out_channels // reduction, 
                kernel_size=1, 
                padding=0,
                stride=1,
                use_batchnorm=True,
                activation=activation
            )
        )
        for r in atrous_rates:
            modules.append(
                SeparableConvBnAct(
                in_channels, 
                out_channels // reduction, 
                kernel_size=3,
                stride=1,
                padding=r,
                dilation=r,
                use_batchnorm=True,
                activation=activation
            ))
        modules.append(ASPPPooling(in_channels, out_channels // reduction, activation=activation))
        self.body = nn.ModuleList(modules)
        self.project = nn.Sequential(
            ConvBnAct(
                (len(atrous_rates) + 2) * out_channels // reduction, 
                out_channels, 
                kernel_size=1, 
                padding=0, 
                stride=1,
                use_batchnorm=True,
                activation=activation
            ),
            nn.Dropout(dropout)
        )

    def forward(self, x, scale_factor=1):
        if scale_factor != 1:
            x = F.interpolate(x, scale_factor=scale_factor, mode="bilinear")
        results = []
        for module in self.body:
            results.append(module(x))
        results = torch.cat(results, dim=1)
        return self.project(results)


class SegmentationHead(nn.Sequential):
    def __init__(
        self, 
        in_channels, 
        out_channels, 
        kernel_size=3, 
        padding=1, 
        upsampling=1
    ):
        blocks = [
            ConvBnAct(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=padding,
                stride=1,
                use_batchnorm=False,
                activation=None
            ),
            nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        ]
        super(SegmentationHead, self).__init__(*blocks)


class SCSEModule(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())

    def forward(self, x):
        return (x * self.cSE(x) + x * self.sSE(x)) / 2.
    

class Attention(nn.Module):
    def __init__(self, name, **params):
        super().__init__()
        if name is None:
            self.attention = nn.Identity(**params)
        elif name == "scse":
            self.attention = SCSEModule(**params)
        else:
            raise ValueError("Attention type {} is not implemented".format(name))

    def forward(self, x):
        return self.attention(x)

### Encoder

In [4]:
def create_encoder(encoder_params):
    module = getattr(sys.modules[__name__], encoder_params["class"])
    name = encoder_params["encoder_name"]
    return module(name=name, **encoder_params["params"])


class BaseEncoder(nn.Module):
    def __init__(self, out_channels, **kwargs):
        super().__init__()
        self.out_channels = out_channels

    def get_stages(self):
        return [nn.Identity()]

    def forward(self, x):
        stages = self.get_stages()
        features = []
        for stage in stages:
            x = stage(x)
            features.append(x)
        return features


class EfficientNetEncoder2d(BaseEncoder):
    def __init__(
        self, 
        name,
        stage_idx,
        backbone_params={},
        **kwargs
    ):
        super().__init__(**kwargs)
        self.encoder = create_model(name, **backbone_params)
        assert len(stage_idx) <= len(self.encoder.blocks)
        self.stage_idx = stage_idx
        self.depth = len(stage_idx) + 2

    def get_stages(self):
        return [nn.Identity(), nn.Sequential(self.encoder.conv_stem, self.encoder.bn1)] + \
            [self.encoder.blocks[i : j] for i, j in zip([0] + self.stage_idx, self.stage_idx + [len(self.encoder.blocks)])]
    
    def forward_head(self, x):
        x = self.encoder.conv_head(x)
        x = self.encoder.bn2(x)
        x = self.encoder.forward_head(x)
        return x

### Decoder

In [5]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        block_depth=1,
        separable=False,
        use_aspp=False,
        use_batchnorm=True,
        attention_type=None,
        activation="relu"
    ):
        super().__init__()
        self.attention = nn.ModuleList([
            Attention(attention_type, in_channels=in_channels + skip_channels),
            Attention(attention_type, in_channels=out_channels)
        ])
        self.aspp = ASPP(
            in_channels,
            in_channels,
            atrous_rates=[1, 2, 4],
            reduction=2,
            dropout=0.2,
            activation=activation
        ) if use_aspp else nn.Identity()
        module = SeparableConvBnAct if separable else ConvBnAct
        self.stem = module(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
            activation=activation
        )
        self.body = nn.Sequential(*[
            module(
                out_channels, 
                out_channels, 
                kernel_size=3, 
                padding=1, 
                use_batchnorm=use_batchnorm,
                activation=activation
            ) for _ in range(block_depth)
         ])

    def forward(self, x, skip=None, scale_factor=1):
        if scale_factor != 1:
            x = F.interpolate(x, scale_factor=scale_factor, mode="trilinear")
        x = self.aspp(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention[0](x)
        x = self.stem(x)
        x = self.body(x)
        x = self.attention[1](x)
        return x


class UnetDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels,
        scale_factors,
        num_blocks=5,
        block_depth=1,
        separable=False,
        use_aspp=False,
        use_batchnorm=True,
        attention_type=None,
        activation="relu"
    ):
        super().__init__()
        assert num_blocks >= len(encoder_channels) - 1
        assert num_blocks == len(decoder_channels)
        assert num_blocks == len(scale_factors)
        self.scale_factors = scale_factors
        encoder_channels = encoder_channels[1:][::-1]
        in_channels = [encoder_channels[0]] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:])
        skip_channels += [0] * (len(in_channels) - len(skip_channels))
        out_channels = decoder_channels
        aspp_idx = len(in_channels) - 2
        blocks = []
        for i, (i_ch, s_ch, o_ch) in enumerate(zip(in_channels, skip_channels, out_channels)):
            blocks.append(
                DecoderBlock(
                    i_ch, 
                    s_ch, 
                    o_ch, 
                    block_depth,
                    separable=separable,
                    use_aspp=use_aspp if i == aspp_idx else False,
                    use_batchnorm=use_batchnorm, 
                    attention_type=attention_type,
                    activation=activation
                )
            )
        self.blocks = nn.ModuleList(blocks)

    def forward(self, *features):
        features = features[1:][::-1]
        x = features[0]
        skips = features[1:]
        for i, (block, scale_factor) in enumerate(zip(self.blocks, self.scale_factors)):
            skip = skips[i] if i < len(skips) else None
            x = block(x, skip, scale_factor)
        return x

### Model

In [6]:
def create_segmentation_model(config):
    config_ = config.copy()
    family = config_.pop("family")
    if family == "unet":
        return inflate_module(Unet(**config_))
    else:
        raise ValueError(f"Model family {family} is not supported.")
    

def inflate_module(module):
    module_ = module
    if isinstance(module, nn.BatchNorm2d):
        module_ = nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_.weight = module.weight
                module_.bias = module.bias
        module_.running_mean = module.running_mean
        module_.running_var = module.running_var
        module_.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_.qconfig = module.qconfig
    elif isinstance(module, Conv2dSame):
        module_ = conv3d_same.Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1, 1, 1, 1,module.kernel_size[0]))
    elif isinstance(module, torch.nn.Conv2d):
        module_ = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1, 1, 1, 1,module.kernel_size[0]))
    elif isinstance(module, torch.nn.MaxPool2d):
        module_ = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_ = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )
    for name, child in module.named_children():
        module_.add_module(name, inflate_module(child))
    del module
    return module_


class Unet(nn.Module):
    def __init__(
        self, 
        encoder_params,
        decoder_params, 
        num_classes=1
    ):
        super().__init__()
        self.encoder = create_encoder(encoder_params)
        self.decoder = UnetDecoder(self.encoder.out_channels, **decoder_params)
        self.head = SegmentationHead(
            decoder_params["decoder_channels"][-1], 
            num_classes, 
            kernel_size=3,
            padding=1, 
            upsampling=1
        )
        initialize_decoder(self.decoder)
        initialize_head(self.head)

    def forward(self, x):
        features = self.encoder(x)
        decoder_output = self.decoder(*features)
        logits = self.head(decoder_output)
        return logits
    
    
class SegmentationModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = create_segmentation_model(config["model"])

## Slice Model

In [7]:
def create_slice_classification_model(encoder_params):
    module = getattr(sys.modules[__name__], encoder_params["class"])
    name = encoder_params["encoder_name"]
    return module(name=name, **encoder_params["params"])


class SliceClassificationModel(nn.Module):
    def __init__(self, name, backbone_params, dropout):
        super().__init__()
        self.num_channels = backbone_params["in_chans"]
        self.encoder = create_model(name, num_classes=0, **backbone_params)
        feature_dim = self.feature_dim()
        self.drop = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.head = ClassificationHead(feature_dim, (2, 3, 4, 4, 4, 5))

    def feature_dim(self):
        x = torch.randn(2, self.num_channels, 256, 256)
        return self.encoder(x).shape[-1]

    def forward_features(self, x):
        return self.encoder(x)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.drop(x)
        return self.head(x)
    
    
class SliceClassificationModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = create_slice_classification_model(config["model"])

## Scan Model

In [8]:
def create_scan_classification_model(config):
    return ScanClassificationModel(**config)


class ScanClassificationModel(torch.nn.Module):
    def __init__(
        self, 
        time_dim, 
        feature_dim, 
        hidden_dim, 
        num_layers=1,
        dropout=0.2,
        bidirectional=True
    ):
        super().__init__()
        self.lstm = nn.GRU(
            feature_dim, 
            hidden_dim, 
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=True
        )
        scale_factor = 2 if bidirectional else 1
        self.attention = attention.Attention(time_dim, hidden_dim * scale_factor)
        self.drop = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.seg_head = ClassificationHead(scale_factor * hidden_dim, [5])
        self.clf_head = ClassificationHead(scale_factor * hidden_dim * 2, [2, 2, 3, 3, 3])

    def forward(self, x, mask):
        x, _ = self.lstm(x)
        x = self.drop(x)
        logits_list = self.seg_head(x * torch.unsqueeze(mask, dim=-1))
        max_pool, _ = torch.max(x, dim=1)
        att_pool = self.attention(x, mask)
        cat = torch.cat([max_pool, att_pool], dim=1)
        logits_list = self.clf_head(cat) + logits_list
        return logits_list
    
    
class ScanClassificationModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = create_scan_classification_model(config["model"])

## Dataset

In [9]:
DATA_DIR = "/kaggle/input/rsna-2023-abdominal-trauma-detection"
SEGMENTATION_DIM = [128, 224, 224]
INJURY_CATEGORIES = [
    "any",
    "extravasation",
    "bowel",
    "liver",
    "spleen",
    "kidney"
]

In [10]:
def rescale_slice_to_array_dicomsdl(slice):
    image = slice.pixelData(storedvalue=True)
    info = slice.getPixelDataInfo()
    if slice.PixelRepresentation == 1:
        bit_shift = slice.BitsAllocated - slice.BitsStored
        dtype = image.dtype
        image = (image << bit_shift).astype(dtype) >> bit_shift
    slope, intercept = 1.0, 0.0
    center, width = 50, 400
    if "RescaleSlope" in info and "RescaleIntercept" in info:
        slope, intercept = slice.RescaleSlope, slice.RescaleIntercept
    if "WindowCenter" in info and "WindowWidth" in info:
        center, width = slice.WindowCenter, slice.WindowWidth
    low, high = center - width / 2, center + width / 2
    image = np.clip(image * slope + intercept, low, high)
    return image


def normalize_min_max(image, eps=1.0e-8):
    return (image - image.min()) / (image.max() - image.min() + eps)


def convert_to_uint8(image):
    return (image * 255.).astype(np.uint8)


def resize_volume(x):
    x = Resize(SEGMENTATION_DIM, mode="trilinear")(x).numpy().astype(np.uint8)
    return x


def flip_scan(slices):
    x0, x1 = slices
    if x1.ImagePositionPatient[2] > x0.ImagePositionPatient[2]:
        return True
    return False


def get_scan_path(patient_id, series_id, split="train"):
    return os.path.join(DATA_DIR, f"{split}_images", str(patient_id), str(series_id))


def get_slice_path(scan_path, slice_idx):
    return os.path.join(scan_path, f"{slice_idx}.dcm")


def get_sorted_slices(scan_path):
    return sorted([int(p.split(".")[0]) for p in os.listdir(scan_path)])


def sample_indices(n, m):
    return np.quantile(list(range(n)), np.linspace(0., 1., m)).round().astype(int)

In [11]:
def get_step_size(n):
    step = 1
    if n > 256:
        step = 2
    if n > 512:
        step = 4
    return step


def load_scan(scan_path):
    sorted_slices = get_sorted_slices(scan_path)
    step = get_step_size(len(sorted_slices))
    slices = []
    for i in sorted_slices[::step]:
        try:
            slices.append(dsdl.open(get_slice_path(scan_path, i)))
        except:
            pass
    flip = False
    if len(slices) > 1:
        flip = flip_scan(slices[:2])
    if flip:
        slices = slices[::-1]
    image = []
    for slice in slices:
        try:
            image.append(rescale_slice_to_array_dicomsdl(slice))
        except:
            pass
    if len(image) == 0:
        image = np.zeros(128, 256, 256)
    else:
        image = np.stack(image, axis=0)
    return image


def clean_bounds(bounds):
    x0, x1 = min([b[0][0] for b in bounds]), max([b[0][1] for b in bounds])
    y0, y1 = min([b[1][0] for b in bounds]), max([b[1][1] for b in bounds])
    return (x0, x1), (y0, y1)


def crop(x, bounds, i):
    (x0, x1), (y0, y1) = clean_bounds([b[i] for b in bounds])
    return x[x0 : x1, y0 : y1]


def resize_slice(image, image_size, numpy=True):
    if numpy:
        image = A.Compose([
            A.Resize(image_size, image_size, interpolation=cv2.INTER_LINEAR, always_apply=True)
        ])(image=image)["image"]
    else:
        image = F.interpolate(image[None, None, :], size=(image_size, image_size), mode="bilinear")[0, 0]
    return image


def normalize_slice(image, num_channels=3):
    tf = A.Compose([
        A.Normalize(mean=[0.5] * num_channels, std=[0.5] * num_channels, max_pixel_value=1.0, always_apply=True),
        ToTensorV2(always_apply=True)
    ])(image=image)
    image = tf["image"]
    return image


def preprocess_segmentation(image, device="cuda"):
    if len(image) < SEGMENTATION_DIM[0]:
        image = np.stack([image[i] for i in sample_indices(len(image), SEGMENTATION_DIM[0])])
    image = convert_to_uint8(normalize_min_max(image))
    if image.ndim < 4:
        image = image[None, :].repeat(3, 0)
    image = resize_volume(image)
    image = image / 127.5 - 1.0
    image = torch.tensor(image).float()[None, :].to(device)
    return image


def preprocess_slice(image, mask, bounds, image_size=384, num_channels=3, device="cuda"):
    slices = []
    indices = range(len(image)) if len(image) < 256 else sample_indices(len(image), 256)
    for i in indices:
        x = image[i]
        x = normalize_min_max(x)
        x = crop(x, bounds, i=1)
        x = resize_slice(x, image_size)
        x = x[None, :].repeat(num_channels, 0)
        x = np.transpose(x, (1, 2, 0))
        x = normalize_slice(x, num_channels=num_channels).to(device)
        z = min(int(i * SEGMENTATION_DIM[0] / len(image)), SEGMENTATION_DIM[0] - 1)
        m = mask[1:, z].max(dim=0)[0]
        m = crop(m, bounds, i=0)
        m = resize_slice(m, image_size, numpy=False)
        m = 2.0 * m - 1.0
        x = torch.cat([x, m[None, :]], axis=0)
        slices.append(x)
    image = torch.stack(slices, axis=0)
    return image


def preprocess_scan(features, time_dim, device="cuda"):
    t = features.shape[0]
    attention_mask = torch.ones((time_dim,), dtype=torch.float32).to(device)
    if t > time_dim:
        x = F.interpolate(features[None, None, :], size=(time_dim, features.shape[-1]), mode="bilinear")[0]
    else:
        pad_dim = (0, 0, 0, time_dim - t)
        x = F.pad(features, pad_dim, mode="constant", value=0.0)[None, :]
        attention_mask[t:] = 0.0
    attention_mask = attention_mask[None, :]
    return x, attention_mask

## Prediction

In [12]:
def get_segmentation_bounds(n, s):
    nx, ny, nz = n
    sx, sy, sz = s
    m = len(nx)
    if m == 0:
        x0, x1 = 0, 0
        y0, y1 = 0, 0
        z0, z1 = 0, 0
    else:
        x0, x1 = int(0.8 * nx[int(0.001 * m)]), int(1.2 * nx[int(0.999 * m)])
        y0, y1 = int(0.8 * ny[int(0.001 * m)]), int(1.2 * ny[int(0.999 * m)])
        z0, z1 = int(nz[int(0.001 * m)]), int(nz[int(0.999 * m)])

    xx0, xx1 = int(x0 * sx / SEGMENTATION_DIM[1]), int(x1 * sx / SEGMENTATION_DIM[1])
    yy0, yy1 = int(y0 * sy / SEGMENTATION_DIM[2]), int(y1 * sy / SEGMENTATION_DIM[2])
    zz0, zz1 = int(z0 * sz / SEGMENTATION_DIM[0]), int(z1 * sz / SEGMENTATION_DIM[0])

    x0, x1 = max(x0, 0), min(x1, SEGMENTATION_DIM[1])
    y0, y1 = max(y0, 0), min(y1, SEGMENTATION_DIM[2])
    z0, z1 = max(z0, 0), min(z1, SEGMENTATION_DIM[0])

    xx0, xx1 = max(xx0, 0), min(xx1, sx)
    yy0, yy1 = max(yy0, 0), min(yy1, sy)
    zz0, zz1 = max(zz0, 0), min(zz1, sz)
    return [(x0, x1), (y0, y1), (z0, z1)], [(xx0, xx1), (yy0, yy1), (zz0, zz1)]


def predict_segmentation(models, image, sample=1, device="cuda"):
    models = random.sample(models, sample)
    sz = len(image)
    sx, sy = image.shape[1], image.shape[2]
    x = preprocess_segmentation(image, device=device)
    with torch.no_grad():
        logits = sum([m(x) for m in models]) / len(models)
        probs = torch.sigmoid(logits)[0]
    bounds = []
    for organ_id in range(1, 6):
        p0 = probs[organ_id] > 0.3
        p1 = probs[organ_id] > 0.1
        nz, nx, ny = torch.nonzero(p0, as_tuple=True)
        n = len(nx)
        if n == 0:
            nz, nx, ny = torch.nonzero(p1, as_tuple=True)
        nx, ny, nz = torch.sort(nx)[0], torch.sort(ny)[0], torch.sort(nz)[0]
        b0, b1 = get_segmentation_bounds((nx, ny, nz), (sx, sy, sz))
        bounds.append([b0, b1])
    return probs, bounds


def predict_slice(models, image, mask, bounds, image_size=384, num_channels=3, batch_size=16, sample=3, device="cuda"):
    models = random.sample(models, sample)
    x = preprocess_slice(image, mask, bounds, image_size=image_size, num_channels=num_channels, device=device)
    features = []
    for x_ in torch.split(x, batch_size, dim=0):
        with torch.no_grad():
            f = sum([m.forward_features(x_) for m in models]) / len(models)
        features.append(f)
    features = torch.cat(features, dim=0)
    return features


def predict_scan(models, features, time_dim=256, sample=5, device="cuda"):
    models = random.sample(models, sample)
    x, attention_mask = preprocess_scan(features, time_dim=time_dim, device=device)
    predictions = [[] for _ in range(5)]
    for model in models:
        with torch.no_grad():
            logits_list = model(x, attention_mask)[:-1]
        probs_list = [F.softmax(logits, dim=-1) for logits in logits_list]
        for i, probs in enumerate(probs_list):
            predictions[i].append(probs[0].cpu().numpy())
    predictions = [sum(p) / len(models) for p in predictions]
    return predictions

## Model Configuration

In [13]:
os.makedirs("/kaggle/working/configs/segmentation", exist_ok=True)
os.makedirs("/kaggle/working/configs/stage_1", exist_ok=True)
os.makedirs("/kaggle/working/configs/stage_2", exist_ok=True)

In [14]:
%%writefile configs/segmentation/tf_efficientnetv2_s_128_224_224.yaml
type: "segmentation"
model:
    family: "unet"
    num_classes: 6
    encoder_params:
        class: EfficientNetEncoder2d
        encoder_name: "tf_efficientnetv2_s.in21k_ft_in1k"
        params:
            out_channels: [3, 24, 48, 64, 160, 256]
            stage_idx: [2, 3, 5]
            backbone_params: 
                pretrained: false
                in_chans: 3
                drop_path_rate: 0.2
    decoder_params: 
        decoder_channels: [256, 128, 64, 32, 16]
        scale_factors: [2, 2, 2, 2, 2]
        num_blocks: 5
        block_depth: 1
        separable: false
        use_aspp: false
        use_batchnorm: true 
        attention_type: "scse"
        activation: "silu"

Writing configs/segmentation/tf_efficientnetv2_s_128_224_224.yaml


In [15]:
%%writefile configs/stage_1/tf_efficientnetv2_s_384.yaml
type: "stage_1"
model:
    class: SliceClassificationModel
    encoder_name: "tf_efficientnetv2_s.in21k_ft_in1k"
    params:
        dropout: 0.2
        backbone_params: 
            pretrained: false
            in_chans: 4
            drop_path_rate: 0.2
data:
    image_size: 384
    num_channels: 3

Writing configs/stage_1/tf_efficientnetv2_s_384.yaml


In [16]:
%%writefile configs/stage_1/convnextv2_tiny_384.yaml
model:
    class: SliceClassificationModel
    encoder_name: "convnextv2_tiny.fcmae_ft_in22k_in1k_384"
    params:
        dropout: 0.2
        backbone_params: 
            pretrained: false
            in_chans: 4
            drop_path_rate: 0.2
data:
    image_size: 384
    num_channels: 3

Writing configs/stage_1/convnextv2_tiny_384.yaml


In [17]:
%%writefile configs/stage_1/maxxvitv2_nano_256.yaml
model:
    class: SliceClassificationModel
    encoder_name: "maxxvitv2_nano_rw_256.sw_in1k"
    params:
        dropout: 0.2
        backbone_params: 
            pretrained: false
            in_chans: 2
            drop_path_rate: 0.2
data:
    image_size: 256
    num_channels: 1

Writing configs/stage_1/maxxvitv2_nano_256.yaml


In [18]:
%%writefile configs/stage_2/lstm_256_128.yaml
model:
    time_dim: 256
    feature_dim: 2816
    hidden_dim: 128
    num_layers: 2
    dropout: 0.2
    bidirectional: true
data:
    time_dim: 256

Writing configs/stage_2/lstm_256_128.yaml


In [19]:
%%writefile configs/ensemble.yaml
device: "cuda"
batch_size: 16

models: 
    segmentation:
        config_path: "/kaggle/working/configs/segmentation/tf_efficientnetv2_s_128_224_224.yaml"
        checkpoint_paths: [
            "/kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_0.ckpt",
            "/kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_1.ckpt",
            "/kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_2.ckpt",
            "/kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_3.ckpt",
            "/kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_4.ckpt"
        ]
    stage_1:
        tf_efficientnetv2_s_384:
            config_path: "/kaggle/working/configs/stage_1/tf_efficientnetv2_s_384.yaml"
            checkpoint_paths: [
                "/kaggle/input/rsna-atd-final-stage-1/tf_efficientnetv2_s.in21k_ft_in1k__image_size_384__seed_0__fold_2.ckpt",
                "/kaggle/input/rsna-atd-final-stage-1/tf_efficientnetv2_s.in21k_ft_in1k__image_size_384__seed_0__fold_3.ckpt",
                "/kaggle/input/rsna-atd-final-stage-1/tf_efficientnetv2_s.in21k_ft_in1k__image_size_384__seed_0__fold_4.ckpt"
            ]
        convnextv2_tiny_384:
            config_path: "/kaggle/working/configs/stage_1/convnextv2_tiny_384.yaml"
            checkpoint_paths: [
                "/kaggle/input/rsna-atd-final-stage-1/convnextv2_tiny.fcmae_ft_in22k_in1k_384__image_size_384__seed_0__fold_2.ckpt",
                "/kaggle/input/rsna-atd-final-stage-1/convnextv2_tiny.fcmae_ft_in22k_in1k_384__image_size_384__seed_0__fold_3.ckpt",
                "/kaggle/input/rsna-atd-final-stage-1/convnextv2_tiny.fcmae_ft_in22k_in1k_384__image_size_384__seed_0__fold_4.ckpt"
            ]
        maxxvitv2_nano_256:
            config_path: "/kaggle/working/configs/stage_1/maxxvitv2_nano_256.yaml"
            checkpoint_paths: [
                "/kaggle/input/rsna-atd-final-stage-1/maxxvitv2_nano_rw_256.sw_in1k__image_size_256__seed_0__fold_2.ckpt",
                "/kaggle/input/rsna-atd-final-stage-1/maxxvitv2_nano_rw_256.sw_in1k__image_size_256__seed_0__fold_3.ckpt",
                "/kaggle/input/rsna-atd-final-stage-1/maxxvitv2_nano_rw_256.sw_in1k__image_size_256__seed_0__fold_4.ckpt"
            ]
    stage_2:
        lstm_256_128:
            config_path: "/kaggle/working/configs/stage_2/lstm_256_128.yaml"
            checkpoint_paths: [
                "/kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_0.ckpt",
                "/kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_1.ckpt",
                "/kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_2.ckpt",
                "/kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_3.ckpt",
                "/kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_4.ckpt"
            ]

Writing configs/ensemble.yaml


In [20]:
def load_model(module, config, checkpoint_path):
    print(f"Checkpoint: {checkpoint_path}")
    model = module(config)
    model.load_state_dict(torch.load(checkpoint_path)["state_dict"])
    m = model.model
    m.eval()
    return m


def load_segmentation_models(config, device="cuda"):
    print(f"Loading segmentation model...")
    with open(config["config_path"], "rb") as f:
        config_ = yaml.load(f, Loader=yaml.FullLoader)
    models = [load_model(SegmentationModule, config_, path).to(device) for path in config["checkpoint_paths"]]
    return models


def load_stage_1_models(config, device="cuda"):
    models = {}
    for name, params in config.items():
        print(f"Loading stage 1 family `{name}`...")
        with open(params["config_path"], "rb") as f:
            config_ = yaml.load(f, Loader=yaml.FullLoader)
        checkpoint_paths = params["checkpoint_paths"]
        models[name] = {
            "models": [load_model(SliceClassificationModule, config_, path).to(device) for path in checkpoint_paths],
            "image_size": config_["data"]["image_size"],
            "num_channels": config_["data"]["num_channels"]
        }
    return models


def load_stage_2_models(config, device="cuda"):
    models = {}
    for name, params in config.items():
        print(f"Loading stage 2 family `{name}`...")
        with open(params["config_path"], "rb") as f:
            config_ = yaml.load(f, Loader=yaml.FullLoader)
        checkpoint_paths = params["checkpoint_paths"]
        models[name] = {
            "models": [load_model(ScanClassificationModule, config_, path).to(device) for path in checkpoint_paths],
            "time_dim": config_["data"]["time_dim"]
        }
    return models

## Inference

In [21]:
split = "test"
df = pd.read_csv(os.path.join(DATA_DIR, f"{split}_series_meta.csv"))

with open("/kaggle/working/configs/ensemble.yaml", "rb") as f:
    CONFIG = yaml.load(f, Loader=yaml.FullLoader)
    
segmentation = CONFIG["models"]["segmentation"]
stage_1 = CONFIG["models"]["stage_1"]
stage_2 = CONFIG["models"]["stage_2"]
device = CONFIG["device"]
batch_size = CONFIG["batch_size"]

segmentation_models = load_segmentation_models(segmentation, device=device)
stage_1_models = load_stage_1_models(stage_1, device=device)
stage_2_models = load_stage_2_models(stage_2, device=device)

keys = []
scan_paths = []
for patient_id, series_id in list(df.groupby(["patient_id", "series_id"]).groups.keys()):
    scan_path = get_scan_path(patient_id, series_id, split=split)
    if not os.path.exists(scan_path):
        continue
    keys.append((patient_id, series_id))
    scan_paths.append(scan_path)

scan_predictions = [[] for _ in range(5)]
for (patient_id, series_id), scan_path in tqdm(zip(keys, scan_paths)):
    image = load_scan(scan_path)
    mask, bounds = predict_segmentation(segmentation_models, image, device=device)
    
    features = []
    for name in stage_1_models:
        models = stage_1_models[name]["models"]
        image_size = stage_1_models[name]["image_size"]
        num_channels = stage_1_models[name]["num_channels"]
        features.append(predict_slice(models, image, mask, bounds, image_size, num_channels, batch_size, device=device))
    features = torch.cat(features, dim=1)
    
    predictions = []
    for name in stage_2_models:
        models = stage_2_models[name]["models"]
        time_dim = stage_2_models[name]["time_dim"]
        predictions.append(predict_scan(models, features, time_dim, device=device))
    predictions = [sum([p[i] for p in predictions]) for i in range(5)]
    
    for i, p in enumerate(predictions):
        scan_predictions[i].append(p)

scan_predictions = np.concatenate([np.stack(p, axis=0) for p in scan_predictions], axis=1)
df_submit = pd.DataFrame(keys, columns=["patient_id", "series_id"])
pos_cols = [
    "extravasation_injury",
    "bowel_injury",
    "liver_low",
    "liver_high",
    "spleen_low",
    "spleen_high",
    "kidney_low",
    "kidney_high"
]
neg_cols = [
    "extravasation_healthy",
    "bowel_healthy",
    "liver_healthy",
    "spleen_healthy",
    "kidney_healthy"
]
target_cols = [
    "extravasation_healthy",
    "extravasation_injury",
    "bowel_healthy",
    "bowel_injury",
    "liver_healthy",
    "liver_low",
    "liver_high",
    "spleen_healthy",
    "spleen_low",
    "spleen_high",
    "kidney_healthy",
    "kidney_low",
    "kidney_high",
]
df_submit[target_cols] = scan_predictions

df_submit = df_submit.drop(columns=["series_id"])
df_submit_pos = df_submit[["patient_id"] + pos_cols].groupby("patient_id", as_index=False).max(numeric_only=True) # Max aggregation over injury columns
df_submit_neg = df_submit[["patient_id"] + neg_cols].groupby("patient_id", as_index=False).min(numeric_only=True) # Min aggregation over healthy columns
df_submit = df_submit_pos.merge(df_submit_neg, on=["patient_id"], how="left").reset_index(drop=True)
df_submit.head()

Loading segmentation model...
Checkpoint: /kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_0.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_1.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_2.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_3.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-segmentation/tf_efficientnetv2_s.in21k_ft_in1k__128_224_224__seed_0__fold_4.ckpt
Loading stage 1 family `tf_efficientnetv2_s_384`...
Checkpoint: /kaggle/input/rsna-atd-final-stage-1/tf_efficientnetv2_s.in21k_ft_in1k__image_size_384__seed_0__fold_2.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-stage-1/tf_efficientnetv2_s.in21k_ft_in1k__image_size_384__seed_0__fold_3.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-stage-1/tf_efficientnetv2_s

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Checkpoint: /kaggle/input/rsna-atd-final-stage-1/maxxvitv2_nano_rw_256.sw_in1k__image_size_256__seed_0__fold_3.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-stage-1/maxxvitv2_nano_rw_256.sw_in1k__image_size_256__seed_0__fold_4.ckpt
Loading stage 2 family `lstm_256_128`...
Checkpoint: /kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_0.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_1.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_2.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_3.ckpt
Checkpoint: /kaggle/input/rsna-atd-final-stage-2/lstm_256_128__seed_0__fold_4.ckpt


0it [00:00, ?it/s]

Unnamed: 0,patient_id,extravasation_injury,bowel_injury,liver_low,liver_high,spleen_low,spleen_high,kidney_low,kidney_high,extravasation_healthy,bowel_healthy,liver_healthy,spleen_healthy,kidney_healthy
0,48843,0.479187,0.29989,0.309901,0.224372,0.328844,0.290889,0.195177,0.2937,0.520813,0.70011,0.465727,0.380267,0.511123
1,50046,0.485541,0.374966,0.274808,0.24772,0.303287,0.243375,0.209228,0.208284,0.514459,0.625034,0.477472,0.453338,0.582488
2,63706,0.379059,0.444275,0.298089,0.196277,0.412387,0.228513,0.220229,0.190025,0.620941,0.555725,0.505634,0.3591,0.589746


In [22]:
submission = pd.read_csv(os.path.join(DATA_DIR, "sample_submission.csv")).drop(columns=target_cols)
submission = submission.merge(df_submit, on=["patient_id"], how="left")
submission = submission.fillna(0.0)
submission.head()

Unnamed: 0,patient_id,extravasation_injury,bowel_injury,liver_low,liver_high,spleen_low,spleen_high,kidney_low,kidney_high,extravasation_healthy,bowel_healthy,liver_healthy,spleen_healthy,kidney_healthy
0,48843,0.479187,0.29989,0.309901,0.224372,0.328844,0.290889,0.195177,0.2937,0.520813,0.70011,0.465727,0.380267,0.511123
1,50046,0.485541,0.374966,0.274808,0.24772,0.303287,0.243375,0.209228,0.208284,0.514459,0.625034,0.477472,0.453338,0.582488
2,63706,0.379059,0.444275,0.298089,0.196277,0.412387,0.228513,0.220229,0.190025,0.620941,0.555725,0.505634,0.3591,0.589746


In [23]:
submission.to_csv("submission.csv", index=None)