In [19]:
from glob import glob
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import random

# image stuff
from PIL import Image
import matplotlib.pyplot as plt

###############################
# modeling with PyTorch dataset

import torch
from torch.utils.data import Dataset
from torchvision import datasets

from torchvision.transforms import ToTensor
from torchvision.io import read_image

# Helpers

In [20]:
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import os
from tqdm import tqdm

import pydicom
import numpy as np
import glob

df_train_series_descriptions = pd.read_csv("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv")
folder_train_images = os.listdir("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images")
# exclude .DS_store files
folder_train_images = list(filter(lambda x: x.find('.DS') == -1, folder_train_images))

def get_metadata_object(folder_images, df_series_descriptions):
    '''
    for intially the train_images folder and train_series_descriptions.csv,
    later for the test_images folder and test_series_description.csv
    '''
    
    # a list of tuples like (study_id, study_id's path location)
    images_study_id_dirs = [(int(study_id),    # integer the study_id
                                   f"/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images/{study_id}") 
                                  for study_id in folder_images]

    # convert the list of tuples into dictionary/metadata
    metadata_object = {study_id: {'study_id_folder_path': path, 
                                    'SeriesInstanceUIDs': [],
                                    'SeriesDescriptions': []
                                 }
                       for study_id, path in images_study_id_dirs
                      }

    # remove all the .DS files/folders (MacOS) from SeriesInstanceUIDs or series_ids folders
    # then put the names for series_ids as well as the corresponding description to metadata
    for study_id in tqdm(metadata_object):

        # SERIES_ID
        # get all series directories/folders inside each study_id directory
        series_ids_dirs = os.listdir(metadata_object[study_id]['study_id_folder_path'])
        filtered_series_ids_dirs = [int(x) for x in series_ids_dirs if x.find('.DS') == -1]   # integer(series_id)
        # put to metadata_object
        metadata_object[study_id]['SeriesInstanceUIDs'] = filtered_series_ids_dirs

        # SERIES_DESCRIPTIONS
        series_desc_df = df_series_descriptions[df_series_descriptions.study_id==study_id][["series_id", "series_description"]]
        for series_id in metadata_object[study_id]['SeriesInstanceUIDs']:

            series_desc_list = series_desc_df[series_desc_df.series_id==series_id].series_description.values

            if len(series_desc_list) == 0:
                metadata_object[study_id]['SeriesDescriptions'].append("")
            else:
                metadata_object[study_id]['SeriesDescriptions'].append(series_desc_list[0])
                
    return metadata_object

metadata_object_train = get_metadata_object(folder_train_images, df_train_series_descriptions)





def get_series_metadata_object_given_study_id(metadata_object, study_id):
    """
    one study_id can have many series => this func gets all series metadata for this 
    particular study_id
    """
    
    metadata_for_study_id = metadata_object[study_id]
    
    series_metadata_object_given_study_id = {}

    for idx, series_id in enumerate(metadata_for_study_id["SeriesInstanceUIDs"]):

        # create bases for the series_images_metadata: each series contains desc and images files
        series_metadata_object_given_study_id[series_id] = {'image_series_description': metadata_for_study_id["SeriesDescriptions"][idx], 
                                                            'image_files': []
                                                            }

        # glob for patten matching as we want to get image files ending with .dcm
        folder_path_study_id = metadata_for_study_id["study_id_folder_path"]
        # rmb: "SeriesInstancesUIDs" is just a list of images and series_id is the actual series-id
        image_files = glob.glob(f"{folder_path_study_id}/{series_id}/*.dcm")


        # inside image_files, create a metadata for id_image and corresponding dicom readable image file
        sorted_image_files = sorted(image_files, key = lambda x: int(x.split('/')[-1].replace('.dcm', '')))

        
        # iterate through all image files (sorting to make sense, not very necessarily)
        for image_file in sorted_image_files:
            
            dicom_image_id = image_file.split('/')[-1].replace(".dcm", '')
            dicom_image_read = pydicom.dcmread(image_file)
            
            # metadata for one image instance
            one_image_metadata = {"SOPInstanceUID": dicom_image_id,      # id of the dicom image instance file
                                  "dicom_image_file": dicom_image_read}  # actual read of the dcm instance file

            # append this image_metadata to list of image_files
            series_metadata_object_given_study_id[series_id]["image_files"].append(one_image_metadata)

    return series_metadata_object_given_study_id
    
    
    
    
    
    
