In [1]:
import os
from glob import glob
import math
import copy
from tqdm import tqdm
import numpy as np
import pandas as pd
import open3d as o3d
import torchio as tio
from pydicom import dcmread
import cv2
import pgzip
import timm_3d
from spacecutter.losses import CumulativeLinkLoss
from spacecutter.models import LogisticCumulativeLink
from spacecutter.callbacks import AscensionCallback
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
pd.options.display.max_rows = 100
pd.options.display.max_columns = 30

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def retrieve_coordinate_training_data(train_path):
    def reshape_row(row):
        data = {'study_id': [], 'condition': [], 'level': [], 'severity': []}

        for column, value in row.items():
            if column not in ['study_id', 'series_id', 'instance_number', 'x', 'y', 'series_description']:
                parts = column.split('_')
                condition = ' '.join([word.capitalize() for word in parts[:-2]])
                # capitalize(): 字符串第一个字母大写
                level = parts[-2].capitalize() + '/' + parts[-1].capitalize()
                data['study_id'].append(row['study_id'])
                data['condition'].append(condition)
                data['level'].append(level)
                data['severity'].append(value)

        return pd.DataFrame(data)

    train = pd.read_csv(train_path + 'train.csv')
    label = pd.read_csv(train_path + 'train_label_coordinates.csv')
    train_desc = pd.read_csv(train_path + 'train_series_descriptions.csv')
    test_desc = pd.read_csv(train_path + 'test_series_descriptions.csv')
    sub = pd.read_csv(train_path + 'sample_submission.csv')

    new_train_df = pd.concat([reshape_row(row) for _, row in train.iterrows()], ignore_index=True)
    merged_df = new_train_df.merge(label, on=['study_id', 'condition', 'level'], how='inner')
    final_merged_df = merged_df.merge(train_desc, on=['series_id', 'study_id'], how='inner')
    final_merged_df['severity'] = final_merged_df['severity'].map({'Normal/Mild': 'normal_mild', 'Moderate': 'moderate', 'Severe': 'severe'})

    final_merged_df['row_id'] = (
            final_merged_df['study_id'].astype(str) + '_' +
            final_merged_df['condition'].str.lower().str.replace(' ', '_') + '_' +
            final_merged_df['level'].str.lower().str.replace('/', '_')
    )

    # Create the image_path column
    final_merged_df['image_path'] = (
            f'{train_path}/train_images/' +
            final_merged_df['study_id'].astype(str) + '/' +
            final_merged_df['series_id'].astype(str) + '/' +
            final_merged_df['instance_number'].astype(str) + '.dcm'
    )

    return final_merged_df

In [3]:
CONDITIONS = {
    'Sagittal T2/STIR': ['Spinal Canal Stenosis'],
    'Axial T2': ['Left Subarticular Stenosis', 'Right Subarticular Stenosis'],
    'Sagittal T1': ['Left Neural Foraminal Narrowing', 'Right Neural Foraminal Narrowing'],
}
LABEL_MAP = {'normal_mild': 0, 'moderate': 1, 'severe': 2}

data_path = '../input/rsna-2024-lumbar-spine-degenerative-classification/'

In [4]:
train = pd.read_csv(data_path + 'train.csv')
label = pd.read_csv(data_path + 'train_label_coordinates.csv')
train_desc = pd.read_csv(data_path + 'train_series_descriptions.csv')

test_desc = pd.read_csv(data_path + 'test_series_descriptions.csv')
sub = pd.read_csv(data_path + 'sample_submission.csv')

In [5]:
train_desc.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6294 entries, 0 to 6293
Data columns (total 3 columns):
 #   Column              Non-Null Count  Dtype 
---  ------              --------------  ----- 
 0   study_id            6294 non-null   int64 
 1   series_id           6294 non-null   int64 
 2   series_description  6294 non-null   object
dtypes: int64(2), object(1)
memory usage: 147.6+ KB


In [6]:
train_desc.head()

