In [1]:
!pip -q install '/kaggle/input/rsna24-lsdc-wheels/dicomsdl-0.109.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl'

# Imports

In [2]:
import numpy as np
import pandas as pd
import polars as pl
import os
import random
import time
import datetime
import warnings
import yaml

from collections import OrderedDict
from types import SimpleNamespace

from tqdm.notebook import trange, tqdm

from pathlib import Path
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold
from sklearn.metrics import log_loss

import pandas.api.types
import sklearn.metrics

import timm
import transformers
import nibabel as nib
import pydicom
import dicomsdl

import cv2
from PIL import Image

warnings.simplefilter("ignore")

# Seeding

In [3]:
def seed_everything(seed: int):    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

# Dataset

In [4]:
class LSDCTestDataset(Dataset):
    def __init__(self, conf, df, study_id_level_df, patient_coords, axial_IPP, transforms): ##AXIAL_FLAG
#     def __init__(self, conf, df, study_id_level_df, transforms):
        super().__init__()
        self.conf = conf
        self.df = df
        self.study_id_level_df = study_id_level_df
        self.patient_coords = patient_coords ##AXIAL_FLAG
        self.axial_IPP = axial_IPP ##AXIAL_FLAG
        self.transforms = transforms

        self.label_names_cond = [
            # sagittal_t1
            'left_neural_foraminal_narrowing_normal_mild',
            'left_neural_foraminal_narrowing_moderate',
            'left_neural_foraminal_narrowing_severe',

            'right_neural_foraminal_narrowing_normal_mild',
            'right_neural_foraminal_narrowing_moderate',
            'right_neural_foraminal_narrowing_severe',

            # sagittal_t2
            'spinal_canal_stenosis_normal_mild',
            'spinal_canal_stenosis_moderate',
            'spinal_canal_stenosis_severe',

            # axial
            'left_subarticular_stenosis_normal_mild',
            'left_subarticular_stenosis_moderate',
            'left_subarticular_stenosis_severe',

            'right_subarticular_stenosis_normal_mild',
            'right_subarticular_stenosis_moderate',
            'right_subarticular_stenosis_severe',
        ]
        
    
    def __len__(self):
        return len(self.study_id_level_df)
    
    def __getitem__(self, idx):

        image_stack = np.zeros((self.conf.image_size, self.conf.image_size, self.conf.n_slices))
        chan = 0
        
        study_id_level_samples = self.study_id_level_df[idx]

        this_study_lvl = self.df.filter(pl.col('study_id_level') == study_id_level_samples)
        this_study = this_study_lvl['study_id'].item(0)
        sagittal_t1 = this_study_lvl.filter(pl.col('series_description') == 'Sagittal T1')
        sagittal_t2 = this_study_lvl.filter(pl.col('series_description') == 'Sagittal T2/STIR')
        axial_t2 = this_study_lvl.filter(pl.col('series_description') == 'Axial T2') ##AXIAL_FLAG
        
        if not sagittal_t1.is_empty():
            sagittal_t1_path = sagittal_t1['series_path'].item(0)
            sagittal_t1_slices = sagittal_t1['slices'][0].to_numpy()

            sagittal_t1_mid_idx = sagittal_t1_slices[len(sagittal_t1_slices) // 2]
            mid_img_size = self._read_dcm(f'{sagittal_t1_path}/{sagittal_t1_mid_idx}.dcm').shape
            sagittal_t1_rel_bbox = sagittal_t1.select(['relative_x', 'relative_y']).to_numpy()[0]
            x, y = sagittal_t1_rel_bbox * mid_img_size
            window_half_size = (np.mean(mid_img_size) * self.conf.sagittal_window_ratio) // 2

            sagittal_t1_bbox = np.array([
                x - window_half_size,
                x + window_half_size,
                y - window_half_size,
                y + window_half_size,
            ]).clip(1, np.mean(mid_img_size))
            
            for ins_num in sagittal_t1_slices:
                try:
                    dicom = self._read_dcm(f'{sagittal_t1_path}/{ins_num}.dcm')
                    cropped = dicom[
                        int(round(sagittal_t1_bbox[2])): int(round(sagittal_t1_bbox[3])),
                        int(round(sagittal_t1_bbox[0])): int(round(sagittal_t1_bbox[1])),
                    ]
                    image_stack[:, :, chan] = cv2.resize(cropped, (self.conf.image_size, self.conf.image_size))
                except:
                    pass
                chan += 1
        else:
            chan += 11 # skip if study does not have sagittal t1
            
        if not sagittal_t2.is_empty():
            sagittal_t2_path = sagittal_t2['series_path'].item(0)
            sagittal_t2_slices = sagittal_t2['slices'][0].to_numpy()

            sagittal_t2_mid_idx = sagittal_t2_slices[len(sagittal_t2_slices) // 2]
            mid_img_size = self._read_dcm(f'{sagittal_t2_path}/{sagittal_t2_mid_idx}.dcm').shape
            sagittal_t2_rel_bbox = sagittal_t2.select(['relative_x', 'relative_y']).to_numpy()[0]
            x, y = sagittal_t2_rel_bbox * mid_img_size
            window_half_size = (np.mean(mid_img_size) * self.conf.sagittal_window_ratio) // 2

            sagittal_t2_bbox = np.array([
                x - window_half_size,
                x + window_half_size,
                y - window_half_size,
                y + window_half_size,
            ]).clip(1, np.mean(mid_img_size))

            for ins_num in sagittal_t2_slices:
                try:
                    dicom = self._read_dcm(f'{sagittal_t2_path}/{ins_num}.dcm')
                    cropped = dicom[
                        int(round(sagittal_t2_bbox[2])): int(round(sagittal_t2_bbox[3])),
                        int(round(sagittal_t2_bbox[0])): int(round(sagittal_t2_bbox[1])),
                    ]
                    image_stack[:, :, chan] = cv2.resize(cropped, (self.conf.image_size, self.conf.image_size))
                except:
                    pass
                chan += 1
                
            if not axial_t2.is_empty(): ##AXIAL_FLAG
                axial_t2_path = axial_t2['series_path'].item(0)
                axial_t2_slices = axial_t2['slices'][0].to_numpy()
                axial_t2_rel_coords = np.array(axial_t2[['left_coords', 'right_coords']].row(0)).mean(axis=0)

                axial_t2_mid_idx = axial_t2_slices[len(axial_t2_slices) // 2]
                mid_img_size = self._read_dcm(f'{axial_t2_path}/{axial_t2_mid_idx}.dcm').shape
                x, y = axial_t2_rel_coords * mid_img_size
                window_half_size = (np.mean(mid_img_size) * self.conf.axial_window_ratio) // 2
                axial_t2_bbox = np.array([
                    x - window_half_size,
                    x + window_half_size,
                    y - window_half_size,
                    y + window_half_size,
                ]).clip(1, np.mean(mid_img_size))

                axial_z_axis = self.axial_IPP.filter(pl.col('study_id') == this_study)[['instance_number', 'SliceLocation']].to_numpy()
                sag_coord = self.patient_coords.filter(
                    pl.col('study_id_level') == study_id_level_samples,
                    pl.col('series_description') == 'Sagittal T2/STIR'
                )['patient_coords'].to_numpy()[0]

                level_idx = self._find_axial(axial_z_axis, sag_coord, self.conf.axial_min_dist_threshold)
                if level_idx is not None:
                    level_pack_slices = axial_t2_slices[
                        level_idx - 2: level_idx + 1
                    ]

                    for ins_num in level_pack_slices:
                        try:
                            dicom = self._read_dcm(f'{axial_t2_path}/{ins_num}.dcm')
#                             cropped = self._center_crop(dicom, (self.conf.axial_center_crop_size, self.conf.axial_center_crop_size))
                            cropped = dicom[
                                int(round(axial_t2_bbox[2])): int(round(axial_t2_bbox[3])),
                                int(round(axial_t2_bbox[0])): int(round(axial_t2_bbox[1])),
                            ]
                            image_stack[:, :, chan] = cv2.resize(cropped, (self.conf.image_size, self.conf.image_size))
                        except:
                            pass
                        chan += 1

        image_stack = (image_stack - image_stack.min()) / (image_stack.max() - image_stack.min() + 1e-7)
        image_stack = image_stack.astype(np.float32)

        new_image_stack = np.concatenate([
            # left
            image_stack[:, :, 0: 3],
            image_stack[:, :, 11: 14],
            # middle
            image_stack[:, :, 3: 8],
            image_stack[:, :, 14: 19],
            # right
            image_stack[:, :, 8: 11],
            image_stack[:, :, 19: 22],
            # axial
            image_stack[:, :, 22: 25], ##AXIAL_FLAG
        ], axis=-1)

        if self.transforms is not None:
            new_image_stack = self.transforms(image=new_image_stack)['image']
            
        batch = {}
        batch['ids'] = study_id_level_samples
        batch['images'] = new_image_stack.to(dtype=torch.float)
        
        return batch

    def _read_dcm(self, path):
        data = dicomsdl.open(path).pixelData(storedvalue=True)
        return data
    
    def _center_crop(self, img, dim):
        # https://gist.github.com/Nannigalaxy/35dd1d0722f29672e68b700bc5d44767
        """Returns center cropped image
        Args:
        img: image to be center cropped
        dim: dimensions (width, height) to be cropped
        """
        width, height = img.shape[1], img.shape[0]

        # process crop width and height for max available dimension
        crop_width = dim[0] if dim[0] < img.shape[1] else img.shape[1]
        crop_height = dim[1] if dim[1] < img.shape[0] else img.shape[0]
        mid_x, mid_y = int(width / 2), int(height / 2)
        cw2, ch2 = int(crop_width / 2), int(crop_height / 2) 
        crop_img = img[mid_y - ch2: mid_y + ch2, mid_x - cw2: mid_x + cw2]
        return crop_img
    
    def _find_axial(self, axial_array, sag_coords, threshold):
        # axial_array (instance, slice location)
        slice_position = axial_array[:, 1] # axial slicelocation
        distance = np.abs(slice_position - sag_coords[2])
        slice_pos = np.argmin(distance)
        
        if threshold != None:
            if axial_array[slice_pos][1] > threshold:
                return None
        
        return int(axial_array[slice_pos][0])

In [5]:
# test_df = infer_data_preprocess(test_df, test_images_path, sagittal_IPP_df)

In [6]:
# with open('/kaggle/input/rsna24-lsdc-exps/exp110/config.yaml') as file:
#     conf = SimpleNamespace(**yaml.safe_load(file))
# dataset = LSDCTestDataset(conf, test_df, test_df['study_id_level'].unique(), patient_coords_df, axial_IPP_df, get_transforms(conf, types='test'))

In [7]:
# data =  dataset[1]

In [8]:
# for i in range(1, 26):
#     dcm = pydicom.dcmread(f'/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_images/44036939/3844393089/{i}.dcm')
#     print([dcm.InstanceNumber, dcm.ImagePositionPatient[0], dcm.SliceLocation])

In [9]:
# fig = plt.figure(figsize=(20, 20))
# for i, img in enumerate(data['images']):
#     num = i + 1
#     ax = fig.add_subplot(6, 6, num, frameon=False)
#     ax.title.set_text(num)
#     plt.imshow(img, cmap='gray')
#     plt.axis('off')
# plt.show()

In [10]:
class LSDCCoordsTestDataset(Dataset):
    def __init__(self, conf, df, series_id_df, transforms=None):
        super().__init__()
        self.conf = conf
        self.transforms = transforms
        self.series_id_df = series_id_df
        self.df = df
        
    def __len__(self):
        return len(self.series_id_df)
    
    def __getitem__(self, idx):
        
        series_sample = self.series_id_df[idx]
        this_series = self.df.filter(pl.col('series_id') == series_sample)
        study_id = this_series['study_id'].item(0)
        series_id = this_series['series_id'].item(0)
        series_path = this_series['series_path'].item(0)
        ins_num_array = this_series['instance_number'].to_numpy()
        ins_num = ins_num_array[len(ins_num_array) // 2] # pick middle dcm assuming passed df is sorted
        
        one_image = self._read_dcm(f'{series_path}/{ins_num}.dcm')
        one_image = cv2.resize(one_image, (self.conf.image_size, self.conf.image_size))
        image = np.stack([one_image, one_image, one_image], axis=-1)
        
        image = (image - image.min()) / (image.max() - image.min() + 1e-7)
        
        if self.transforms is not None:
            transformed = self.transforms(image=image)
            image = transformed['image']
    
        batch = {}
        batch['ids'] = series_sample
        batch['images'] = image
        
        return batch
        
    def _read_dcm(self, path):
        data = dicomsdl.open(path).pixelData(storedvalue=True)
        return data

In [11]:
def get_transforms(conf, types):
    tranforms_dict = {
        'test': A.Compose([
            ToTensorV2(),
        ]),
    }
    return tranforms_dict[types]

# Model

In [12]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'

In [13]:
# https://www.kaggle.com/bminixhofer/a-validation-framework-impact-of-the-random-seed
class Attention(nn.Module):
    def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
        super(Attention, self).__init__(**kwargs)
        
        self.supports_masking = True

        self.bias = bias
        self.feature_dim = feature_dim
        self.step_dim = step_dim
        self.features_dim = 0
        
        weight = torch.zeros(feature_dim, 1)
        nn.init.xavier_uniform_(weight)
        self.weight = nn.Parameter(weight)
        
        if bias:
            self.b = nn.Parameter(torch.zeros(step_dim))
        
    def forward(self, x, mask=None):
        feature_dim = self.feature_dim
        step_dim = self.step_dim

        eij = torch.mm(
            x.contiguous().view(-1, feature_dim), 
            self.weight
        ).view(-1, step_dim)
        
        if self.bias:
            eij = eij + self.b
            
        eij = torch.tanh(eij)
        a = torch.exp(eij)
        
        if mask is not None:
            a = a * mask

        a = a / torch.sum(a, 1, keepdim=True) + 1e-10

        weighted_input = x * torch.unsqueeze(a, -1)
        return torch.sum(weighted_input, 1)

In [14]:
class LSDCModel(nn.Module):
    def __init__(self, conf, pretrained=False):
        super().__init__()
        self.head_type = conf.head_type
        if self.head_type == 'lstm_attn':
            self.head_feats = 512
        elif self.head_type == 'lstm_mean_max':
            self.head_feats = 512 * 2
        elif self.head_type == 'avg':
            self.head_feats = 512
        
        if 'vit' in conf.backbone or 'coat' in conf.backbone: 
            self.backbone = timm.create_model(
                conf.backbone,
                pretrained=pretrained,
                features_only=False,
                in_chans=conf.in_chans,
                img_size=conf.image_size
            )
        else:
            self.backbone = timm.create_model(
                conf.backbone,
                pretrained=pretrained,
                features_only=False,
                in_chans=conf.in_chans,
            )
        
        if 'efficientnet' in conf.backbone:
            in_features = self.backbone.classifier.in_features
            self.backbone.global_pool = nn.Identity()
            self.backbone.classifier = nn.Identity()
        elif 'convnext' in conf.backbone:
            in_features = self.backbone.head.fc.in_features
            self.backbone.head = nn.Identity()
        elif 'maxvit' in conf.backbone:
            in_features = self.backbone.head.fc.in_features
            self.backbone.head = nn.Identity()
        elif 'coat' in conf.backbone:
            in_features = self.backbone.head.fc.in_features
            self.backbone.head = nn.Identity()
        else:
            raise NotImplementedError

        self.pooler = self._get_pooling(conf.pooling)

        self.lstm_lnfn = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)
        self.lstm_rnfn = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)
        self.lstm_scs = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)
        self.lstm_lss = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)
        self.lstm_rss = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)

        if 'attn' in self.head_type:
            self.attn_lnfn = Attention(512, conf.n_slices)
            self.attn_rnfn = Attention(512, conf.n_slices)
            self.attn_scs = Attention(512, conf.n_slices)
            self.attn_lss = Attention(512, conf.n_slices)
            self.attn_rss = Attention(512, conf.n_slices)

        self.head_lnfn = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )
        self.head_rnfn = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )
        self.head_scs = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )
        self.head_lss = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )
        self.head_rss = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )

    def _get_pooling(self, pooling_name):
        if pooling_name == 'avg':
            return nn.AdaptiveAvgPool2d((1))
        elif pooling_name == 'max':
            return nn.AdaptiveMaxPool2d((1))
        elif pooling_name == 'gem':
            return GeM()
        else:
            raise NotImplementedError

    def forward(self, inputs):
        
        bs, slices, img_h, img_w = inputs.shape

        inputs = inputs.view(bs * slices, 1, img_h, img_w)

        outputs = self.backbone(inputs)
        outputs = self.pooler(outputs)
        outputs = outputs.view(bs, slices, -1)
        
        outputs1, _ = self.lstm_lnfn(outputs)
        outputs2, _ = self.lstm_rnfn(outputs)
        outputs3, _ = self.lstm_scs(outputs)
        outputs4, _ = self.lstm_lss(outputs)
        outputs5, _ = self.lstm_rss(outputs)

        if self.head_type == 'lstm_attn':
            outputs1 = self.attn_lnfn(outputs1)
            outputs2 = self.attn_rnfn(outputs2)
            outputs3 = self.attn_scs(outputs3)
            outputs4 = self.attn_lss(outputs4)
            outputs5 = self.attn_rss(outputs5)
            
        elif self.head_type == 'lstm_mean_max':
            outputs1 = torch.concat([outputs1.mean(dim=1), outputs1.amax(dim=1)], dim=-1)
            outputs2 = torch.concat([outputs2.mean(dim=1), outputs2.amax(dim=1)], dim=-1)
            outputs3 = torch.concat([outputs3.mean(dim=1), outputs3.amax(dim=1)], dim=-1)
            outputs4 = torch.concat([outputs4.mean(dim=1), outputs4.amax(dim=1)], dim=-1)
            outputs5 = torch.concat([outputs5.mean(dim=1), outputs5.amax(dim=1)], dim=-1)
            
        elif self.head_type == 'avg':
            outputs1 = outputs1.contiguous().view(bs * slices, -1)
            outputs2 = outputs2.contiguous().view(bs * slices, -1)
            outputs3 = outputs3.contiguous().view(bs * slices, -1)
            outputs4 = outputs4.contiguous().view(bs * slices, -1)
            outputs5 = outputs5.contiguous().view(bs * slices, -1)
        
        outputs1 = self.head_lnfn(outputs1)
        outputs2 = self.head_rnfn(outputs2)
        outputs3 = self.head_scs(outputs3)
        outputs4 = self.head_lss(outputs4)
        outputs5 = self.head_rss(outputs5)
        
        if self.head_type == 'avg':
            outputs1 = outputs1.view(bs, slices, -1).mean(dim=1)
            outputs2 = outputs2.view(bs, slices, -1).mean(dim=1)
            outputs3 = outputs3.view(bs, slices, -1).mean(dim=1)
            outputs4 = outputs4.view(bs, slices, -1).mean(dim=1)
            outputs5 = outputs5.view(bs, slices, -1).mean(dim=1)

        return torch.concat([outputs1, outputs2, outputs3, outputs4, outputs5], dim=-1)