def display_images_given_study_id(metadata_object, study_id): 
    '''
    inside there is another function specifically for this function
    '''
    
    # view images for this study_id = 4003253 for a particular series_description
    def display_images(image_files, series_description, max_images_per_row=5):

        # grid for display
        num_images = len(image_files)
        num_rows = (num_images + max_images_per_row - 1) // max_images_per_row  # ceiling division (ignore the remainder, extra)

        # subplot grid
        fig, axes = plt.subplots(num_rows, max_images_per_row, figsize=(10, 1.5 * num_rows))

        # flatten axes for easy looping if there are multiple rows
        if num_rows > 1:
            axes = axes.flatten()
        else:
            axes = [axes] # iterable for consistency

        # plot each image
        for idx, image_file in enumerate(image_files):
            ax = axes[idx]
            ax.imshow(image_file, cmap='gray') # Assuming grayscale for simplicity, change cmap as needed
            ax.axis('off')

        # turn off unused subplots
        for idx in range(num_images, len(axes)):
            axes[idx].axis("off")

        fig.suptitle(series_description, fontsize=12)
        plt.show()
        
        
        
    series_metadata_object = get_series_metadata_object_given_study_id(metadata_object, study_id)
    
    for series_id in series_metadata_object:

        series_description = series_metadata_object[series_id]["image_series_description"]

        # get the image_files for this particular series_id
        image_files = series_metadata_object[series_id]["image_files"]

        # get the dicom files to a list
        dicom_images = []
        for image_metadata in image_files:
            dicom_image = image_metadata['dicom_image_file'].pixel_array
            dicom_images.append(dicom_image)

        # display images for each series
        display_images(dicom_images, series_description)
        
        

        
        
        
coord_df = pd.read_csv('/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_label_coordinates.csv')
# add the columns in coordinates regarding condition+level to match with train_df

cols = ['condition', 'level']
coord_df['m_condition'] = coord_df[cols].apply(lambda row: ' '.join(row.values.astype(str)).lower(), axis=1)
coord_df['m_condition'] = coord_df['m_condition'].str.replace(r'[ /]', '_', regex=True)

train_df = pd.read_csv('/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train.csv')



def display_images_with_coord(metadata_object, study_id):
    '''
    again, there is an inner function to display imges
    '''
    
    # this function is for one coord, one image, and one title/description/condition/level)
    def display_image_with_coord(center_coord, image_instance_dicom, title):
        '''
        coord_entry is a particular coord for a particular image 
        image_meta is the image that specifically in the series, containing SOPInstanceUID and dicom_image_file, i.e. 
        {2448190387: {'image_series_description': 'Axial T2',
                      'image_files': [{'SOPInstanceUID': '1',
                                        'dicom_image_file': Dataset.file_meta -------------------------------
                                        (0002, 0001) File Meta Informa
        '''
        radius = 10
        color = (255, 0, 0)
        thickness = 2

        # for what?
        image_normalized = cv2.normalize(image_instance_dicom, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        # circling
        image_circle = cv2.circle(image_normalized.copy(), center_coord, radius, color, thickness)

        # convert image from BGR to RGB for correct color display in matplotlib
        image_circle = cv2.cvtColor(image_circle, cv2.COLOR_BAYER_BG2BGR)

        # display
        plt.imshow(image_circle)
        plt.axis('off')
        plt.title(title)
        plt.show()
    
    
    
    # example_study_id = 4003253
    series_meta =  get_series_metadata_object_given_study_id(metadata_object, study_id) # series_meta all based on given study_id
    example_train_df = train_df[train_df.study_id==study_id]
    example_coord_entries_df = coord_df[coord_df.study_id==study_id]

    for d, coord_entry in example_coord_entries_df.iterrows():

        center_coord = (int(coord_entry['x']), int(coord_entry['y']))

        # search for the image available for coord in all imgs in series
        image_meta_instances_list = series_meta[coord_entry.series_id]["image_files"]
        # check matching
        image_instance_id = coord_entry.instance_number

        for image_instance in image_meta_instances_list:

            if int(image_instance["SOPInstanceUID"]) == int(image_instance_id):

                image_instance_dicom = image_instance["dicom_image_file"].pixel_array

                severity = train_df.loc[example_train_df.index[0], coord_entry.m_condition]
                title = f'image_instance_id: {image_instance_id} \n {coord_entry.m_condition} - severity: {severity}'

                display_image_with_coord(center_coord, image_instance_dicom, title)

100%|██████████| 1975/1975 [00:05<00:00, 384.23it/s]


# 1. Preprocess train.csv & train_label_coordinates.csv

In [21]:
df_train = pd.read_csv("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train.csv")

# fill missing values with negative so that function can ignore them when calculating the loss and score
df_train = df_train.fillna(0)

# label encoder df_train
labels_encoded = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}
for col in df_train:
    if col == 'study_id':
        df_train['study_id'] = df_train['study_id'].astype(str)
    else:
        df_train[col] = df_train[col].map(labels_encoded)

df_train = df_train.fillna(0)
print(df_train.shape)
df_train.head(5)

(1975, 26)


