In [1]:
!pip install -q timm pydicom

In [2]:
import os
import gc
import sys
from PIL import Image
import cv2
import math, random
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold

from collections import OrderedDict

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

import timm
from timm.utils import ModelEmaV2
from transformers import get_cosine_schedule_with_warmup

import albumentations as A

from sklearn.model_selection import KFold

import re
import pydicom

In [3]:
INPUT_PATH  = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'

OUTPUT_PATH = '/kaggle/working'

# Add Input -> Your Work + Notebook -> rsna24lsdc0823train (Version 1)
MODEL_PATH = '/kaggle/input/rsna24lsdc-0816-train/rsna-2024-lsdc-trained-models'

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
N_WORKERS = os.cpu_count()
USE_AMP = True
SEED = 8620

IMG_SIZE = [512, 512]
IN_CHANS = 40
N_LABELS = 25
N_CLASSES = 3 * N_LABELS

N_FOLDS = 5

MODEL_NAME = "edgenext_base.in21k_ft_in1k"

BATCH_SIZE = 1

In [5]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda', index=0)

In [6]:
df = pd.read_csv(f'{INPUT_PATH}/test_series_descriptions.csv')
df.head()

Unnamed: 0,study_id,series_id,series_description
0,44036939,2828203845,Sagittal T1
1,44036939,3481971518,Axial T2
2,44036939,3844393089,Sagittal T2/STIR


In [7]:
study_ids = list(df['study_id'].unique())

In [8]:
sample_sub = pd.read_csv(f'{INPUT_PATH}/sample_submission.csv')

In [9]:
LABELS = list(sample_sub.columns[1:])
LABELS

['normal_mild', 'moderate', 'severe']

In [10]:
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]

In [11]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

In [12]:
import numpy as np
import cv2
import pydicom
import glob
from torch.utils.data import Dataset

