Adopted from https://www.kaggle.com/code/yxyyxy/rsna2024-training-baseline-2nd-stage/edit

# RSNA2024 LSDC Submission Baseline

This notebook will Let the model infer and make a submission.

### My other Notebooks
- [RSNA2024 LSDC Making Dataset](https://www.kaggle.com/code/itsuki9180/rsna2024-lsdc-making-dataset) 
- [RSNA2024 LSDC Training Baseline](https://www.kaggle.com/code/itsuki9180/rsna2024-lsdc-training-baseline) 
- [RSNA2024 LSDC Submission Baseline](https://www.kaggle.com/code/itsuki9180/rsna2024-lsdc-submission-baseline) <- you're reading now

# Import Libralies

In [1]:
DEBUG = False
if DEBUG == True:
    rd = '/kaggle/input/rsna-lsdc-2024-submission-debug-dataset/debug'
else:
    rd = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'
    
# Define the directory path
DATA_fromStage1 = '/kaggle/working'


In [2]:
!python /kaggle/input/script-deepspine-custom-dataset/main.py \
"{rd}/test_series_descriptions.csv" \
"{rd}/sample_submission.csv" \
"{rd}/test_images" \
'/kaggle/input/2d-segmentation-of-sagittal-lumbar-spine-mri/simple_unet.pth'


Using path: /kaggle/input/2d-segmentation-of-sagittal-lumbar-spine-mri/simple_unet.pth
Using path: /kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_series_descriptions.csv
Using base path: /kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_images
Processing studies: 100%|█████████████████████████| 1/1 [00:08<00:00,  8.04s/it]
Pipeline completed successfully!
Processing studies: 100%|█████████████████████████| 1/1 [00:00<00:00,  1.14it/s]
Axial dataset generation completed successfully!


In [3]:
import os
import gc
import sys
from PIL import Image
import cv2
import math, random
import numpy as np
import pandas as pd
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import log_loss

from collections import OrderedDict

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.optim import AdamW

import timm
from transformers import get_cosine_schedule_with_warmup

import albumentations as A


# Config

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
N_WORKERS = os.cpu_count()
USE_AMP = True
SEED = 1

IMG_SIZE = [224, 224]
N_LABELS = 5
N_CLASSES = 3 * N_LABELS

model_name_sag = 'efficientnet_b0'
model_name_axi = 'resnet34'
# resnet34
in_chans_sag = 30
in_chans_axi = 4

N_FOLDS = 5

BATCH_SIZE = 1

In [5]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda', index=0)

In [6]:
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

# Define the mapping for each level
level_mapping = {
    'L1/L2': 'l1_l2',
    'L2/L3': 'l2_l3',
    'L3/L4': 'l3_l4',
    'L4/L5': 'l4_l5',
    'L5/S1': 'l5_s1'
}


In [7]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

In [8]:
# Define the path to CSV file
csv_file_path = DATA_fromStage1 + '/sagittalT2/dataset_metadata.csv'
# Read the CSV file
dataset_metadata = pd.read_csv(csv_file_path)
dataset_metadata.rename(columns={'Unnamed: 0': 'st_id'}, inplace=True)
dataset_metadata.set_index('st_id', inplace=True)

# # Print the number of columns and rows in the DataFrame
# num_rows = dataset_metadata.shape[0]
# print(f"columns: {dataset_metadata.columns}")
# print(f"Number of rows: {num_rows}")

# # print(dataset_metadata.head())
# # Stack the 2nd to 6th columns into a single column
# stacked_df = dataset_metadata.stack(dropna=False).reset_index(level=1)
# stacked_df.columns = ['Level', 'fn']
'''# stacked_df = stacked_df.reset_index(drop=True)
print(stacked_df)

print("----------- test index 0 -------------")
st_id = stacked_df.index[0]
row_idx = stacked_df.iloc[0]
print(st_id)
print(row_idx['Level'])
print(row_idx['fn'])
print("----------- test index 1 -------------")
st_id = stacked_df.index[1]
row_idx = stacked_df.iloc[1]
print(st_id)
print(row_idx['Level'])
print(row_idx['fn'])'''

dataset_metadata


Unnamed: 0_level_0,L4/L5,L3/L4,L2/L3,L1/L2,L5/S1
st_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
44036939,/kaggle/working/sagittalT2/44036939/L4_L5.npy,/kaggle/working/sagittalT2/44036939/L3_L4.npy,/kaggle/working/sagittalT2/44036939/L2_L3.npy,/kaggle/working/sagittalT2/44036939/L1_L2.npy,/kaggle/working/sagittalT2/44036939/L5_S1.npy


In [9]:
'''# one npy file example
print("--------------------- SagittalT2 samples -----------")
example_npy_fn = dataset_metadata['L4/L5'][0]
example_npy_fn = example_npy_fn.split('/')
# Load the .npy file
example_npy = np.load(os.path.join(DATA_fromStage1, *(example_npy_fn[-3:])))
# Print the structure of the data
print(f"Shape of the data: {example_npy.shape}")
print(f"Data type: {example_npy.dtype}")
# Show the 15 slices one by one
for i in range(15):
    plt.imshow(example_npy[i], cmap='gray')
    plt.title(f'Slice {i+1}')
    plt.axis('off')
    plt.show()

print("--------------------- SagittalT1 samples -----------")
example_npy = np.load(os.path.join(DATA_fromStage1, 'sagittalT1', *(example_npy_fn[-2:])))
# Print the structure of the data
print(f"Shape of the data: {example_npy.shape}")
print(f"Data type: {example_npy.dtype}")
# Show the 15 slices one by one
for i in range(15):
    plt.imshow(example_npy[i], cmap='gray')
    plt.title(f'Slice {i+1}')
    plt.axis('off')
    plt.show()

print("--------------------- AxialT2 samples ----------")
# Load the .npy file
example_npy = np.load(os.path.join(DATA_fromStage1, 'axialT2', *(example_npy_fn[-2:])))
print(f"Shape of the data: {example_npy.shape}")
print(f"Data type: {example_npy.dtype}")
for i in range(6):
    img = example_npy[i]
    plt.imshow(img, cmap='gray')
    plt.title(f'Slice {i+1}')
    plt.axis('off')
    plt.show()
'''


'# one npy file example\nprint("--------------------- SagittalT2 samples -----------")\nexample_npy_fn = dataset_metadata[\'L4/L5\'][0]\nexample_npy_fn = example_npy_fn.split(\'/\')\n# Load the .npy file\nexample_npy = np.load(os.path.join(DATA_fromStage1, *(example_npy_fn[-3:])))\n# Print the structure of the data\nprint(f"Shape of the data: {example_npy.shape}")\nprint(f"Data type: {example_npy.dtype}")\n# Show the 15 slices one by one\nfor i in range(15):\n    plt.imshow(example_npy[i], cmap=\'gray\')\n    plt.title(f\'Slice {i+1}\')\n    plt.axis(\'off\')\n    plt.show()\n\nprint("--------------------- SagittalT1 samples -----------")\nexample_npy = np.load(os.path.join(DATA_fromStage1, \'sagittalT1\', *(example_npy_fn[-2:])))\n# Print the structure of the data\nprint(f"Shape of the data: {example_npy.shape}")\nprint(f"Data type: {example_npy.dtype}")\n# Show the 15 slices one by one\nfor i in range(15):\n    plt.imshow(example_npy[i], cmap=\'gray\')\n    plt.title(f\'Slice {i+1}\'

# Define Dataset

In [10]:
class RSNA24TestDataset(Dataset):
    def __init__(self, df_fn, Slice_len_Sag=15, Slice_len_Axi=4, transform=None, trainsform_axis=None):
        self.df_fn = df_fn
        self.transform = transform
        self.trainsform_axis = trainsform_axis
        self.Slice_len_Sag = Slice_len_Sag
        self.Slice_len_Axi = Slice_len_Axi
        # Select all rows where the 'Name' column has the value 'Alice'
    def __len__(self):
        return len(self.df_fn)

    def __getitem__(self, idx):
        st_id = self.df_fn.index[idx]
        row_idx = self.df_fn.iloc[idx]
        levels = ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
        filenames = [row_idx[level] for level in levels]
        
        npy_sagT2_list = []
        npy_sagT1_list = []
        npy_AxiT2_list = []
    
        for filename in filenames:
            path_split = filename.split('/')
            
            # Saggital T2 ------------  Load the .npy file
            npy_sagT2_path = os.path.join(DATA_fromStage1, path_split[3], path_split[4], path_split[5])
            npy_sagT2 = np.load(npy_sagT2_path).astype(np.float32)
            current_length = npy_sagT2.shape[0]
            if current_length > self.Slice_len_Sag:
                indices = np.linspace(0, current_length - 1, self.Slice_len_Sag, dtype=int)
                npy_sagT2 = npy_sagT2[indices, :, :]
            elif current_length < self.Slice_len_Sag:
                indices = np.linspace(0, current_length - 1, self.Slice_len_Sag, dtype=int)
                npy_sagT2 = npy_sagT2[indices, :, :]
            npy_sagT2_list.append(npy_sagT2)
            # Saggital T1 ------------  Load the .npy file
            npy_sagT1_path = os.path.join(DATA_fromStage1, "sagittalT1", path_split[4], path_split[5])
            npy_sagT1 = np.load(npy_sagT1_path).astype(np.float32)
            current_length = npy_sagT1.shape[0]
            if current_length > self.Slice_len_Sag:
                indices = np.linspace(0, current_length - 1, self.Slice_len_Sag, dtype=int)
                npy_sagT1 = npy_sagT1[indices, :, :]
            elif current_length < self.Slice_len_Sag:
                indices = np.linspace(0, current_length - 1, self.Slice_len_Sag, dtype=int)
                npy_sagT1 = npy_sagT1[indices, :, :]
            npy_sagT1_list.append(npy_sagT1)
            # Axial T2 ------------ Load the .npy file
            npy_AxiT2_path = os.path.join(DATA_fromStage1 , "axialT2", path_split[4], path_split[5])
            npy_AxiT2 = np.load(npy_AxiT2_path).astype(np.float32)
            current_length = npy_AxiT2.shape[0]
            if current_length > self.Slice_len_Axi:
                indices = np.linspace(0, current_length - 1, self.Slice_len_Axi, dtype=int)
                npy_AxiT2 = npy_AxiT2[indices, :, :]
            elif current_length < self.Slice_len_Axi:
                indices = np.linspace(0, current_length - 1, self.Slice_len_Axi, dtype=int)
                npy_AxiT2 = npy_AxiT2[indices, :, :]
            npy_AxiT2_list.append(npy_AxiT2)
            
        # Transpose and transform the data
        npy_sagT1_list = [np.transpose(npy, (1, 2, 0)) for npy in npy_sagT1_list]
        npy_sagT2_list = [np.transpose(npy, (1, 2, 0)) for npy in npy_sagT2_list]
        npy_AxiT2_list = [np.transpose(npy, (1, 2, 0)) for npy in npy_AxiT2_list]

        if self.transform is not None:
            npy_sagT1_list = [self.transform(image=npy)['image'] for npy in npy_sagT1_list]
            npy_sagT2_list = [self.transform(image=npy)['image'] for npy in npy_sagT2_list]
            npy_AxiT2_list = [self.trainsform_axis(image=npy)['image'] for npy in npy_AxiT2_list]

        # Transpose back to the original format
        npy_sagT1_list = [np.transpose(npy, (2, 0, 1)) for npy in npy_sagT1_list]
        npy_sagT2_list = [np.transpose(npy, (2, 0, 1)) for npy in npy_sagT2_list]
        npy_AxiT2_list = [np.transpose(npy, (2, 0, 1)) for npy in npy_AxiT2_list]

        return st_id, npy_sagT1_list, npy_sagT2_list, npy_AxiT2_list, levels


In [11]:
transforms_test_Sag = A.Compose([
    A.Resize(84, 160),
    A.Normalize(mean=0.5, std=0.5)
])

transforms_test = A.Compose([
    A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
    A.Normalize(mean=0.5, std=0.5)
])


In [12]:
test_ds = RSNA24TestDataset(dataset_metadata, transform=transforms_test_Sag, trainsform_axis=transforms_test)
test_dl = DataLoader(
    test_ds, 
    batch_size=1, 
    shuffle=False,
    num_workers=N_WORKERS,
    pin_memory=True,
    drop_last=False
)


In [13]:
'''from collections import Counter

sagT2_slicenum = []
AxiT2_slicenum = []

print(test_dl.__len__())

# Iterate through the data loader and append slice numbers
for idx, (st_id, npy_sagT1_list, npy_sagT2_list, npy_AxiT2_list, levels) in enumerate(test_dl):
    for npy_sagT2 in npy_sagT2_list:
        sagT2_slicenum.append(npy_sagT2.shape)
        assert not torch.isnan(npy_sagT2).any(), "NaN values found in npy_sagT2"
    
    for npy_AxiT2 in npy_AxiT2_list:
        AxiT2_slicenum.append(npy_AxiT2.shape)
        assert not torch.isnan(npy_AxiT2).any(), "NaN values found in npy_AxiT2"
    
    print(f"Batch {idx + 1}: Levels - {levels}")

# Count the occurrences of each unique value in the lists
sagT2_counts = Counter(sagT2_slicenum)
AxiT2_counts = Counter(AxiT2_slicenum)

print("Occurrences of each unique value in sagT2_slicenum:")
for value, count in sagT2_counts.items():
    print(f"Value: {value}, Count: {count}")

print("Occurrences of each unique value in AxiT2_slicenum:")
for value, count in AxiT2_counts.items():
    print(f"Value: {value}, Count: {count}")'''


'from collections import Counter\n\nsagT2_slicenum = []\nAxiT2_slicenum = []\n\nprint(test_dl.__len__())\n\n# Iterate through the data loader and append slice numbers\nfor idx, (st_id, npy_sagT1_list, npy_sagT2_list, npy_AxiT2_list, levels) in enumerate(test_dl):\n    for npy_sagT2 in npy_sagT2_list:\n        sagT2_slicenum.append(npy_sagT2.shape)\n        assert not torch.isnan(npy_sagT2).any(), "NaN values found in npy_sagT2"\n    \n    for npy_AxiT2 in npy_AxiT2_list:\n        AxiT2_slicenum.append(npy_AxiT2.shape)\n        assert not torch.isnan(npy_AxiT2).any(), "NaN values found in npy_AxiT2"\n    \n    print(f"Batch {idx + 1}: Levels - {levels}")\n\n# Count the occurrences of each unique value in the lists\nsagT2_counts = Counter(sagT2_slicenum)\nAxiT2_counts = Counter(AxiT2_slicenum)\n\nprint("Occurrences of each unique value in sagT2_slicenum:")\nfor value, count in sagT2_counts.items():\n    print(f"Value: {value}, Count: {count}")\n\nprint("Occurrences of each unique value i

# Define Model

In [14]:
class LevelHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LevelHead, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x

class RSNA24Model_Hybrid(nn.Module):
    def __init__(self, model_name_sag, model_name_axi, in_chans_sag, in_chans_axi, num_classes, level_names):
        super(RSNA24Model_Hybrid, self).__init__()
        self.model_sag = timm.create_model(model_name_sag, in_chans=in_chans_sag, global_pool='avg'
                                           , pretrained=False, features_only=False)
        self.model_axi = timm.create_model(model_name_axi, in_chans=in_chans_axi, global_pool='avg'
                                           , pretrained=False, features_only=False)
        
        # Replace the last layer with an identity layer
        if hasattr(self.model_sag, 'classifier'):
            self.model_sag.classifier = nn.Identity()
        elif hasattr(self.model_sag, 'fc'):
            self.model_sag.fc = nn.Identity()
        
        if hasattr(self.model_axi, 'classifier'):
            self.model_axi.classifier = nn.Identity()
        elif hasattr(self.model_axi, 'fc'):
            self.model_axi.fc = nn.Identity()
        
        # Get the output feature sizes
        with torch.no_grad():
            sample_input_sag = torch.randn(1, in_chans_sag, 84, 160)
            sample_input_axi = torch.randn(1, in_chans_axi, 224, 224)
            output_sag = self.model_sag(sample_input_sag)
            output_axi = self.model_axi(sample_input_axi)
        
        # Define the final fully connected layers for each task
        self.fc_heads = nn.ModuleDict({
            level: LevelHead(output_sag.shape[1] + output_axi.shape[1], num_classes) for level in level_names
        })
        
    def forward(self, x_sag, x_axi, level):
        x_sag = self.model_sag(x_sag)
        x_axi = self.model_axi(x_axi)
        x = torch.cat((x_sag, x_axi), dim=1)
        x = self.fc_heads[level[0]](x) # the input level is a array associated with batch size
        return x


# Load Models

In [15]:
CKPT_PATHS = glob.glob('/kaggle/input/tpu-rsna2024-training-baseline-2nd-stage/rsna24-results/best_wll_model_fold-*.pt')
CKPT_PATHS = sorted(CKPT_PATHS)


In [16]:
# Ensure device is set correctly
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

models = []
for i, cp in enumerate(CKPT_PATHS):
    print(f'loading {cp}...')
    model = RSNA24Model_Hybrid(model_name_sag, model_name_axi,
                               in_chans_sag, in_chans_axi, 
                               num_classes=N_CLASSES, level_names=level_mapping.keys())
    model.load_state_dict(torch.load(cp, map_location=device))
    model.eval()
    model.to(device)
    models.append(model)

'''print(models)'''
    

loading /kaggle/input/tpu-rsna2024-training-baseline-2nd-stage/rsna24-results/best_wll_model_fold-0.pt...
loading /kaggle/input/tpu-rsna2024-training-baseline-2nd-stage/rsna24-results/best_wll_model_fold-1.pt...
loading /kaggle/input/tpu-rsna2024-training-baseline-2nd-stage/rsna24-results/best_wll_model_fold-2.pt...
loading /kaggle/input/tpu-rsna2024-training-baseline-2nd-stage/rsna24-results/best_wll_model_fold-3.pt...
loading /kaggle/input/tpu-rsna2024-training-baseline-2nd-stage/rsna24-results/best_wll_model_fold-4.pt...


'print(models)'

# Inference loop

In [17]:
autocast = torch.cuda.amp.autocast(enabled=USE_AMP, dtype=torch.half)
y_preds = []
row_names = []

seq_cond = [1, 3, 2, 4, 0]
with tqdm(test_dl, leave=True) as pbar:
    with torch.no_grad():
        for idx, (st_id, npy_sagT1_list, npy_sagT2_list, npy_AxiT2_list, levels) in enumerate(pbar):
            pred_per_study = np.zeros((25, 3))
            index = 0  # Initialize the index counter
            for npy_sagT1, npy_sagT2, npy_AxiT2, level in zip(npy_sagT1_list, npy_sagT2_list, npy_AxiT2_list, levels):        
                pred_per_study_level = np.zeros((5, 3))  
                npy_sagT1 = npy_sagT1.to(device)
                npy_sagT2 = npy_sagT2.to(device)
                npy_AxiT2 = npy_AxiT2.to(device)
                with torch.cuda.amp.autocast(): 
                    for m in models:    
                        y = m(torch.cat((npy_sagT1, npy_sagT2), axis=1), npy_AxiT2, level)[0]
                        for col in range(N_LABELS):
                            pred = y[col*3:col*3+3]
                            y_pred = pred.float().softmax(0).cpu().numpy()
                            pred_per_study_level[col] += y_pred / len(models)
                    # pred_per_study_level (5, 3)
                for i in range(5):
                    pred_per_study[index + i*5, :] = pred_per_study_level[seq_cond[i], :]
                index += 1  # Increment the index for the next iteration

            # Add row names following the new sequence
            for cond_idx in seq_cond:
                cond = CONDITIONS[cond_idx]
                for i in range(5):
                    row_name = f"{str(st_id.item())}_{cond}_{level_mapping[''.join(levels[i])]}"
                    row_names.append(row_name)
            
            y_preds.append(pred_per_study)
        
y_preds = np.concatenate(y_preds, axis=0)
                    
print(len(row_names))
print(len(y_preds))

CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]


100%|██████████| 1/1 [00:01<00:00,  1.59s/it]

25
25





In [18]:
sample_sub = pd.read_csv(f'{rd}/sample_submission.csv')
LABELS = list(sample_sub.columns[1:])
# print(sample_sub.head(30))
print(sample_sub.index)

RangeIndex(start=0, stop=25, step=1)


# Make Submission

In [19]:
# # Check if "row_id" columns are the same
# row_id_same = sub['row_id'].equals(sample_sub['row_id'])
# print("Row_id columns are the same:", row_id_same)

# # If not the same, print the unequal rows
# if not row_id_same:
#     unequal_rows = sub[sub['row_id'] != sample_sub['row_id']]
#     print("Rows with unequal 'row_id':")
#     print(unequal_rows)

# # Check if columns are the same
# columns_are_same = sub.columns.equals(sample_sub.columns)
# print("Columns are the same:", columns_are_same)

In [20]:
sub = pd.DataFrame()
sub['row_id'] = row_names
sub[LABELS] = y_preds
print(sub.index)
# print(sub.head(30))


RangeIndex(start=0, stop=25, step=1)


In [21]:
sample_sub.to_csv('submission.csv', index=False)
'''pd.read_csv('submission.csv').head()'''


"pd.read_csv('submission.csv').head()"