Unnamed: 0,study_id,spinal_canal_stenosis_l1_l2,spinal_canal_stenosis_l2_l3,spinal_canal_stenosis_l3_l4,spinal_canal_stenosis_l4_l5,spinal_canal_stenosis_l5_s1,left_neural_foraminal_narrowing_l1_l2,left_neural_foraminal_narrowing_l2_l3,left_neural_foraminal_narrowing_l3_l4,left_neural_foraminal_narrowing_l4_l5,...,left_subarticular_stenosis_l1_l2,left_subarticular_stenosis_l2_l3,left_subarticular_stenosis_l3_l4,left_subarticular_stenosis_l4_l5,left_subarticular_stenosis_l5_s1,right_subarticular_stenosis_l1_l2,right_subarticular_stenosis_l2_l3,right_subarticular_stenosis_l3_l4,right_subarticular_stenosis_l4_l5,right_subarticular_stenosis_l5_s1
0,4003253,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
1,4646740,0.0,0.0,1.0,2.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,2.0,0.0,0.0,1.0,1.0,1.0,0.0
2,7143189,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,8785691,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,10728036,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


In [22]:
df_train_label_coords = pd.read_csv('/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_label_coordinates.csv')

# 
df_train_label_coords.condition = df_train_label_coords.condition.map(lambda x: str.lower(x).replace(' ', '_'))
df_train_label_coords.level = df_train_label_coords.level.map(lambda x: str.lower(x).replace('/', '_'))

# map series_id to series_desc, then drop series_id
dict_series_id_desc = {}
for v in metadata_object_train.values():
    series = v["SeriesInstanceUIDs"]
    series_descs = v["SeriesDescriptions"]
    for serie, desc in zip(series, series_descs):
        dict_series_id_desc[serie] = desc.replace(' ', '_').replace('/', '_')
df_train_label_coords["series_desc"] = df_train_label_coords.series_id.map(dict_series_id_desc).fillna(np.nan)
df_train_label_coords = df_train_label_coords.drop('series_id', axis=1)

# m_condition
df_train_label_coords["condition_level"] = df_train_label_coords["condition"] + '_' + df_train_label_coords["level"]

long_df_train = pd.melt(df_train, id_vars=['study_id'], var_name='condition_level', value_name='severity').sort_values(by='study_id')

In [23]:
CONDITIONS = ['spinal_canal_stenosis', 
              'left_neural_foraminal_narrowing', 'right_neural_foraminal_narrowing',
              'left_subarticular_stenosis', 'right_subarticular_stenosis']

LEVELS = ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']

# 2. Preprocess png images

Unzip output data (png images) from my '/kaggle/input/rsna-making-dataset-png' notebook.

**Run the unzip code once!**

In [24]:
# run once:
!unzip -q /kaggle/input/rsna-making-dataset-png/_output_.zip

# Configs

In [25]:
torch.cuda.is_available()

True

In [26]:
## True -> run normally, False -> debug mode, with lesser computing cost
NOT_DEBUG = True

OUTPUT_DIR = f'my_results'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# set computation device
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# parallel processing: no. CPU cores to use for data loading
N_WORKERS = os.cpu_count() 


##########################################
########### TRANSFORMATION ###############
##########################################

# if Automatic Mixed Precision (AMP) is used, 
# beneficial for certain GPUs like T4 or newer
USE_AMP = True
SEED = 123

# Image configurations
IMG_SIZE = [512, 512]
IMG_HEIGHT, IMG_WIDTH = 512, 512

# # data augmentation:
AUG_PROB = 0.75 # probability of applying data aug
# # augmentation will be applied to 75% of images during training, improve generalization
AUG = True      # whether to use data augmentation


####################################
############# MODELS ###############
####################################

# cross-validation:
N_FOLDS = 5 if NOT_DEBUG else 2 # use 5 folds in normal mode, 2 in debug mode (faster)

# training configs
EPOCHS = 5 if NOT_DEBUG else 2

# Model I/O
IN_CHANS = 30  # no. input channels (typically 1 for grayscale, 3 for RGB)
N_LABELS = 25  # no. labels: 3 conditions, 2 has left and right = +1 = (2*2+1) * 5 levels = 25 
N_CLASSES = 3 * N_LABELS # classes (for prediction): severity (normal/mid, moderate, severe)
# here, labels might be misleading, but they refer to the object/observation
# that we want to predict the severity is => classify the condition+level

# pretrained, finetuned model name - Efficient Net family of models
if NOT_DEBUG:
    MODEL_NAME = "resnet50.a1_in1k"
else:
    MODEL_NAME = "resnet50.a1_in1k" 
    # (smaller no. of params)
# ns: stands for "Noisy Student," which is a training technique used to improve the model's performance by adding noise to the student model during training.
# jft: indicates that the model has been pretrained on the JFT dataset, which is a large dataset developed by Google containing around 300 million images with 18,000 labels.
# in1k: indicates that the model has also been fine-tuned or evaluated on the ImageNet-1k dataset, which is a standard benchmark in the field of computer vision containing 1,000 classes.
# the names come from timm library