Unnamed: 0,study_id,series_id,series_description
0,4003253,702807833,Sagittal T2/STIR
1,4003253,1054713880,Sagittal T1
2,4003253,2448190387,Axial T2
3,4646740,3201256954,Axial T2
4,4646740,3486248476,Sagittal T1


In [7]:
train_desc['series_description'].value_counts()

series_description
Axial T2            2340
Sagittal T1         1980
Sagittal T2/STIR    1974
Name: count, dtype: int64

In [8]:
train.head()

Unnamed: 0,study_id,spinal_canal_stenosis_l1_l2,spinal_canal_stenosis_l2_l3,spinal_canal_stenosis_l3_l4,spinal_canal_stenosis_l4_l5,spinal_canal_stenosis_l5_s1,left_neural_foraminal_narrowing_l1_l2,left_neural_foraminal_narrowing_l2_l3,left_neural_foraminal_narrowing_l3_l4,left_neural_foraminal_narrowing_l4_l5,left_neural_foraminal_narrowing_l5_s1,right_neural_foraminal_narrowing_l1_l2,right_neural_foraminal_narrowing_l2_l3,right_neural_foraminal_narrowing_l3_l4,right_neural_foraminal_narrowing_l4_l5,right_neural_foraminal_narrowing_l5_s1,left_subarticular_stenosis_l1_l2,left_subarticular_stenosis_l2_l3,left_subarticular_stenosis_l3_l4,left_subarticular_stenosis_l4_l5,left_subarticular_stenosis_l5_s1,right_subarticular_stenosis_l1_l2,right_subarticular_stenosis_l2_l3,right_subarticular_stenosis_l3_l4,right_subarticular_stenosis_l4_l5,right_subarticular_stenosis_l5_s1
0,4003253,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Moderate,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild
1,4646740,Normal/Mild,Normal/Mild,Moderate,Severe,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Moderate,Normal/Mild,Normal/Mild,Moderate,Moderate,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Severe,Normal/Mild,Normal/Mild,Moderate,Moderate,Moderate,Normal/Mild
2,7143189,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild
3,8785691,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Moderate,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild
4,10728036,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Normal/Mild


In [9]:
train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1975 entries, 0 to 1974
Data columns (total 26 columns):
 #   Column                                  Non-Null Count  Dtype 
---  ------                                  --------------  ----- 
 0   study_id                                1975 non-null   int64 
 1   spinal_canal_stenosis_l1_l2             1974 non-null   object
 2   spinal_canal_stenosis_l2_l3             1974 non-null   object
 3   spinal_canal_stenosis_l3_l4             1974 non-null   object
 4   spinal_canal_stenosis_l4_l5             1974 non-null   object
 5   spinal_canal_stenosis_l5_s1             1974 non-null   object
 6   left_neural_foraminal_narrowing_l1_l2   1973 non-null   object
 7   left_neural_foraminal_narrowing_l2_l3   1973 non-null   object
 8   left_neural_foraminal_narrowing_l3_l4   1973 non-null   object
 9   left_neural_foraminal_narrowing_l4_l5   1973 non-null   object
 10  left_neural_foraminal_narrowing_l5_s1   1973 non-null   object
 11  righ

In [10]:
label.head()

Unnamed: 0,study_id,series_id,instance_number,condition,level,x,y
0,4003253,702807833,8,Spinal Canal Stenosis,L1/L2,322.831858,227.964602
1,4003253,702807833,8,Spinal Canal Stenosis,L2/L3,320.571429,295.714286
2,4003253,702807833,8,Spinal Canal Stenosis,L3/L4,323.030303,371.818182
3,4003253,702807833,8,Spinal Canal Stenosis,L4/L5,335.292035,427.327434
4,4003253,702807833,8,Spinal Canal Stenosis,L5/S1,353.415929,483.964602