In [15]:
class LSDCCoordsModel(nn.Module):
    def __init__(self, conf, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            conf.backbone,
            pretrained=pretrained,
            features_only=False,
            in_chans=conf.in_chans,
            num_classes=conf.num_class,
            global_pool='avg'
        )

    def forward(self, inputs):
        outputs = self.backbone(inputs)
        return outputs

# Utils

In [16]:
def infer_data_preprocess(df, data_path, sagittal_IPP_df):
    print('data preprocessing..')
    
    df = df.with_columns(pl.concat_str([
        pl.lit(data_path),
        pl.col('study_id'),
        pl.col('series_id'),
    ], separator='/').alias('series_path'))
    
    df = df.with_columns(pl.concat_str([
        pl.col('study_id'),
        pl.col('level'),
    ],separator='_').alias('study_id_level'))
    
    slice_samples_stack = {
        'series_id': [],
        'slices': [],
    }

    for name, data in tqdm(df.group_by(['series_path'])):
        plane = data['series_description'][0]
        series_path = Path(name[0])
        series_id = series_path.name

        if plane != 'Axial T2':
            series_info = sagittal_IPP_df.filter(pl.col('series_id') == int(series_id))
            # handling if series fail to retrive info from dicom then manually picking as fall back
            if series_info.is_empty():
                dcm_array = np.sort([int(path.stem) for path in series_path.glob('*.dcm')])
                dcm_array_mid_len = len(dcm_array) // 2
                mid_idx = dcm_array[dcm_array_mid_len]
                side1_idx = dcm_array[dcm_array_mid_len - 4] # distance around 16~18mm
                side2_idx = dcm_array[dcm_array_mid_len + 4] # distance around 16~18mm

                side1_pack = dcm_array[side1_idx - 2: side1_idx + 1][: 3] # make sure number of slices
                mid_pack = dcm_array[mid_idx - 3: mid_idx + 2][: 5] # make sure number of slices
                side2_pack = dcm_array[side2_idx - 2: side2_idx + 1][: 3] # make sure number of slices
                
                pack = np.concatenate([side1_pack, mid_pack, side2_pack])

            # if series exists pick based on distance on mid with left or right in mind
            else:
                dcm_array = series_info['instance_number'].to_numpy()
                mid_idx = dcm_array[len(dcm_array) // 2] # should be fine in most case but may be inaccurate

                first = series_info.filter(pl.col('instance_number') == series_info['instance_number'].min())
                first_IPP_x = first["ImagePositionPatient_x"].item(0)
                first_sl = first["SliceLocation"].item(0)

                last = series_info.filter(pl.col('instance_number') == series_info['instance_number'].max())
                last_IPP_x = last["ImagePositionPatient_x"].item(0)
                last_sl = last["SliceLocation"].item(0)

                mid_info = series_info.filter(pl.col('instance_number') == mid_idx)

                neutral = series_info['SliceLocation'].to_numpy() - mid_info['SliceLocation'].to_numpy()
                side1_idx = np.argmin(np.abs(neutral - 17)) + 1
                side2_idx = np.argmin(np.abs(neutral + 17)) + 1

                side1_pack = dcm_array[side1_idx - 2: side1_idx + 1]
                mid_pack = dcm_array[mid_idx - 3: mid_idx + 2]
                side2_pack = dcm_array[side2_idx - 2: side2_idx + 1]

                cond1 = (first_IPP_x > last_IPP_x) and (first_sl > last_sl)
                cond2 = (first_IPP_x < last_IPP_x) and (first_sl < last_sl)

                if cond1 or cond2:
                    pack = np.concatenate([side1_pack, mid_pack, side2_pack])
                else:
                    pack = np.concatenate([side2_pack, mid_pack, side1_pack])
            
        else: # Axial T2
            pack = np.sort([int(path.stem) for path in series_path.glob('*.dcm')])

        slice_samples_stack['series_id'].append(series_id)
        slice_samples_stack['slices'].append(pack)
        
    slice_samples_stack = pl.DataFrame(slice_samples_stack).with_columns(pl.col('series_id').cast(pl.Int64))
    
    df = df.join(slice_samples_stack, on='series_id')
    
    return df

In [17]:
def formatting_predictions(oof: pl.DataFrame):
    formatted_df = pl.DataFrame()
    
    oof = (
        oof
        .melt(id_vars='study_id_level')
        .with_columns(pl.col('study_id_level').str.head(-6).alias('study_id'))
        .with_columns(pl.col('study_id_level').str.slice(-5).alias('level'))
    )

    for cond in ['normal_mild', 'moderate', 'severe']:
        tmp = (
            oof
            .filter(pl.col('variable').str.contains(f'{cond}'))
            .with_columns(pl.col('variable').str.head(-len(f'_{cond}')).alias('condition'))
            .with_columns(pl.concat_str([
                pl.col('study_id'),
                pl.col('condition'),
                pl.col('level'),
            ], separator='_').alias('row_id'))
            .rename({'value': f'{cond}'})
            .select(['row_id', f'{cond}'])
        )
        if formatted_df.is_empty():
            formatted_df = tmp
        else:
            formatted_df = formatted_df.join(tmp, on='row_id')
            
    return formatted_df.sort('row_id')

## Create df utils

In [18]:
def create_test_df(test_des, test_images_path):
    test_images_path_str = test_images_path
    print('creating test df')
    test_images_path = Path(test_images_path).glob('**/**/*.dcm')
    test_dcm_dict = {
        'study_id': [],
        'series_id': [],
        'instance_number': [],
    }

    for path in tqdm(test_images_path):
        test_dcm_dict['study_id'].append(int(path.parent.parent.name))
        test_dcm_dict['series_id'].append(int(path.parent.name))
        test_dcm_dict['instance_number'].append(int(path.stem))

    test_df = pl.DataFrame(test_dcm_dict).with_columns(pl.concat_str([
                pl.lit(test_images_path_str),
                pl.col('study_id'),
                pl.col('series_id'),
            ], separator='/').alias('series_path')).sort(['study_id', 'series_id', 'instance_number'])

    test_df = test_df.join(test_des, on=['study_id', 'series_id'])
    del test_dcm_dict

    return test_df

def create_patient_coords_df(test_df, test_images_path): # need revision
    print('creating patient coordinates df')
    
    patient_coords_dict = {
        'study_id': [],
        'series_id': [],
        'study_id_level': [],
        'series_description': [],
        'level': [],
        'patient_coords': []
    }

    for name, data in tqdm(test_df.filter(pl.col('series_description') != 'Axial T2').group_by(['study_id'])):

        for name2, data2 in data.group_by(['series_id']):
            study_id = data2['study_id'].unique().item(0)
            series_id = data2['series_id'].unique().item(0)
            series_description = data2['series_description'].unique().item(0)

            ins_nums = data2['instance_number'].to_numpy()
            rel_coords = data2.select(['relative_x', 'relative_y']).to_numpy()
            levels = data2['level'].to_list()

            for ins_num, (rel_x, rel_y), level in zip(ins_nums, rel_coords, levels):
                path = os.path.join(test_images_path, f'{study_id}/{series_id}/{int(ins_num)}.dcm')
                dcm = dicomsdl.open(path)

                pixel_spacing = dcm.PixelSpacing
                image_position = dcm.ImagePositionPatient
                image_orientation = dcm.ImageOrientationPatient

                x, y = rel_x * dcm.Rows, rel_y * dcm.Columns

                row_vector = np.array(image_orientation[:3])
                col_vector = np.array(image_orientation[3:])

                patient_coords = [
                    image_position[0] + x * pixel_spacing[1] * row_vector[0] + y * pixel_spacing[0] * col_vector[0],
                    image_position[1] + x * pixel_spacing[1] * row_vector[1] + y * pixel_spacing[0] * col_vector[1],
                    image_position[2] + x * pixel_spacing[1] * row_vector[2] + y * pixel_spacing[0] * col_vector[2]
                ]
                patient_coords_dict['study_id'].append(study_id)
                patient_coords_dict['series_id'].append(series_id)
                patient_coords_dict['study_id_level'].append(str(study_id) + '_' + level)
                patient_coords_dict['series_description'].append(series_description)
                patient_coords_dict['level'].append(level)
                patient_coords_dict['patient_coords'].append(patient_coords)

    patient_coords_df = pl.DataFrame(patient_coords_dict)
    del patient_coords_dict

    return patient_coords_df

def create_axial_IPP_df(test_df, test_images_path): # need fix reduce unnecessary stuff
    print('creating axial IPP df')
    
    axial_img_pos_dict = {
        'study_id': [],
        'series_id': [],
        'instance_number': [],
        'ImagePositionPatient_z': [],
        'SliceLocation': [],
    }

    for name, data in tqdm(test_df.filter(pl.col('series_description') == 'Axial T2').group_by(['study_id', 'series_id'])):
        study_id = name[0]
        series_id = name[1]

        dcm_path = Path(os.path.join(test_images_path, f'{study_id}/{series_id}/')).glob('*.dcm')
        for path in dcm_path:
            dcm = dicomsdl.open(path.as_posix()).getValues(['InstanceNumber' ,'ImagePositionPatient', 'SliceLocation'])
            
            axial_img_pos_dict['study_id'].append(study_id)
            axial_img_pos_dict['series_id'].append(series_id)
            axial_img_pos_dict['instance_number'].append(dcm[0])
            axial_img_pos_dict['ImagePositionPatient_z'].append(dcm[1][2])
            axial_img_pos_dict['SliceLocation'].append(dcm[2])
            
    axial_img_pos_df = pl.DataFrame(axial_img_pos_dict).sort(['study_id', 'series_id', 'instance_number'])

    return axial_img_pos_df

def create_sagittal_IPP_df(test_df, test_images_path):
    print('creating sagittal IPP df')
    
    none_series = []

    sagittal_img_pos_dict = {
        'study_id': [],
        'series_id': [],
        'instance_number': [],
        'ImagePositionPatient_x': [],
        'SliceLocation': [],
    }

    for name, data in tqdm(test_df.filter(pl.col('series_description') != 'Axial T2').group_by(['study_id', 'series_id'])):
        study_id = name[0]
        series_id = name[1]

        dcm_path = sorted(list(Path(os.path.join(test_images_path, f'{study_id}/{series_id}/')).glob('*.dcm')), key=lambda x: int(x.stem))
        dcm_array = np.array([int(path.stem) for path in dcm_path])

        for path in dcm_path:
            dcm = dicomsdl.open(path.as_posix()).getValues(['InstanceNumber' ,'ImagePositionPatient', 'SliceLocation'])

            if None in dcm:
                none_series.append(series_id)
                break

            sagittal_img_pos_dict['study_id'].append(study_id)
            sagittal_img_pos_dict['series_id'].append(series_id)
            sagittal_img_pos_dict['instance_number'].append(dcm[0])
            sagittal_img_pos_dict['ImagePositionPatient_x'].append(dcm[1][0])
            sagittal_img_pos_dict['SliceLocation'].append(dcm[2])

            
    sagittal_img_pos_df = pl.DataFrame(sagittal_img_pos_dict).sort(['study_id', 'series_id', 'instance_number'])
    sagittal_img_pos_df = sagittal_img_pos_df.filter(~pl.col('series_id').is_in(none_series))
    del sagittal_img_pos_dict

    return sagittal_img_pos_df

# InferenceRunner

In [19]:
class InferenceRunner:
    def __init__(self):
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.cls_model_dict = OrderedDict()
        self.sag_coords_model_dict = OrderedDict()
        self.axl_coords_model_dict = OrderedDict()
    
        self.label_names_cond = [
            # sagittal_t1
            'left_neural_foraminal_narrowing_normal_mild',
            'left_neural_foraminal_narrowing_moderate',
            'left_neural_foraminal_narrowing_severe',

            'right_neural_foraminal_narrowing_normal_mild',
            'right_neural_foraminal_narrowing_moderate',
            'right_neural_foraminal_narrowing_severe',

            # sagittal_t2
            'spinal_canal_stenosis_normal_mild',
            'spinal_canal_stenosis_moderate',
            'spinal_canal_stenosis_severe',

            # axial
            'left_subarticular_stenosis_normal_mild',
            'left_subarticular_stenosis_moderate',
            'left_subarticular_stenosis_severe',

            'right_subarticular_stenosis_normal_mild',
            'right_subarticular_stenosis_moderate',
            'right_subarticular_stenosis_severe',
        ]
    
    def register_cls(self, model_name: str, path: str, fold_list: list):
        model_dict = self._get_cls_model_dict(path, fold_list)
        self.cls_model_dict[model_name] = model_dict
        print('REGISTERED CLS: ', model_name)
        
    def register_sag_coords(self, model_name: str, path: str, fold_list: list):
        model_dict = self._get_coords_model_dict(path, fold_list)
        self.sag_coords_model_dict[model_name] = model_dict
        print('REGISTERED SAG COORDS: ', model_name)

    def register_axl_coords(self, model_name: str, path: str, fold_list: list):
        model_dict = self._get_coords_model_dict(path, fold_list)
        self.axl_coords_model_dict[model_name] = model_dict
        print('REGISTERED AXL COORDS: ', model_name)
    
    def _get_cls_model_dict(self, path, fold_list):
        model_dict = {
            'models': {f: f'{path}/best_score_fold{f}.pt' for f in fold_list},
#             'models': {f: f'{path}/best_loss_fold{f}.pt' for f in fold_list},
            'config': f'{path}/config.yaml',
        }
        return model_dict
    
    def _get_coords_model_dict(self, path, fold_list):
        model_dict = {
            'models': {f: f'{path}/best_score_fold{f}.pt' for f in fold_list},
            'config': f'{path}/config.yaml',
        }
        return model_dict
    
    def show(self):
        print("seg models:")
        print(yaml.dump(dict(self.coords_model_dict), allow_unicode=True, default_flow_style=False))
        print('')
        print("cls models:")
        print(yaml.dump(dict(self.cls_model_dict), allow_unicode=True, default_flow_style=False))
    
    def predict_sagittal_coords(self, test_df):
        
        for model_name, model_dict in self.sag_coords_model_dict.items():

            with open(model_dict['config']) as file:
                conf = SimpleNamespace(**yaml.safe_load(file))
            
            for fold, model_path in model_dict['models'].items(): ### using only 1 fold for now, maybe good enough
                print(f'INFERENCING: {model_name}, FOLD: {fold}')
                dataloader = self._get_coords_dataloader(conf, test_df.filter(pl.col('series_description') != 'Axial T2'))
                model = self._load_coords_model(conf, model_path)
                predictions_df = self._infer_sag_coords_fn(dataloader, model)
                
                predictions_df = predictions_df.explode(['relative_x', 'relative_y'])
                lvl_series = pl.Series(['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1'] * int(len(predictions_df) // 5))
                predictions_df = predictions_df.with_columns(level=lvl_series)
                
        axial_null_coords_df = (
            test_df
            .filter(pl.col('series_description') == 'Axial T2')
            .select(['series_id'])
            .unique()
            .with_columns(relative_x=pl.lit(None), relative_y=pl.lit(None))
            .with_columns(level=pl.lit(['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']))
            .explode('level')
        )
        predictions_df = predictions_df.with_columns(pl.col('relative_x').clip(0.01, 0.99), pl.col('relative_y').clip(0.01, 0.99)) # prevent zero
        coords_df = pl.concat([predictions_df, axial_null_coords_df], how='vertical')
                
        del dataloader
        del model
                
        return coords_df
    
    def predict_axial_coords(self, test_df):
        
        for model_name, model_dict in self.axl_coords_model_dict.items():

            with open(model_dict['config']) as file:
                conf = SimpleNamespace(**yaml.safe_load(file))
                
            for fold, model_path in model_dict['models'].items(): ### using only 1 fold for now, maybe good enough
                print(f'INFERENCING: {model_name}, FOLD: {fold}')
                dataloader = self._get_coords_dataloader(conf, test_df.filter(pl.col('series_description') == 'Axial T2'))
                model = self._load_coords_model(conf, model_path)
                predictions_df = self._infer_axl_coords_fn(dataloader, model)

                predictions_df = predictions_df.with_columns(pl.col('series_id').repeat_by(5)).explode('series_id')
                lvl_series = pl.Series(['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1'] * int(len(predictions_df) // 5))
                predictions_df = predictions_df.with_columns(level=lvl_series)
                
            del dataloader
            del model
            
            return predictions_df
    
    def _get_coords_dataloader(self, conf, test_df):
        test_dataset = LSDCCoordsTestDataset(conf, test_df, test_df['series_id'].unique(), get_transforms(conf, types='test'))
        test_loader = DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=0,
            pin_memory=False,
            drop_last=False,
        )
        return test_loader
    
    def _load_coords_model(self, conf, model_path):
        model = LSDCCoordsModel(conf, pretrained=None)
        checkpoint = torch.load(model_path, map_location=self.device)
        model.load_state_dict(checkpoint['model_state_dict'])
        return model
    
    def _infer_sag_coords_fn(self, dataloader, model):
        model.to(self.device)
        model.eval()
            
        ids_list = []
        outputs_list = []

        for batch in tqdm(dataloader):
            batch_ids = batch['ids']
            inputs = batch['images'].to(self.device, dtype=torch.float)
            batch_size = len(inputs)
            
            with torch.no_grad():
                outputs = model(inputs)
                
            outputs = outputs.cpu().sigmoid().view(batch_size, 5, 2)
                
            ids_list.append(batch_ids)
            outputs_list.append(outputs)
            
        ids_list = torch.concat(ids_list).numpy()
        outputs_list = torch.concat(outputs_list).numpy()

        return pl.DataFrame({
            'series_id': ids_list,
            'relative_x': outputs_list[:, :, 0],
            'relative_y': outputs_list[:, :, 1],
        })
    
    def _infer_axl_coords_fn(self, dataloader, model):
        model.to(self.device)
        model.eval()
            
        ids_list = []
        outputs_list = []

        for batch in tqdm(dataloader):
            batch_ids = batch['ids']
            inputs = batch['images'].to(self.device, dtype=torch.float)
            batch_size = len(inputs)
            
            with torch.no_grad():
                outputs = model(inputs)
                
            outputs = outputs.cpu().sigmoid().view(batch_size, 2, 2)
                
            ids_list.append(batch_ids)
            outputs_list.append(outputs)
            
        ids_list = torch.concat(ids_list).numpy()
        outputs_list = torch.concat(outputs_list).numpy()

        return pl.DataFrame({
            'series_id': ids_list,
            'left_coords': outputs_list[:, 0], # [x, y]
            'right_coords': outputs_list[:, 1], # [x, y]
        })
    
    def predict_cv(self, test_df):

        preds_value_stack = []
        
        for model_name, model_dict in self.cls_model_dict.items():
        
            with open(model_dict['config']) as file:
                conf = SimpleNamespace(**yaml.safe_load(file))

            for fold, model_path in model_dict['models'].items():
                print(f'INFERENCING: {model_name}, config_exp: {conf.exp}, FOLD: {fold}')
                dataloader = self._get_cls_dataloader(conf, test_df)
                model = self._load_cls_model(conf, model_path)
                fold_predictions = self._infer_cls_fn(dataloader, model)
                preds_value_stack.append(fold_predictions[1])
                
        preds_value_stack = torch.stack(preds_value_stack, dim=0)
        preds_value_mean = torch.mean(preds_value_stack, dim=0)
        item_num = preds_value_mean.shape[0]

        preds_value_mean = preds_value_mean.view(item_num, 5, 3).softmax(dim=2).view(item_num, -1).numpy()

        preds_df = pl.DataFrame({
            'study_id_level': fold_predictions[0],
            'preds': preds_value_mean,
        }).with_columns(pl.col('preds').arr.to_struct()).unnest('preds')

        preds_df = preds_df.rename({
            f'field_{f_n}': cond for f_n, cond in enumerate(self.label_names_cond)
        })

        preds_df = formatting_predictions(preds_df)
        preds_df = preds_df.cast({pl.selectors.numeric(): pl.Float64}).fill_null(0.33).fill_nan(0.33)

        return preds_df
    
    def _get_cls_dataloader(self, conf, test_df):
        test_dataset = LSDCTestDataset(
            conf,
            test_df,
            test_df['study_id_level'].unique(maintain_order=True),
            patient_coords_df, ##AXIAL_FLAG
            axial_IPP_df, ##AXIAL_FLAG
            get_transforms(conf, types='test'),
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=0,
            pin_memory=False,
            drop_last=False,
        )
        return test_loader
    
    def _load_cls_model(self, conf, model_path):
        model = LSDCModel(conf, pretrained=False)
        checkpoint = torch.load(model_path, map_location=self.device) # maybe trigger error cause it load gpu first
        model.load_state_dict(checkpoint['model_state_dict'])
        return model
    
    def _infer_cls_fn(self, dataloader, model):
        model.to(self.device)
        model.eval()
        
        ids_list = []
        outputs_list = []
        for batch in tqdm(dataloader):

            batch_ids = batch['ids']
            inputs = batch['images'].to(self.device, dtype=torch.float)

            batch_size = len(inputs)

            with torch.no_grad():
                outputs = model(inputs)

            ids_list.extend(batch_ids)
            outputs_list.append(outputs)

        outputs_list = torch.concat(outputs_list).cpu()
        
        del dataloader
        del model
        
        return ids_list, outputs_list

In [20]:
def create_dummies_sub(test_df, severity=[0.333, 0.333, 0.333]):
    df = (
        test_df
        .select(['study_id'])
        .unique()
        .with_columns(condition=pl.lit([
            'left_neural_foraminal_narrowing',
            'right_neural_foraminal_narrowing',
            'spinal_canal_stenosis',
            'left_subarticular_stenosis',
            'right_subarticular_stenosis',
        ]))
        .explode('condition')
        .with_columns(level=pl.lit(['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']))
        .explode('level')
        .with_columns(pl.concat_str([
            pl.col('study_id'),
            pl.col('condition'),
            pl.col('level')
        ], separator='_').alias('row_id'))
        .with_columns(normal_mild=severity[0])
        .with_columns(moderate=severity[1])
        .with_columns(severe=severity[2])
        .select(['row_id', 'normal_mild', 'moderate', 'severe'])
    )
    return df.sort('row_id')

# Load Data

In [21]:
# test_images_path = '/kaggle/input/rsna-lsdc-2024-submission-debug-dataset/debug/test_images'
# test_des = pl.read_csv('/kaggle/input/rsna-lsdc-2024-submission-debug-dataset/debug/test_series_descriptions.csv')

In [22]:
# test_images_path = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images'
# test_des = pl.read_csv('/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv')

In [23]:
test_images_path = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_images'
test_des = pl.read_csv('/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_series_descriptions.csv')

# Registers

In [24]:
test_df = create_test_df(test_des, test_images_path)

axial_IPP_df = create_axial_IPP_df(test_df, test_images_path) ##AXIAL_FLAG
sagittal_IPP_df = create_sagittal_IPP_df(test_df, test_images_path)

runner = InferenceRunner()

runner.register_sag_coords('expC002', '/kaggle/input/rsna24-lsdc-train-coor-sagittal-ds', [0])
runner.register_axl_coords('expCA001', '/kaggle/input/rsna24-lsdc-train-coor-axial-ds', [0])

runner.register_cls('exp114', '/kaggle/input/rsna24-lsdc-exps/exp114_new', [0, 1, 2, 3])
runner.register_cls('exp150', '/kaggle/input/rsna24-lsdc-exps/exp150', [0, 1, 2, 3])
runner.register_cls('exp164', '/kaggle/input/rsna24-lsdc-exps/exp164', [0, 1, 2, 3])
runner.register_cls('exp165', '/kaggle/input/rsna24-lsdc-exps/exp165', [0, 1, 2, 3])
runner.register_cls('exp166', '/kaggle/input/rsna24-lsdc-exps/exp166', [0, 1, 2, 3])
runner.register_cls('exp187', '/kaggle/input/rsna24-lsdc-exps/exp187', [0, 1, 2, 3])

creating test df


0it [00:00, ?it/s]

creating axial IPP df


0it [00:00, ?it/s]

creating sagittal IPP df


0it [00:00, ?it/s]

REGISTERED SAG COORDS:  expC002
REGISTERED AXL COORDS:  expCA001
REGISTERED CLS:  exp114
REGISTERED CLS:  exp150
REGISTERED CLS:  exp164
REGISTERED CLS:  exp165
REGISTERED CLS:  exp166
REGISTERED CLS:  exp187


# Run Inference

In [25]:
coords_df = runner.predict_sagittal_coords(test_df)
test_df = test_df.join(coords_df, on='series_id')

patient_coords_df = create_patient_coords_df(test_df, test_images_path) ##AXIAL_FLAG
test_df = infer_data_preprocess(test_df, test_images_path, sagittal_IPP_df)

axl_coords_df = runner.predict_axial_coords(test_df)

test_df = test_df.join(axl_coords_df, on=['series_id', 'level'], how='left')

preds_df = runner.predict_cv(test_df)

INFERENCING: expC002, FOLD: 0


  0%|          | 0/2 [00:00<?, ?it/s]

creating patient coordinates df


0it [00:00, ?it/s]

data preprocessing..


0it [00:00, ?it/s]

INFERENCING: expCA001, FOLD: 0


  0%|          | 0/1 [00:00<?, ?it/s]

INFERENCING: exp114, config_exp: exp114, FOLD: 0


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp114, config_exp: exp114, FOLD: 1


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp114, config_exp: exp114, FOLD: 2


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp114, config_exp: exp114, FOLD: 3


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp150, config_exp: exp150, FOLD: 0


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp150, config_exp: exp150, FOLD: 1


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp150, config_exp: exp150, FOLD: 2


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp150, config_exp: exp150, FOLD: 3


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp164, config_exp: exp164, FOLD: 0


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp164, config_exp: exp164, FOLD: 1


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp164, config_exp: exp164, FOLD: 2


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp164, config_exp: exp164, FOLD: 3


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp165, config_exp: exp165, FOLD: 0


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp165, config_exp: exp165, FOLD: 1


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp165, config_exp: exp165, FOLD: 2


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp165, config_exp: exp165, FOLD: 3


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp166, config_exp: exp166, FOLD: 0


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp166, config_exp: exp166, FOLD: 1


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp166, config_exp: exp166, FOLD: 2


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp166, config_exp: exp166, FOLD: 3


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp187, config_exp: exp187, FOLD: 0


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp187, config_exp: exp187, FOLD: 1


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp187, config_exp: exp187, FOLD: 2


  0%|          | 0/5 [00:00<?, ?it/s]

INFERENCING: exp187, config_exp: exp187, FOLD: 3


  0%|          | 0/5 [00:00<?, ?it/s]

In [26]:
display(preds_df)
preds_df.write_csv('submission.csv')

row_id,normal_mild,moderate,severe
str,f64,f64,f64
"""44036939_left_neural_foraminal…",0.787128,0.205131,0.007741
"""44036939_left_neural_foraminal…",0.276859,0.582092,0.141049
"""44036939_left_neural_foraminal…",0.083866,0.399899,0.516235
"""44036939_left_neural_foraminal…",0.026191,0.200216,0.773593
"""44036939_left_neural_foraminal…",0.392305,0.526893,0.080802
…,…,…,…
"""44036939_spinal_canal_stenosis…",0.59372,0.356247,0.050033
"""44036939_spinal_canal_stenosis…",0.022777,0.176318,0.800905
"""44036939_spinal_canal_stenosis…",0.111004,0.564906,0.32409
"""44036939_spinal_canal_stenosis…",0.138922,0.435522,0.425556