###############################
########### Tuning ############
###############################
    
# Steps to accumulate gradients before updating the model weights
GRAD_ACC = 2 
# instead of updating the model weights after every batch, 
# the gradients are accumulated over several batches (i.e. 2 batches), 
# and then the model weights are updated once => not increase GPU usage

# size = no. of samples in a batch
# Target batch size
TGT_BATCH_SIZE = 32 # total number of samples you want to process before updating the model weights
# Batch size per gradient accumulation step
BATCH_SIZE = TGT_BATCH_SIZE // GRAD_ACC  # for now: = 16 samples/batch
# This means the model will process 16 samples at a time, 
# accumulate the gradients, and after processing two such batches (total 32 samples), 
# it will update the model weights.

# maximum norm for gradient clipping => not used
MAX_GRAD_NORM = None
# Gradient clipping is a technique to prevent the exploding gradient problem 
# by capping (clipping) the gradients during backpropagation to a maximum norm.
# but if not used, no need to stablize when computed gradient norm exceeds some number

# no. epochs with no improvement after which training will be stopped early
EARLY_STOPPING_EPOCH = 3

# learning rate (base = 2e-4), but adjusted based on the batch size
LR = 2e-4 * TGT_BATCH_SIZE / 32   # commonly start at 2e-4
# adjustment is needed b/c, in a case that needs a larger batch size 
# => each training step updates the  model with more data 
# => gradient estimates are more accurate => LR is scaled up proportionally

# weight decay for regularization => to prevent overfitting (= decrease variance)
# penalize large weights in model, introduce more bias, and aim for more generalizble
# also not rely too much on any particular weight, promote more evenly distributed weights
# in the loss function => reduce weight vector towards 0 when backprop
# example like squared L2 norm of the weights
WD = 1e-2 # commonly used starting point in practice
# example: optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


###############################
########## SEEDS ##############
###############################
def set_random_seed(seed, deterministic: bool = False):
    """Set seeds for all packages needed"""
    random.seed(seed)
    np.random.seed(seed)
    # for hash-based randomization in Python operations like dictionary keys:
    os.environ["PYTHONHASHSEED"] = str(seed)
    # PyTorch’s CPU-based random number generator:
    # torch.manual_seed(seed) 
    # PyTorch’s GPU-based random number generator:
    # torch.cuda.manual_seed(seed)  # type: ignore
    
    # enables the cuDNN auto-tuner to find the best algorithms for hardware
    # speed up training on GPUs by optimizing performance => require deterministic
    torch.backends.cudnn.benchmark = True
    # cuDNN should use deterministic algorithms for convolutions
    torch.backends.cudnn.deterministic = deterministic  # make sure no deterministic behaviors
    # non-deterministic in this case b/c want to avoid performance or memory issue
    # due to the use of less optimized, deterministic algo (when search for one in
    # benchmark)
    
set_random_seed(SEED)

# 3. Define Dataset + Transforms + DataLoader

https://pytorch.org/tutorials/recipes/recipes/custom_dataset_transforms_loader.html

In [27]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import random

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

<contextlib.ExitStack at 0x7ee1866e0520>