In [11]:
label.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 48692 entries, 0 to 48691
Data columns (total 7 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   study_id         48692 non-null  int64  
 1   series_id        48692 non-null  int64  
 2   instance_number  48692 non-null  int64  
 3   condition        48692 non-null  object 
 4   level            48692 non-null  object 
 5   x                48692 non-null  float64
 6   y                48692 non-null  float64
dtypes: float64(2), int64(3), object(2)
memory usage: 2.6+ MB


In [12]:
sub

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_left_neural_foraminal_narrowing_l1_l2,0.333333,0.333333,0.333333
1,44036939_left_neural_foraminal_narrowing_l2_l3,0.333333,0.333333,0.333333
2,44036939_left_neural_foraminal_narrowing_l3_l4,0.333333,0.333333,0.333333
3,44036939_left_neural_foraminal_narrowing_l4_l5,0.333333,0.333333,0.333333
4,44036939_left_neural_foraminal_narrowing_l5_s1,0.333333,0.333333,0.333333
5,44036939_left_subarticular_stenosis_l1_l2,0.333333,0.333333,0.333333
6,44036939_left_subarticular_stenosis_l2_l3,0.333333,0.333333,0.333333
7,44036939_left_subarticular_stenosis_l3_l4,0.333333,0.333333,0.333333
8,44036939_left_subarticular_stenosis_l4_l5,0.333333,0.333333,0.333333
9,44036939_left_subarticular_stenosis_l5_s1,0.333333,0.333333,0.333333


In [13]:
for i, row in train.iterrows():
    # print(i)
    # print(row)
    # print(type(row))
    for column, value in row.items():
        print(f'column = {column}')
        print(f'value = {value}')
        # break
    
    break

column = study_id
value = 4003253
column = spinal_canal_stenosis_l1_l2
value = Normal/Mild
column = spinal_canal_stenosis_l2_l3
value = Normal/Mild
column = spinal_canal_stenosis_l3_l4
value = Normal/Mild
column = spinal_canal_stenosis_l4_l5
value = Normal/Mild
column = spinal_canal_stenosis_l5_s1
value = Normal/Mild
column = left_neural_foraminal_narrowing_l1_l2
value = Normal/Mild
column = left_neural_foraminal_narrowing_l2_l3
value = Normal/Mild
column = left_neural_foraminal_narrowing_l3_l4
value = Normal/Mild
column = left_neural_foraminal_narrowing_l4_l5
value = Moderate
column = left_neural_foraminal_narrowing_l5_s1
value = Normal/Mild
column = right_neural_foraminal_narrowing_l1_l2
value = Normal/Mild
column = right_neural_foraminal_narrowing_l2_l3
value = Normal/Mild
column = right_neural_foraminal_narrowing_l3_l4
value = Moderate
column = right_neural_foraminal_narrowing_l4_l5
value = Moderate
column = right_neural_foraminal_narrowing_l5_s1
value = Normal/Mild
column = left_s

In [14]:
train_data = retrieve_coordinate_training_data(data_path)
train_data.head(30)

Unnamed: 0,study_id,condition,level,severity,series_id,instance_number,x,y,series_description,row_id,image_path
0,4003253,Spinal Canal Stenosis,L1/L2,normal_mild,702807833,8,322.831858,227.964602,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l1_l2,../input/rsna-2024-lumbar-spine-degenerative-c...
1,4003253,Spinal Canal Stenosis,L2/L3,normal_mild,702807833,8,320.571429,295.714286,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l2_l3,../input/rsna-2024-lumbar-spine-degenerative-c...
2,4003253,Spinal Canal Stenosis,L3/L4,normal_mild,702807833,8,323.030303,371.818182,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l3_l4,../input/rsna-2024-lumbar-spine-degenerative-c...
3,4003253,Spinal Canal Stenosis,L4/L5,normal_mild,702807833,8,335.292035,427.327434,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l4_l5,../input/rsna-2024-lumbar-spine-degenerative-c...
4,4003253,Spinal Canal Stenosis,L5/S1,normal_mild,702807833,8,353.415929,483.964602,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l5_s1,../input/rsna-2024-lumbar-spine-degenerative-c...
5,4003253,Left Neural Foraminal Narrowing,L1/L2,normal_mild,1054713880,11,196.070671,126.021201,Sagittal T1,4003253_left_neural_foraminal_narrowing_l1_l2,../input/rsna-2024-lumbar-spine-degenerative-c...
6,4003253,Left Neural Foraminal Narrowing,L2/L3,normal_mild,1054713880,12,191.321555,170.120141,Sagittal T1,4003253_left_neural_foraminal_narrowing_l2_l3,../input/rsna-2024-lumbar-spine-degenerative-c...
7,4003253,Left Neural Foraminal Narrowing,L3/L4,normal_mild,1054713880,12,187.878354,217.245081,Sagittal T1,4003253_left_neural_foraminal_narrowing_l3_l4,../input/rsna-2024-lumbar-spine-degenerative-c...
8,4003253,Left Neural Foraminal Narrowing,L4/L5,moderate,1054713880,11,186.504472,251.592129,Sagittal T1,4003253_left_neural_foraminal_narrowing_l4_l5,../input/rsna-2024-lumbar-spine-degenerative-c...
9,4003253,Left Neural Foraminal Narrowing,L5/S1,normal_mild,1054713880,11,197.100569,289.457306,Sagittal T1,4003253_left_neural_foraminal_narrowing_l5_s1,../input/rsna-2024-lumbar-spine-degenerative-c...


In [15]:
train_data[train_data['study_id'] == 4646740][['study_id', 'series_id', 'series_description', 'condition', 'level', 'severity']]

Unnamed: 0,study_id,series_id,series_description,condition,level,severity
25,4646740,3666319702,Sagittal T2/STIR,Spinal Canal Stenosis,L1/L2,normal_mild
26,4646740,3666319702,Sagittal T2/STIR,Spinal Canal Stenosis,L2/L3,normal_mild
27,4646740,3666319702,Sagittal T2/STIR,Spinal Canal Stenosis,L3/L4,moderate
28,4646740,3666319702,Sagittal T2/STIR,Spinal Canal Stenosis,L4/L5,severe
29,4646740,3666319702,Sagittal T2/STIR,Spinal Canal Stenosis,L5/S1,normal_mild
30,4646740,3486248476,Sagittal T1,Left Neural Foraminal Narrowing,L1/L2,normal_mild
31,4646740,3486248476,Sagittal T1,Left Neural Foraminal Narrowing,L2/L3,normal_mild
32,4646740,3486248476,Sagittal T1,Left Neural Foraminal Narrowing,L3/L4,normal_mild
33,4646740,3486248476,Sagittal T1,Left Neural Foraminal Narrowing,L4/L5,moderate
34,4646740,3486248476,Sagittal T1,Left Neural Foraminal Narrowing,L5/S1,moderate


In [16]:
train_data['series_description'].value_counts()

series_description
Sagittal T1         19724
Axial T2            19220
Sagittal T2/STIR     9748
Name: count, dtype: int64

In [17]:
train_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 48692 entries, 0 to 48691
Data columns (total 11 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   study_id            48692 non-null  int64  
 1   condition           48692 non-null  object 
 2   level               48692 non-null  object 
 3   severity            48657 non-null  object 
 4   series_id           48692 non-null  int64  
 5   instance_number     48692 non-null  int64  
 6   x                   48692 non-null  float64
 7   y                   48692 non-null  float64
 8   series_description  48692 non-null  object 
 9   row_id              48692 non-null  object 
 10  image_path          48692 non-null  object 
dtypes: float64(2), int64(3), object(6)
memory usage: 4.1+ MB


In [18]:
df_tmp = train_data.groupby(['study_id']).transform('size')
df_tmp

0        25
1        25
2        25
3        25
4        25
         ..
48687    25
48688    25
48689    25
48690    25
48691    25
Length: 48692, dtype: int64

In [19]:
df_tmp.value_counts()

25    44725
23     1311
21      987
19      722
22      660
24      240
15       30
17       17
Name: count, dtype: int64

In [20]:
df_tmp = train_data.groupby(['series_description', 'study_id'])['level'].agg('count').reset_index().rename(columns={'level': 'count'})
df_tmp

Unnamed: 0,series_description,study_id,count
0,Axial T2,4003253,10
1,Axial T2,4646740,10
2,Axial T2,7143189,10
3,Axial T2,8785691,10
4,Axial T2,10728036,10
...,...,...,...
5914,Sagittal T2/STIR,4282019580,5
5915,Sagittal T2/STIR,4283570761,5
5916,Sagittal T2/STIR,4284048608,5
5917,Sagittal T2/STIR,4287160193,5


In [21]:
df_tmp[df_tmp['series_description'] == 'Sagittal T2/STIR']['count'].value_counts()

count
5    1898
3      39
4      35
1       1
Name: count, dtype: int64

In [22]:
df_tmp[df_tmp['series_description'] != 'Sagittal T2/STIR']['count'].value_counts()

count
10    3768
8       84
6       82
9        8
4        1
2        1
7        1
15       1
Name: count, dtype: int64

In [23]:
df_tmp = train_data.groupby(['study_id'])['level'].agg('count').reset_index().rename(columns={'level': 'count'})
df_tmp['count'].value_counts()

count
25    1789
23      57
21      47
19      38
22      30
24      10
15       2
17       1
Name: count, dtype: int64

In [24]:
type((df_tmp))

pandas.core.frame.DataFrame

In [25]:
dataframe = train_data[['study_id', 'series_id', 'series_description', 'condition', 'level', 'severity']].drop_duplicates()
dataframe.head()

Unnamed: 0,study_id,series_id,series_description,condition,level,severity
0,4003253,702807833,Sagittal T2/STIR,Spinal Canal Stenosis,L1/L2,normal_mild
1,4003253,702807833,Sagittal T2/STIR,Spinal Canal Stenosis,L2/L3,normal_mild
2,4003253,702807833,Sagittal T2/STIR,Spinal Canal Stenosis,L3/L4,normal_mild
3,4003253,702807833,Sagittal T2/STIR,Spinal Canal Stenosis,L4/L5,normal_mild
4,4003253,702807833,Sagittal T2/STIR,Spinal Canal Stenosis,L5/S1,normal_mild


In [26]:
labels = dict()
for name, group in dataframe.groupby(['study_id']):
    group = group[['condition', 'level', 'severity']].drop_duplicates().sort_values(['condition', 'level'])
    label_indices = []
    for index, row in group.iterrows():
        if row['severity'] in LABEL_MAP:
            label_indices.append(LABEL_MAP[row['severity']])
        else:
            raise ValueError()
    
    study_id = name[0]
    labels[study_id] = label_indices
    break

In [27]:
labels

{4003253: [0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0]}

In [29]:
len([0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0])

25

In [28]:
group

Unnamed: 0,condition,level,severity
5,Left Neural Foraminal Narrowing,L1/L2,normal_mild
6,Left Neural Foraminal Narrowing,L2/L3,normal_mild
7,Left Neural Foraminal Narrowing,L3/L4,normal_mild
8,Left Neural Foraminal Narrowing,L4/L5,moderate
9,Left Neural Foraminal Narrowing,L5/S1,normal_mild
15,Left Subarticular Stenosis,L1/L2,normal_mild
16,Left Subarticular Stenosis,L2/L3,normal_mild
17,Left Subarticular Stenosis,L3/L4,normal_mild
18,Left Subarticular Stenosis,L4/L5,moderate
19,Left Subarticular Stenosis,L5/S1,normal_mild


In [30]:
from torch import nn


heads = nn.ModuleList([
    nn.Sequential(
        nn.Linear(512, 1),
        LogisticCumulativeLink(3)
    ) for i in range(25)
])
heads

ModuleList(
  (0-24): 25 x Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True)
    (1): LogisticCumulativeLink()
  )
)