class RSNA24TestDataset(Dataset):
    def __init__(self, df, study_ids, phase='test', transform=None):
        self.df = df
        self.study_ids = study_ids
        self.transform = transform
        self.phase = phase
    
    def __len__(self):
        return len(self.study_ids)
    
    def get_img_paths(self, study_id, series_desc):
        pdf = self.df[(self.df['study_id']==study_id) & (self.df['series_description']==series_desc)]
        
        # Find the series with the most images by actually counting files
        max_images = 0
        series_with_most_images = None
        for _, row in pdf.iterrows():
            series_path = f'{INPUT_PATH}/test_images/{study_id}/{row["series_id"]}'
            image_count = len([f for f in os.listdir(series_path) if f.endswith('.dcm')])
            if image_count > max_images:
                max_images = image_count
                series_with_most_images = row["series_id"]
        
        if series_with_most_images is None:
            return []
        
        allimgs = glob.glob(f'{INPUT_PATH}/test_images/{study_id}/{series_with_most_images}/*.dcm')
        return sorted(allimgs, key=natural_keys)
    
    def read_dcm_ret_arr(self, src_path):
        dicom_data = pydicom.dcmread(src_path)
        image = dicom_data.pixel_array
        image = (image - image.min()) / (image.max() - image.min() + 1e-6) * 255
        img = cv2.resize(image, (IMG_SIZE[0], IMG_SIZE[1]), interpolation=cv2.INTER_CUBIC)
        assert img.shape == (IMG_SIZE[0], IMG_SIZE[1])
        return img
    
    def select_images(self, allimgs, num_images, image_type):
        if len(allimgs) < num_images:
            print(f"Warning: Only {len(allimgs)} images available for {image_type}. Duplicating images to reach {num_images}.")
            # Duplicate images
            duplicated = []
            for i in range(num_images):
                duplicated.append(allimgs[i % len(allimgs)])
            return duplicated
        elif len(allimgs) == num_images:
            return allimgs
        else:
            if num_images == 10:  # For Sagittal T1 and Sagittal T2/STIR
                start = max(0, len(allimgs) // 2 - 5)
                return allimgs[start:start+10]
            elif num_images == 20:  # For Axial T2
                indices = np.linspace(0, len(allimgs) - 1, num_images, dtype=int)
                return [allimgs[i] for i in indices]
    
    def __getitem__(self, idx):
        x = np.zeros((IMG_SIZE[0], IMG_SIZE[1], IN_CHANS), dtype=np.uint8)
        st_id = self.study_ids[idx]
        
        image_types = [('Sagittal T1', 10), ('Sagittal T2/STIR', 10), ('Axial T2', 20)]
        
        for i, (series_desc, num_images) in enumerate(image_types):
            allimgs = self.get_img_paths(st_id, series_desc)
            if not allimgs:
                print(f"{st_id}: {series_desc}, has no images")
                continue
            
            selected_imgs = self.select_images(allimgs, num_images, series_desc)
            
            for j, img_path in enumerate(selected_imgs):
                try:
                    img = self.read_dcm_ret_arr(img_path)
                    x[..., j + i*10] = img.astype(np.uint8)
                except:
                    print(f'failed to load on {st_id}, {series_desc}')
        
        if self.transform is not None:
            x = self.transform(image=x)['image']
        x = x.transpose(2, 0, 1)
        
        return x, str(st_id)

In [13]:
transforms_test = A.Compose([
    A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
    A.Normalize(mean=0.5, std=0.5)
])

In [14]:
test_ds = RSNA24TestDataset(df, study_ids, transform=transforms_test)
test_dl = DataLoader(
    test_ds, 
    batch_size=1, 
    shuffle=False,
    num_workers=N_WORKERS,
    pin_memory=True,
    drop_last=False
)

In [15]:
# class RSNA24Model(nn.Module):
#     def __init__(self, model_name, in_c=40, n_classes=75, pretrained=True, features_only=False):
#         super().__init__()
#         self.model = timm.create_model(
#                                     model_name,
#                                     pretrained=pretrained, 
#                                     features_only=features_only,
#                                     in_chans=in_c,
#                                     num_classes=n_classes,
#                                     global_pool='avg'
#                                     )
    
#     def forward(self, x):
#         y = self.model(x)
#         return y

In [16]:
class ClassificationModule(nn.Module):
#     def __init__(self, in_channels, num_classes, hidden_dim=512):
#         super().__init__()
#         self.conv = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1)
#         self.bn = nn.BatchNorm2d(hidden_dim)
#         self.pool = nn.AdaptiveAvgPool2d(1)
#         self.fc = nn.Linear(hidden_dim, num_classes)

#     def forward(self, x):
#         x = torch.relu(self.bn(self.conv(x)))
#         x = self.pool(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc(x)
#         return x

    def __init__(self, in_channels, num_classes, hidden_dim=512):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(hidden_dim)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(hidden_dim)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class RSNA24Model(nn.Module):
    def __init__(self, base_model_name, in_chans, num_classes, pretrained=True):
        super().__init__()
        self.base_model = timm.create_model(base_model_name, pretrained=pretrained, in_chans=in_chans, features_only=True)

        # Freeze/Update the base model
        for param in self.base_model.parameters():
            param.requires_grad = True

        # Get the number of output features from the last layer of the base model
        with torch.no_grad():
            dummy_input = torch.randn(1, in_chans, IMG_SIZE[0], IMG_SIZE[1])
            features = self.base_model(dummy_input)
            in_features = features[-1].shape[1]  # Use the last feature map

            # i.e., if features[-1].shape is (1, 512, 7, 7),
            # then features[-1].shape[1] would be 512.
            # we're feeding it into a conv layer, so we need in_channels only.

        # Add custom classification module
        self.classification_module = ClassificationModule(in_features, num_classes)

    def forward(self, x):
        features = self.base_model(x)
        output = self.classification_module(features[-1])  # Use the last feature map
        return output.view(output.size(0), -1)  # Reshape to (batch_size, num_classes * 3)

In [17]:
import glob

models = []

CKPT_PATHS = glob.glob('/kaggle/input/rsna24lsdc0823train/rsna-2024-lsdc-trained-models/best_wll_model_fold-*.pt')
CKPT_PATHS = sorted(CKPT_PATHS)

for i, cp in enumerate(CKPT_PATHS):
    print(f'loading {cp}...')
    model = RSNA24Model(MODEL_NAME, IN_CHANS, N_CLASSES, pretrained=False)
    model.load_state_dict(torch.load(cp))
    model.eval()
    model.half()
    model.to(device)
    models.append(model)

loading /kaggle/input/rsna24lsdc0823train/rsna-2024-lsdc-trained-models/best_wll_model_fold-0.pt...
loading /kaggle/input/rsna24lsdc0823train/rsna-2024-lsdc-trained-models/best_wll_model_fold-1.pt...
loading /kaggle/input/rsna24lsdc0823train/rsna-2024-lsdc-trained-models/best_wll_model_fold-2.pt...
loading /kaggle/input/rsna24lsdc0823train/rsna-2024-lsdc-trained-models/best_wll_model_fold-3.pt...
loading /kaggle/input/rsna24lsdc0823train/rsna-2024-lsdc-trained-models/best_wll_model_fold-4.pt...


In [18]:
autocast = torch.cuda.amp.autocast(enabled=USE_AMP, dtype=torch.half)
y_preds = []
row_names = []

with tqdm(test_dl, leave=True) as pbar:
    with torch.no_grad():
        for idx, (x, si) in enumerate(pbar):
            x = x.to(device)
            pred_per_study = np.zeros((25, 3))
            
            for cond in CONDITIONS:
                for level in LEVELS:
                    row_names.append(si[0] + '_' + cond + '_' + level)
            
            with autocast:
                for m in models:
                    y = m(x)[0]
                    for col in range(N_LABELS):
                        pred = y[col*3:col*3+3]
                        y_pred = pred.float().softmax(0).cpu().numpy()
                        pred_per_study[col] += y_pred / len(models)
                y_preds.append(pred_per_study)

y_preds = np.concatenate(y_preds, axis=0)

100%|██████████| 1/1 [00:02<00:00,  2.88s/it]


In [19]:
sub = pd.DataFrame()
sub['row_id'] = row_names
sub[LABELS] = y_preds
sub.head(25)

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.321904,0.365173,0.312923
1,44036939_spinal_canal_stenosis_l2_l3,0.254861,0.367973,0.377166
2,44036939_spinal_canal_stenosis_l3_l4,0.304131,0.448013,0.247857
3,44036939_spinal_canal_stenosis_l4_l5,0.417411,0.295237,0.287352
4,44036939_spinal_canal_stenosis_l5_s1,0.841112,0.100403,0.058485
5,44036939_left_neural_foraminal_narrowing_l1_l2,0.488327,0.482438,0.029236
6,44036939_left_neural_foraminal_narrowing_l2_l3,0.328101,0.592873,0.079027
7,44036939_left_neural_foraminal_narrowing_l3_l4,0.255925,0.523179,0.220896
8,44036939_left_neural_foraminal_narrowing_l4_l5,0.163932,0.441611,0.394457
9,44036939_left_neural_foraminal_narrowing_l5_s1,0.137688,0.39521,0.467102


In [20]:
sub.to_csv('submission.csv', index=False)
pd.read_csv('submission.csv').head()

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.321904,0.365173,0.312923
1,44036939_spinal_canal_stenosis_l2_l3,0.254861,0.367973,0.377166
2,44036939_spinal_canal_stenosis_l3_l4,0.304131,0.448013,0.247857
3,44036939_spinal_canal_stenosis_l4_l5,0.417411,0.295237,0.287352
4,44036939_spinal_canal_stenosis_l5_s1,0.841112,0.100403,0.058485