In [28]:
class LumbarDataset(Dataset):
    
    """Lumbar Degenerative Conditions dataset."""

    def __init__(self, df, transform=None, phase='train'):
        self.df = df
        self.transform = transform # from PyTorch data class and transforms torchvision library
        self.phase = phase
    
    def __len__(self):
        return len(self.df)
    
    def _load_image(self, folder_path, file_name):
        """
        loads and converts a single image to grayscale and numpy.
        """
        try:
            img_path = os.path.join(folder_path, file_name)
            img = Image.open(img_path).convert('L')  # grayscale
            img_np = np.array(img, dtype=np.uint8)  # convert to numpy array
            return img_np
        except Exception as e:
            print(f"error loading image: {img_path}, {str(e)}")
            return np.zeros(self.img_size, dtype=np.uint8)  # return blank image on error
    
    # tough one:
    def __getitem__(self, idx_in_train_df):
        """
        get images given idx_in_train_df
        """
        
        # each study_id has 30 images/channels, each image is 512x512 on top of each other (assuming grayscale!)
        stack_2d_images = np.zeros((512, 512, IN_CHANS), dtype=np.uint8) # placeholder for image data, 
        # so later we can populate it with images from 3 methods!
        # stack_2d_images[512, 512, 0] would be img1
        # stack_2d_images[512, 512, 1] would be img2
        # the ... reference the entire 512 x 512 image in the i-th channel
        
        # rmb: df here is the train_df (not the one combined with coord)
        # for each specific study_id
        row = self.df.iloc[idx_in_train_df]
        study_id = int(row['study_id'])
        # numpy array consisted of all the labels for all the conditions for this study_id:
        conditions_levels = row[1:].values.astype(np.int8) 
        
        # get folder
        folder_Sagittal_T1 = f'./train_images_png/{study_id}/Sagittal_T1/'
        folder_Sagittal_T2_STIR = f'./train_images_png/{study_id}/Sagittal_T2_STIR/'
        folder_Axial_T2 = f'./train_images_png/{study_id}/Axial_T2/'
        
        base_path = f'./train_images_png/{study_id}'
        
        # Sagittal T1
        try:
            for i, fname in enumerate(os.listdir(folder_Sagittal_T1)):
                img_np = self._load_image(folder_Sagittal_T1, fname)
                stack_2d_images[..., i] = img_np.astype(np.uint8)  # first ten channels/layers in stack_2d_images
        except:
            pass
            
        # Sagittal T2_STIR
        try:
            for i, fname in enumerate(os.listdir(folder_Sagittal_T2_STIR)):
                img_np = self._load_image(folder_Sagittal_T2_STIR, fname)
                stack_2d_images[..., i+10] = img_np.astype(np.uint8) # next ten channels/layers in stack_2d_images
        except:
            pass
            
            
        # Axial T2
        # select random 10 images from Axial T2
        try:
            axial_fnames = random.sample(os.listdir(folder_Axial_T2), 10)
            for i, fname in enumerate(axial_fnames):
                img_np = self._load_image(folder_Axial_T2, fname)
                stack_2d_images[..., i+20] = img_np.astype(np.uint8) # last ten channels/layers in stack_2d_images 
        except:
            pass
            
        # print("shape of the image stack:", stack_2d_images.shape)  # 512, 521, 30
            
        # transforms
        # transforms if given the transform functions are provided to do so
        if self.transform is not None:
            stack_2d_images = self.transform(image=stack_2d_images)['image']

        # transpose the stack from (height, width, channel) into (channels, height, width)
        stack_2d_images = stack_2d_images.transpose(2, 0, 1)
        
        # lastly, for this study_id, we get:
        return stack_2d_images, conditions_levels

The data augmentation code is from this brilliant [notebook](https://www.kaggle.com/code/haqishen/1st-place-soluiton-code-small-ver). The notebook also works on diagnosis medical images.

In [29]:
import albumentations as A

transforms_train = A.Compose([
    # A.Transpose(p=0.5),
    # A.VerticalFlip(p=0.5),
    # A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=0.7),
    A.OneOf([
        A.MotionBlur(blur_limit=5),
        A.MedianBlur(blur_limit=5),
        A.GaussianBlur(blur_limit=5),
        A.GaussNoise(var_limit=(5.0, 30.0)),
    ], p=0.7),

    A.OneOf([
        A.OpticalDistortion(distort_limit=1.0),
        A.GridDistortion(num_steps=5, distort_limit=1.),
        A.ElasticTransform(alpha=3),
    ], p=0.7),

    # A.CLAHE(clip_limit=4.0, p=0.7),
    # A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
    
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.7),
    A.Resize(IMG_HEIGHT, IMG_WIDTH),
    
    # replace Cutout with CoarseDropout
    A.CoarseDropout(max_holes=16, max_height=64, max_width=64, min_holes=1, min_height=8, min_width=8, p=0.7),    
    A.Normalize(mean=0.5, std=0.5) # >??
])

transforms_val = A.Compose([
    A.Resize(IMG_HEIGHT, IMG_WIDTH),
    A.Normalize(mean=0.5, std=0.5)
])

if not NOT_DEBUG or not AUG:
    transforms_train = transforms_val

In [30]:
# train_data = LumbarDataset(df_train, phase='train', transform=transforms_train)
# for idx, (stack, conditions_labels) in zip(range(5), train_data):
#     print(f"id in df_train: {idx}")
#     print(f"conditions and labels: {conditions_labels, conditions_labels.shape}")
#     print(f"stack 2D images for this id")
#     print(f"stack shape: {stack.shape}")
#     img_displayable = stack.transpose(0,2,3,1)[0, :, :, :3]
#     # to display image, from (1, channel, height, width) into (1, height, width, channel)
#     # then pick the first image in the stack 
#     # [0, :, :, :3] = [first_image, all values for height, all values for width, first 3 channels RGB of the image]
#     # if [0,512,512,3] is trying to select a specific position within the array. Specifically, trying to index at position 0 in the first dimension, 512 in the second dimension, 512 in the third dimension, and 3 in the fourth dimension
#     img_displayable = (img_displayable + 1) / 2
#     plt.imshow(img_displayable)
#     plt.show()
#     print()
    
# plt.close()

# 4. Define Model + Training

In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import timm # pretrained models!

RESNET50_MODEL_NAME = "resnet50"
NO_CLASSES = 75 # 5 conditions x 5 levels x 3 severity

In [32]:
import torch
import torch.nn as nn
import timm

class BuildModel(nn.Module):
    def __init__(self, model, 
                 in_c=30, no_conditions_labels=NO_CLASSES,  # 75
                 pretrained=True, features_only=False):
        
        super(BuildModel, self).__init__()
        
        # Create the model with specified parameters
        self.model = timm.create_model(model,
                                        pretrained=pretrained,
                                        features_only=features_only,
                                        in_chans=in_c,
                                        num_classes=no_conditions_labels,
                                        global_pool='avg')

    def forward(self, x):
        y = self.model(x)
        return y

    # no backward! because use pretrained parameters

In [33]:
model_draft = BuildModel("resnet50", pretrained=False)
random_input = torch.randn(2, 30, 512, 512)
pred_prob = model_draft(random_input)
probabilities = F.softmax(pred_prob, dim=1)
pred_prob.shape, probabilities

(torch.Size([2, 75]),
 tensor([[0.0156, 0.0098, 0.0123, 0.0133, 0.0201, 0.0132, 0.0128, 0.0131, 0.0141,
          0.0112, 0.0114, 0.0120, 0.0135, 0.0142, 0.0139, 0.0167, 0.0090, 0.0114,
          0.0287, 0.0130, 0.0087, 0.0176, 0.0170, 0.0117, 0.0102, 0.0098, 0.0239,
          0.0173, 0.0097, 0.0114, 0.0089, 0.0132, 0.0097, 0.0167, 0.0150, 0.0135,
          0.0145, 0.0085, 0.0096, 0.0172, 0.0105, 0.0125, 0.0087, 0.0114, 0.0148,
          0.0116, 0.0158, 0.0172, 0.0132, 0.0169, 0.0115, 0.0147, 0.0181, 0.0095,
          0.0115, 0.0131, 0.0154, 0.0111, 0.0129, 0.0117, 0.0117, 0.0185, 0.0173,
          0.0112, 0.0138, 0.0139, 0.0117, 0.0128, 0.0119, 0.0109, 0.0129, 0.0121,
          0.0128, 0.0113, 0.0113],
         [0.0154, 0.0098, 0.0120, 0.0136, 0.0203, 0.0130, 0.0124, 0.0130, 0.0142,
          0.0118, 0.0120, 0.0127, 0.0146, 0.0144, 0.0136, 0.0165, 0.0095, 0.0112,
          0.0276, 0.0133, 0.0088, 0.0176, 0.0177, 0.0111, 0.0103, 0.0100, 0.0241,
          0.0177, 0.0096, 0.0116, 0.0085,

Training

In [34]:
from collections import OrderedDict
from tqdm import tqdm
import math
from glob import glob

from sklearn.model_selection import KFold
from torch.optim import AdamW # optimizer for deep learning, more details to learn
from transformers import get_cosine_schedule_with_warmup # improve convergence, details to learn

In [35]:
train_data = LumbarDataset(df_train, phase='train', transform=transforms_train)
train_loader = DataLoader(
                        train_data,  
                        batch_size=1,        # no. samples to load in each batch
                        shuffle=False,       # no shuffle
                        pin_memory=True,     # copy tensors into CUDA pinned memory
                        drop_last=False,     # drop the last incomplete batch
                        num_workers=0        # no parallel subprocesses 
                        )

In [36]:
autocast = torch.cuda.amp.autocast(enabled=USE_AMP, dtype=torch.half)
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP, init_scale=4096)

df_train = df_train.copy() # since we want validation inside training too 

train_ds = LumbarDataset(df_train, phase='train', transform=transforms_train)
train_dl = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
    num_workers=N_WORKERS
)

# intialize model, optimizer, and loss function
model = BuildModel("resnet50", IN_CHANS, N_CLASSES, pretrained=True)
model.to(device)

optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WD)

# scheduler settings
warmup_steps = EPOCHS / 10 * len(train_dl) // GRAD_ACC
num_total_steps = EPOCHS * len(train_dl) // GRAD_ACC
num_cycles = 0.475
scheduler = get_cosine_schedule_with_warmup(optimizer,
                                             num_warmup_steps=warmup_steps,
                                             num_training_steps=num_total_steps,
                                             num_cycles=num_cycles)

# Loss function
weights = torch.tensor([1.0, 2.0, 4.0])  # Adjust based on your dataset
criterion = nn.CrossEntropyLoss(weight=weights.to(device))

# Training loop
best_loss = float('inf')  # Set best loss to infinity initially
es_step = 0

for epoch in range(1, EPOCHS + 1):
    print(f'Starting epoch {epoch}')
    model.train()
    total_loss = 0

    with tqdm(train_dl, leave=True) as pbar:
        optimizer.zero_grad()
        for idx, (x, t) in enumerate(pbar):
            x = x.to(device)
            t = t.to(device)
            
#             print("x: ", x)
#             print("t: ", t)

            with autocast:
                loss = 0
                y = model(x)
#                print(y)
                for col in range(N_LABELS):
                    pred = y[:, col * 3:col * 3 + 3]
                    gt = t[:, col].long() 

#                     print("target values:", gt)
#                     print("predictions shape:", pred.shape)

                    loss += criterion(pred, gt) / N_LABELS

                total_loss += loss.item()
                if GRAD_ACC > 1:
                    loss /= GRAD_ACC

            if not math.isfinite(loss):
                print(f"Loss is {loss}, stopping training")
                sys.exit(1)

            pbar.set_postfix(loss=f'{loss.item() * GRAD_ACC:.6f}')
            scaler.scale(loss).backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM or 1e9)

            if (idx + 1) % GRAD_ACC == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                if scheduler is not None:
                    scheduler.step()

    train_loss = total_loss / len(train_dl)
    print(f'Training Loss: {train_loss:.6f}')


    if train_loss < best_loss:
        print(f'Epoch {epoch}: Best loss updated from {best_loss:.6f} to {train_loss:.6f}')
        best_loss = train_loss
        fname = f'{OUTPUT_DIR}/best_model.pt'
        torch.save(model.state_dict(), fname)

print('Training complete.')

Starting epoch 1


100%|██████████| 123/123 [06:58<00:00,  3.41s/it, loss=0.730773]


Training Loss: 0.918495
Epoch 1: Best loss updated from inf to 0.918495
Starting epoch 2


100%|██████████| 123/123 [06:43<00:00,  3.28s/it, loss=0.680334]


Training Loss: 0.772411
Epoch 2: Best loss updated from 0.918495 to 0.772411
Starting epoch 3


100%|██████████| 123/123 [06:36<00:00,  3.23s/it, loss=0.698507]


Training Loss: 0.746528
Epoch 3: Best loss updated from 0.772411 to 0.746528
Starting epoch 4


100%|██████████| 123/123 [06:45<00:00,  3.29s/it, loss=0.809720]


Training Loss: 0.735933
Epoch 4: Best loss updated from 0.746528 to 0.735933
Starting epoch 5


100%|██████████| 123/123 [06:37<00:00,  3.23s/it, loss=0.770923]


Training Loss: 0.730947
Epoch 5: Best loss updated from 0.735933 to 0.730947
Training complete.


# 5. Test images

The new test data class processing is taken/modified based on this notebook: https://www.kaggle.com/code/itsuki9180/rsna2024-lsdc-submission-baseline

In [37]:
rd = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

import glob
import re

In [38]:
class RSNA24TestDataset(Dataset):
    def __init__(self, df, study_ids, phase='test', transform=None):
        self.df = df
        self.study_ids = study_ids
        self.transform = transform
        self.phase = phase
    
    def __len__(self):
        return len(self.study_ids)
    
    def get_img_paths(self, study_id, series_desc):
        pdf = self.df[self.df['study_id']==study_id]
        pdf_ = pdf[pdf['series_description']==series_desc]
        allimgs = []
        for i, row in pdf_.iterrows():
            pimgs = glob.glob(f'{rd}/test_images/{study_id}/{row["series_id"]}/*.dcm')
            pimgs = sorted(pimgs, key=natural_keys)
            allimgs.extend(pimgs)
            
        return allimgs
    
    def read_dcm_ret_arr(self, src_path):
        dicom_data = pydicom.dcmread(src_path)
        image = dicom_data.pixel_array
        image = (image - image.min()) / (image.max() - image.min() + 1e-6) * 255
        img = cv2.resize(image, (IMG_SIZE[0], IMG_SIZE[1]),interpolation=cv2.INTER_CUBIC)
        assert img.shape==(IMG_SIZE[0], IMG_SIZE[1])
        return img

    def __getitem__(self, idx):
        x = np.zeros((IMG_SIZE[0], IMG_SIZE[1], IN_CHANS), dtype=np.uint8)
        st_id = self.study_ids[idx]        
        
        # Sagittal T1
        allimgs_st1 = self.get_img_paths(st_id, 'Sagittal T1')
        if len(allimgs_st1)==0:
            print(st_id, ': Sagittal T1, has no images')
        
        else:
            step = len(allimgs_st1) / 10.0
            st = len(allimgs_st1)/2.0 - 4.0*step
            end = len(allimgs_st1)+0.0001
            for j, i in enumerate(np.arange(st, end, step)):
                try:
                    ind2 = max(0, int((i-0.5001).round()))
                    img = self.read_dcm_ret_arr(allimgs_st1[ind2])
                    x[..., j] = img.astype(np.uint8)
                except:
                    print(f'failed to load on {st_id}, Sagittal T1')
                    pass
            
        # Sagittal T2/STIR
        allimgs_st2 = self.get_img_paths(st_id, 'Sagittal T2/STIR')
        if len(allimgs_st2)==0:
            print(st_id, ': Sagittal T2/STIR, has no images')
            
        else:
            step = len(allimgs_st2) / 10.0
            st = len(allimgs_st2)/2.0 - 4.0*step
            end = len(allimgs_st2)+0.0001
            for j, i in enumerate(np.arange(st, end, step)):
                try:
                    ind2 = max(0, int((i-0.5001).round()))
                    img = self.read_dcm_ret_arr(allimgs_st2[ind2])
                    x[..., j+10] = img.astype(np.uint8)
                except:
                    print(f'failed to load on {st_id}, Sagittal T2/STIR')
                    pass
            
        # Axial T2
        allimgs_at2 = self.get_img_paths(st_id, 'Axial T2')
        if len(allimgs_at2)==0:
            print(st_id, ': Axial T2, has no images')
            
        else:
            step = len(allimgs_at2) / 10.0
            st = len(allimgs_at2)/2.0 - 4.0*step
            end = len(allimgs_at2)+0.0001

            for j, i in enumerate(np.arange(st, end, step)):
                try:
                    ind2 = max(0, int((i-0.5001).round()))
                    img = self.read_dcm_ret_arr(allimgs_at2[ind2])
                    x[..., j+20] = img.astype(np.uint8)
                except:
                    print(f'failed to load on {st_id}, Axial T2')
                    pass  
            
            
        if self.transform is not None:
            x = self.transform(image=x)['image']

        x = x.transpose(2, 0, 1)
                
        return x, str(st_id)
    
transforms_test = A.Compose([
    A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
    A.Normalize(mean=0.5, std=0.5)
])

df_test = pd.read_csv(f'/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_series_descriptions.csv')
study_ids_test = list(df_test['study_id'].unique())

test_ds = RSNA24TestDataset(df_test, study_ids_test, transform=transforms_test)
test_dl = DataLoader(
    test_ds, 
    batch_size=1, 
    shuffle=False,
    num_workers=N_WORKERS,
    pin_memory=True,
    drop_last=False
)

In [39]:
sample_sub = pd.read_csv("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/sample_submission.csv")
LABELS = list(sample_sub.columns[1:])
LABELS

['normal_mild', 'moderate', 'severe']

In [40]:
class RSNA24Model(nn.Module):
    def __init__(self, model_name, in_c=30, n_classes=75, pretrained=True, features_only=False):
        super().__init__()
        self.model = timm.create_model(
                                    model_name,
                                    pretrained=pretrained, 
                                    features_only=features_only,
                                    in_chans=in_c,
                                    num_classes=n_classes,
                                    global_pool='avg'
                                    )
    
    def forward(self, x):
        y = self.model(x)
        return y

# get the trained model with learned parameters
model_path = f'{OUTPUT_DIR}/best_model.pt'  
print(f'Loading model from {model_path}...')
model = RSNA24Model("resnet50", IN_CHANS, N_CLASSES, pretrained=False)
model.load_state_dict(torch.load(model_path))
model.eval()
model.half()  
model.to(device)

Loading model from my_results/best_model.pt...


RSNA24Model(
  (model): ResNet(
    (conv1): Conv2d(30, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act2): ReLU(inplace=True)
        (aa): Identity()
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, m

In [41]:
autocast = torch.cuda.amp.autocast(enabled=USE_AMP, dtype=torch.half)
y_preds = []
row_names = []

with tqdm(test_dl, leave=True) as pbar:
    with torch.no_grad():
        for idx, (x, si) in enumerate(pbar):
            x = x.to(device)
            pred_per_study = np.zeros((25, 3))
            
            for cond in CONDITIONS:
                for level in LEVELS:
                    row_names.append(si[0] + '_' + cond + '_' + level)
            
            with autocast:
                y = model(x)[0]
                for col in range(N_LABELS):
                    pred = y[col*3:col*3+3]
                    y_pred = pred.float().softmax(0).cpu().numpy()
                    pred_per_study[col] += y_pred 
                y_preds.append(pred_per_study)



100%|██████████| 1/1 [00:01<00:00,  1.99s/it]


In [44]:
sub = pd.DataFrame()
sub['row_id'] = row_names
sub[LABELS] = y_preds[0]  

sub.to_csv('submission.csv', index=False)
pd.read_csv('submission.csv').head()

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.555168,0.237786,0.207046
1,44036939_spinal_canal_stenosis_l2_l3,0.44905,0.303324,0.247626
2,44036939_spinal_canal_stenosis_l3_l4,0.371784,0.321507,0.306709
3,44036939_spinal_canal_stenosis_l4_l5,0.316991,0.285,0.398009
4,44036939_spinal_canal_stenosis_l5_s1,0.55771,0.229779,0.212511
