In [1]:
import os
import glob
import random
import math

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm import tqdm
from types import SimpleNamespace
import albumentations as A
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import pydicom


import torch.nn as nn
import torch.optim as optim

from torch.optim import lr_scheduler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import CosineAnnealingLR

from sklearn.model_selection import train_test_split
import bisect
import time

def set_seed(seed=1234):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

  data = fetch_version_info()


# Configurations

In [2]:
# Config
cfg= SimpleNamespace(
    img_dir= "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_images",
    device= torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    n_frames=3,
    epochs=10,
    lr=0.0005,
    batch_size=16,
    backbone="resnet18",
    seed= 0,
    model_dir = "/kaggle/working/",
    kernel_type = "resnet18",
    num_workers = 8,
    n_epochs = 5,
    init_lr =0.0005,
    CUDA_VISIBLE_DEVICES = "0",
    sag_axial_slices = 3,
    sag_2_slices =  1,
    sag_1_slices = 4,
    model_dir_point = "/kaggle/input/lumbar_spine_verterbrae_disc-detection/pytorch/1/1/resnet18_0.pt",
    model_dir_lsdd = "/kaggle/input/training-lsdd/resnet18_final_fold0.pth",
    
)
set_seed(seed=cfg.seed) # Makes results reproducable

In [3]:
def load_training_dataframe(cfg,isTrain=True):
    
    test_series_description = pd.read_csv('/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_series_descriptions.csv')

    sagtialla_t1_df = test_series_description[test_series_description['series_description'] == "Sagittal T1"]
    sagtialla_t1_df_u = sagtialla_t1_df.drop_duplicates(subset=['study_id']).reset_index(drop=True)
    
    axial_t2_df = test_series_description[test_series_description['series_description'] == "Axial T2"]
    axial_t2_df_u = axial_t2_df.drop_duplicates(subset=['study_id']).reset_index(drop=True)

    sagtialla_t2_df = test_series_description[test_series_description['series_description'] == "Sagittal T2/STIR"]
    sagtialla_t2_df_u = sagtialla_t2_df.drop_duplicates(subset=['study_id']).reset_index(drop=True)

    merge_sagital_t1 = sagtialla_t1_df_u
    merge_sagital_t1 = merge_sagital_t1.drop(columns=['series_description'])
    merge_sagital_t1.rename(columns={'series_id': 'series_id_sg1'}, inplace=True)
    
    merge_sagital_t2 = pd.merge(merge_sagital_t1,sagtialla_t2_df_u, on=["study_id"],how="inner" )
    merge_sagital_t2 = merge_sagital_t2.drop(columns=['series_description'])
    merge_sagital_t2.rename(columns={'series_id': 'series_id_sg2'}, inplace=True)

    merge_axial_t2 = pd.merge(merge_sagital_t2,axial_t2_df_u, on=["study_id"],how="inner" )
    merge_axial_t2 = merge_axial_t2.drop(columns=['series_description'])
    merge_axial_t2.rename(columns={'series_id': 'series_id_a2'}, inplace=True)

    return merge_axial_t2  
    

# Utils

In [4]:
def batch_to_device(batch, device, skip_keys=[]):
    batch_dict= {}
    for key in batch:
        if key in skip_keys:
             batch_dict[key]= batch[key]
        else:    
            batch_dict[key]= batch[key].to(device)
    return batch_dict

def visualize_prediction(batch, pred, epoch):
    
    mid= cfg.n_frames//2
    
    # Plot
    for idx in range(1):
    
        # Select Data
        img= batch["img"][idx, mid, :, :].cpu().numpy()*255
        cs_true= batch["label"][idx, ...].cpu().numpy()*256
        cs= pred[idx, ...].cpu().numpy()*256
                
        coords_list = [("TRUE", "lightblue", cs_true), ("PRED", "orange", cs)]
        text_labels = [str(x) for x in range(1,21)]
        
        # Plot coords
        fig, axes = plt.subplots(1, len(coords_list), figsize=(10,4))
        fig.suptitle("EPOCH: {}".format(epoch))
        for ax, (title, color, coords) in zip(axes, coords_list):
            ax.imshow(img, cmap='gray')
            ax.scatter(coords[0::2], coords[1::2], c=color, s=50)
            ax.axis('off')
            ax.set_title(title)

            # Add text labels near the coordinates
            for i, (x, y) in enumerate(zip(coords[0::2], coords[1::2])):
                if i < len(text_labels):  # Ensure there are enough labels
                    ax.text(x + 10, y, text_labels[i], color='white', fontsize=15, bbox=dict(facecolor='black', alpha=0.5))


        fig.suptitle("EPOCH: {}".format(epoch))
        plt.show()
#         plt.close(fig)
    return

def load_weights_skip_mismatch(model, weights_path, device):
    # Load Weights
    state_dict = torch.load(weights_path, map_location=device)
    model_dict = model.state_dict()
    
    # Iter models
    params = {}
    for (sdk, sfv), (mdk, mdv) in zip(state_dict.items(), model_dict.items()):
        if sfv.size() == mdv.size():
            params[sdk] = sfv
        else:
            print("Skipping param: {}, {} != {}".format(sdk, sfv.size(), mdv.size()))
    
    # Reload + Skip
    model.load_state_dict(params, strict=False)
    print("Loaded weights from:", weights_path)

# Create csv Dataset

In [5]:
train_df = load_training_dataframe(cfg)


In [6]:
train_df

Unnamed: 0,study_id,series_id_sg1,series_id_sg2,series_id_a2
0,44036939,2828203845,3844393089,3481971518


# Co-ordinate Prediction Dataset

In [7]:
class PreTrainDataset(torch.utils.data.Dataset):
    def __init__(self, study_ids,cfg,transform,isTrain = False,is_dataset_for_t1= False):
        self.cfg= cfg
        self.study_ids = study_ids
        self.transform = transform
        self.isTrain = isTrain
        self.is_dataset_for_t1 = is_dataset_for_t1

    
    def convert_to_8bit(self,x):
        lower, upper = np.percentile(x, (1, 99))
        x = np.clip(x, lower, upper)
        x = x - np.min(x)
        x = x / np.max(x) 
        return (x * 255).astype("uint8")


    def load_dicom_stack(self, dicom_folder, plane, reverse_sort=False):
        dicom_files = glob.glob(os.path.join(dicom_folder, "*.dcm"))
        dicoms = [pydicom.dcmread(f) for f in dicom_files]

        # Determine the plane for sorting (sagittal, coronal, axial)
        plane = {"sagittal": 0, "coronal": 1, "axial": 2}[plane.lower()]
        positions = np.asarray([float(d.ImagePositionPatient[plane]) for d in dicoms])

        # Sort DICOM files based on positions (reverse sort for axial plane if needed)
        idx = np.argsort(-positions if reverse_sort else positions)
        ipp = np.asarray([d.ImagePositionPatient for d in dicoms]).astype("float")[idx]

        # Get the shape of each pixel array (height, width)
        shapes = [d.pixel_array.shape for d in dicoms]
        # Check if all DICOM images have the same shape
        if len(set(shapes)) > 1:
            # There's a shape mismatch, find the minimum shape (height, width)
            min_shape = np.min(shapes, axis=0)
            # Resize images to the minimum shape
            resized_arrays = []
            for d in dicoms:
                img = d.pixel_array.astype("float32")
                if img.shape != tuple(min_shape):
                    resized_img = cv2.resize(img, (min_shape[1], min_shape[0]))  # Resize to (width, height)
                else:
                    resized_img = img  # No resizing needed
                resized_arrays.append(resized_img)

            # Stack the resized images along the first axis
            array = np.stack(resized_arrays)
        else:
            # If all shapes are the same, no resizing is needed
            array = np.stack([d.pixel_array.astype("float32") for d in dicoms])

        # Reorder the array according to the sorted positions
        array = array[idx]

        return {
            "array": self.convert_to_8bit(array),
            "positions": ipp,
            "pixel_spacing": np.asarray(dicoms[0].PixelSpacing).astype("float")
        }
    
    def pad_image(self, img):
        n= img.shape[0]
        if n >= self.cfg.n_frames:
            start_idx = (n - self.cfg.n_frames) // 2
            return img[start_idx:start_idx + self.cfg.n_frames,:, :]
        else:
            pad_left = (self.cfg.n_frames - n) // 2
            pad_right = self.cfg.n_frames - n - pad_left
            return np.pad(img, ((pad_left, pad_right),(0,0), (0,0)), 'constant', constant_values=0)
    
    def load_img(self, series_id):
        fname = self.load_dicom_stack(os.path.join(self.cfg.img_dir, str(series_id)), plane="sagittal")
        img= fname["array"]
        img= self.pad_image(img)
        img= np.transpose(img, (1,2, 0))
        img= self.transform(image=img)["image"]
        img= np.transpose(img, (2, 0, 1))
        img= (img / 255.0)
        return img
        
        
    def __getitem__(self, idx):
        d= self.study_ids.iloc[idx]
        if self.is_dataset_for_t1:
            series_id= "/".join([str(d.study_id),str(d.series_id_sg1)])      
        else:
            series_id= "/".join([str(d.study_id),str(d.series_id_sg2)])      
                
        img= self.load_img(series_id)
        if self.isTrain:
            return {
                'img': img, 
                'label': label,
            }
        else:
            return {
                'img': img 
            }
            
    
    def __len__(self,):
        return len(self.study_ids)
    

resize_transform_point= A.Compose([
A.LongestMaxSize(max_size=256, interpolation=cv2.INTER_CUBIC, always_apply=True),
A.PadIfNeeded(min_height=256, min_width=256, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0), always_apply=True),
])




ds = PreTrainDataset(train_df, cfg,resize_transform_point)    

# Plot a Single Sample
print("---- Sample Shapes -----")
for k, v in ds[0].items():
    print(k, v.shape)


---- Sample Shapes -----
img (3, 256, 256)


# Load Model

In [8]:
# Load backbone for RSNA 2024 task
model_path = "/kaggle/input/lumbar_spine_verterbrae_disc-detection/pytorch/1/1/resnet18_0.pt"
model = timm.create_model('resnet18', pretrained=False, num_classes=20)
model = model.to(cfg.device)
load_weights_skip_mismatch(model, model_path, cfg.device)

  state_dict = torch.load(weights_path, map_location=device)


Loaded weights from: /kaggle/input/lumbar_spine_verterbrae_disc-detection/pytorch/1/1/resnet18_0.pt


# Co-ordinate Inferencing

In [9]:
def coordinate_prediction(model,pred_dataloader,isSagitalT1 = False):
    predictions = []
    with torch.no_grad():
        model = model.eval()
        for batch in tqdm(pred_dataloader):
            batch = batch_to_device(batch, cfg.device)

            pred = model(batch["img"].float())
            pred = torch.sigmoid(pred)
            predictions.append(pred)
            
    data_list_cpu = [tensor.cpu().numpy() for tensor in predictions]
    combined_data = np.vstack(data_list_cpu)
    df_cords = pd.DataFrame(combined_data)
    if isSagitalT1:
        df_cords.to_csv("predicted_cordinates_sag_t1.csv",index=False)
    else:
        df_cords.to_csv("predicted_cordinates_sag_t2.csv",index=False)
    return df_cords

# Coordinate Prediction

In [10]:
train_ds = PreTrainDataset(train_df, cfg,resize_transform_point)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=False)

In [11]:
cordinates_t1 = coordinate_prediction(model,train_dl,isSagitalT1 = True)
cordinates_t2 = coordinate_prediction(model,train_dl,isSagitalT1 = False)


100%|██████████| 1/1 [00:01<00:00,  1.88s/it]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it]


In [12]:
cordinates_t1['study_id'] = train_df['study_id'].values
cordinates_t2['study_id'] = train_df['study_id'].values

In [13]:
cordinates_t1

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11,12,13,14,15,16,17,18,19,study_id
0,0.40336,0.31897,0.624557,0.36428,0.359472,0.430439,0.565006,0.454914,0.314881,0.540189,...,0.563507,0.322633,0.682666,0.545385,0.653523,0.385318,0.841513,0.562395,0.713457,44036939


# Load Utils for LSDD 

In [14]:
def get_padded_roi(orientation, numberOfImageFromCenter):
    # Calculate middle index
    middleImage = len(orientation["array"]) // 2

    # Calculate start and end indices
    start_idx = middleImage - numberOfImageFromCenter
    end_idx = middleImage + numberOfImageFromCenter + 1

    # Handle bounds
    array_length = orientation["array"].shape[0]  # Slicing along the first axis (number of images)

    # Ensure we don't go beyond the array's bounds
    start_pad = max(0, -start_idx)
    end_pad = max(0, end_idx - array_length)

    # Slice the valid part of the array along the first axis
    roi = orientation["array"][max(0, start_idx):min(array_length, end_idx)]

    # Pad with zeros if needed
    if start_pad > 0 or end_pad > 0:
        roi = np.pad(roi, ((start_pad, end_pad), (0, 0), (0, 0)), mode='constant', constant_values=0)

    return roi

In [15]:
def prepare_level_wise_axial(sagittal_img, imgsag_y_coord_to_axial_slice, coordinates,no_of_axial_slice = 3):
    h, w = sagittal_img.shape
    axial_list = []
    
    first_key = next(iter(imgsag_y_coord_to_axial_slice))
    # Get the first value
    first_value = imgsag_y_coord_to_axial_slice[first_key]

    keys = list(imgsag_y_coord_to_axial_slice.keys())
    keys.sort()

    for i in range(0, len(coordinates), 4):
        #print(i)
        category = coordinates[i:i+4]  # Extracting 4 elements at a time
        y= [category[1]*h,category[3]*h]
        minimum = math.floor(min(y[0],y[1]))
        maximum  = math.ceil(max(y[0],y[1]))
        filtered_keys = [k for k in imgsag_y_coord_to_axial_slice.keys() if minimum <= k <= maximum]
        
        if len(filtered_keys) >= no_of_axial_slice:
            # Use numpy to select 3 keys at uniform intervals
            selected_keys = np.linspace(0, len(filtered_keys)-1,no_of_axial_slice , dtype=int)
            selected_keys = [filtered_keys[i] for i in selected_keys]
        else:
            if len(filtered_keys) == 0:
                index = bisect.bisect_left(keys, minimum)
                # Check if we can find a nearest value less than current_value
                if index > 0:
                    filtered_keys.append(keys[index - 1])

                index = bisect.bisect_right(keys, maximum)
                if index < len(keys):
                    filtered_keys.append(keys[index])  # Return the nearest value greater than current_value
                    

            filtered_keys.extend([filtered_keys[-1]] * (no_of_axial_slice - len(filtered_keys)))
            selected_keys = filtered_keys
        
        if len(selected_keys) ==  0:
            result  = np.zeros((no_of_axial_slice, first_value.shape[0],first_value.shape[1]))
        else:
            selected_axial =[imgsag_y_coord_to_axial_slice.get(k) for k in selected_keys]
            result  = np.array(selected_axial)

        roi_copy_list  = []
        for j in range(result.shape[0]):
            roi_copy_list.append(resize_transform(image=result[j])["image"])
            
        axial_list.append(np.array(roi_copy_list))
        
    return  np.array(axial_list)

In [16]:
resize_transform= A.Compose([
    A.LongestMaxSize(max_size=256, interpolation=cv2.INTER_CUBIC, always_apply=True),
    A.PadIfNeeded(min_height=256, min_width=256, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0), always_apply=True),
    A.Normalize()
])

def angle_of_line(x1, y1, x2, y2):
    return math.degrees(math.atan2(-(y2-y1), x2-x1))

def crop_between_keypoints(roi,img, keypoint1, keypoint2):
    h, w = img.shape
    x1, y1 = int(keypoint1[0]), int(keypoint1[1])
    x2, y2 = int(keypoint2[0]), int(keypoint2[1])

    # Calculate bounding box around the keypoints
    left = int(min(x1, x2) - (w * 0.1))
    right = int(max(x1, x2) + (w * 0.1))
    top = int(min(y1, y2) - (h * 0.07))
    bottom = int(max(y1, y2) + (h * 0.1))
    
    left = max(0, left)
    right = min(w, right)
    top = max(0, top)
    bottom = min(h, bottom)
    # Crop the image
    return img[top:bottom, left:right],roi[:,top:bottom, left:right]

def plot_5_crops(orientataion,coords_temp,numberOfImageFromCenter = 3):
    # Create a figure and axis for the grid
    #fig = plt.figure(figsize=(10, 10))
    #gs = gridspec.GridSpec(1, 5, width_ratios=[1]*5)
    
    #print(coords_temp)
    # Plot the crops
    #print("plot-5-crop-img")
    orientataion['array'] = get_padded_roi(orientataion,numberOfImageFromCenter)
    middleImage = len(orientataion["array"])//2
    img = orientataion["array"][middleImage]
    
    roi = orientataion["array"][middleImage-numberOfImageFromCenter:middleImage+numberOfImageFromCenter+1]
    croppedImage = []
    #print(p)
    for i in range(0, len(coords_temp), 4):
        # Copy of img
        img_copy= img.copy()
        h, w = img.shape
        roi_copy = roi.copy()
        # Extract Keypoints
        category = coords_temp[i:i+4]  # Extracting 4 elements at a time
        a= (category[0]*w, category[1]*h)
        b= (category[2]*w, category[3]*h)
        
        # Rotate
        rotate_angle= angle_of_line(a[0], a[1], b[0], b[1])
        transform = A.Compose([
            A.Rotate(limit=(-rotate_angle, -rotate_angle), p=1.0),
        ], keypoint_params= A.KeypointParams(format='xy', remove_invisible=False),
        )

        t= transform(image=img_copy, keypoints=[a,b])
        img_copy= t["image"]
        a,b= t["keypoints"]
        #print(img_copy.shape)
        # Crop + Resize
        img_copy,roi_copy = crop_between_keypoints(roi_copy,img_copy, a, b)
        #print(roi_copy.shape)
        #print(img_copy.shape)
        roi_copy_list  = []
        for j in range (0,numberOfImageFromCenter*2+1):
            roi_copy_list.append(resize_transform(image=roi_copy[j])["image"])
        roi_copy_list = np.array(roi_copy_list)
        
        img_copy = roi_copy_list[numberOfImageFromCenter]
        croppedImage.append(roi_copy_list)
        # Plot
        #ax = plt.subplot(gs[i//4])
        #ax.imshow(img_copy, cmap='gray')
        #ax.set_title(f"L{i//4+1}")
        #ax.axis('off')
    #plt.show()
    return np.array(croppedImage)

# Define DataSet for training

In [17]:
# import pandas as pd
# import numpy as np

# # Sample DataFrame
# df = pd.DataFrame({
#     'A': [0, 1, 2],
#     'B': [1, 0, 2],
#     'C': [2, 2, 1],
#     'D': [0, 1, 0],
#     'E': [1, 0, 1]
#     # add 25 columns in actual case
# })

# # One-hot encode each column in the DataFrame
# def one_hot_encode_row(row):
#     # One hot encode each value in the row (0,1,2)
#     return np.array([np.eye(3)[int(val)] for val in row])

# # Apply the one-hot encoding function to each row
# encoded_data = np.array([one_hot_encode_row(row) for row in df.values])

# # Example: to check one row's shape (25, 3)
# print(encoded_data[0].shape)  # Should output (5, 3) for this example, adjust to 25 columns

# # Example: print the one-hot encoded array for a row
# print(encoded_data[0])  # One-hot encoded first row

# # Final shape of the encoded data for all rows
# print(encoded_data.shape)  # (number_of_rows, 25, 3)

In [18]:
# def one_hot_encode_row(row):
#     # One hot encode each value in the row (0,1,2)
#     return np.array([np.eye(3)[int(val)] for val in row])

# s = one_hot_encode_row([0,1,2,0,0])
# # Apply the one-hot encoding function to each row


In [19]:
class RSNADataset(torch.utils.data.Dataset):
    def __init__(self, study_ids, s1_coords, s2_coords,cfg,transform,isTrain = True,transform_axial = None,transform_sag = None):
        self.cfg= cfg
        self.study_ids = study_ids
        self.transform = transform
        self.isTrain = isTrain
        self.transform_sag = transform_sag
        self.transform_axial = transform_axial
        self.s1_coords = self.align_cord(study_ids,s1_coords)
        self.s2_coords = self.align_cord(study_ids,s2_coords)
        self.labeldf = study_ids[[col for col in study_ids.columns if col not in ['study_id','series_id_sg1','series_id_sg2','series_id_a2']]]

    def convert_to_8bit(self,x):
        lower, upper = np.percentile(x, (1, 99))
        x = np.clip(x, lower, upper)
        x = x - np.min(x)
        x = x / np.max(x) 
        return (x * 255).astype("uint8")
    
    def align_cord(self,study_ids,cords):
        std_id = study_ids[['study_id']] 
        merged = pd.merge(std_id,cords,on=['study_id'],how="inner")
        return merged[[col for col in merged.columns if col not in ['study_id']]]

    def load_dicom_stack(self, dicom_folder, plane, reverse_sort=False):
        dicom_files = glob.glob(os.path.join(dicom_folder, "*.dcm"))
        dicoms = [pydicom.dcmread(f) for f in dicom_files]

        # Determine the plane for sorting (sagittal, coronal, axial)
        plane = {"sagittal": 0, "coronal": 1, "axial": 2}[plane.lower()]
        positions = np.asarray([float(d.ImagePositionPatient[plane]) for d in dicoms])

        # Sort DICOM files based on positions (reverse sort for axial plane if needed)
        idx = np.argsort(-positions if reverse_sort else positions)
        ipp = np.asarray([d.ImagePositionPatient for d in dicoms]).astype("float")[idx]

        # Get the shape of each pixel array (height, width)
        shapes = [d.pixel_array.shape for d in dicoms]
        # Check if all DICOM images have the same shape
        if len(set(shapes)) > 1:
            # There's a shape mismatch, find the minimum shape (height, width)
            min_shape = np.min(shapes, axis=0)
            # Resize images to the minimum shape
            resized_arrays = []
            for d in dicoms:
                img = d.pixel_array.astype("float32")
                if img.shape != tuple(min_shape):
                    resized_img = cv2.resize(img, (min_shape[1], min_shape[0]))  # Resize to (width, height)
                else:
                    resized_img = img  # No resizing needed
                resized_arrays.append(resized_img)

            # Stack the resized images along the first axis
            array = np.stack(resized_arrays)
        else:
            # If all shapes are the same, no resizing is needed
            array = np.stack([d.pixel_array.astype("float32") for d in dicoms])

        # Reorder the array according to the sorted positions
        array = array[idx]

        return {
            "array": self.convert_to_8bit(array),
            "positions": ipp,
            "pixel_spacing": np.asarray(dicoms[0].PixelSpacing).astype("float")
        }
     
    def one_hot_encode_row(self, row):
        return np.array([np.eye(3)[int(val)] for val in row]) 
    def __getitem__(self, idx):
        row = self.study_ids.iloc[idx]
        
        sag_t1 = self.load_dicom_stack(os.path.join(self.cfg.img_dir, str(row.study_id), str(row.series_id_sg1)), plane="sagittal")
        ax_t2 = self.load_dicom_stack(os.path.join(self.cfg.img_dir, str(row.study_id), str(row.series_id_a2)), plane="axial", reverse_sort=True)
        sag_t2 = self.load_dicom_stack(os.path.join(self.cfg.img_dir, str(row.study_id), str(row.series_id_sg2)), plane="sagittal")
        
        top_left_hand_corner_sag_t2 = sag_t2["positions"][len(sag_t2["array"]) // 2]
        sag_y_axis_to_pixel_space = [top_left_hand_corner_sag_t2[2]]
        while len(sag_y_axis_to_pixel_space) < sag_t2["array"].shape[1]: 
            sag_y_axis_to_pixel_space.append(sag_y_axis_to_pixel_space[-1] - sag_t2["pixel_spacing"][1])
        
        sag_y_coord_to_axial_slice = {}
        for ax_t2_slice, ax_t2_pos in zip(ax_t2["array"], ax_t2["positions"]):
            diffs = np.abs(np.asarray(sag_y_axis_to_pixel_space) - ax_t2_pos[2])
            sag_y_coord = np.argmin(diffs)
            sag_y_coord_to_axial_slice[sag_y_coord] = ax_t2_slice
        
        sag1_cord = self.s1_coords.iloc[idx].tolist()
        sag2_cord = self.s2_coords.iloc[idx].tolist()
        
        #print("s1 cord",sag1_cord)
        #print("s2 cord",sag2_cord)
        img= sag_t2["array"][len(sag_t2["array"])//2]
        corresponding_axial = self.transform_axial(img, sag_y_coord_to_axial_slice, sag2_cord,no_of_axial_slice = self.cfg.sag_axial_slices)
        #print(corresponding_axial.shape)

        
        crop_result_t2 = self.transform_sag(sag_t2, sag2_cord,numberOfImageFromCenter = self.cfg.sag_2_slices)
        #print(crop_result_t2.shape)
        
        crop_result_t1 = self.transform_sag(sag_t1, sag1_cord,numberOfImageFromCenter = self.cfg.sag_1_slices)
        #print(crop_result_t1.shape)
        
        
        
        if self.isTrain:
            label  = self.one_hot_encode_row( self.labeldf.iloc[idx])
            return (crop_result_t2,crop_result_t1,corresponding_axial),label
        else:
            return (crop_result_t2,crop_result_t1, corresponding_axial)
            
    
    def __len__(self,):
        return len(self.study_ids)

ds = RSNADataset(train_df,cordinates_t1,cordinates_t2 ,cfg,resize_transform,isTrain= False,transform_axial = prepare_level_wise_axial,transform_sag = plot_5_crops)    

# Plot a Single Sample
print("---- Sample Shapes -----")
(k1,k2,k3)  =  ds[0]
print("k1 shape",k1.shape)
print("k1 shape",k2.shape)
print("k3 shape",k3.shape)


---- Sample Shapes -----
k1 shape (5, 3, 256, 256)
k1 shape (5, 9, 256, 256)
k3 shape (5, 3, 256, 256)


# Create Model

In [20]:
import torch
import torch.nn as nn
import torchvision.models as models

In [21]:
#  model1 = models.video.r3d_18(pretrained=True)

In [22]:
# model1.stem[0].weight

In [23]:
# print(model1)

In [24]:
import torch
import torch.nn as nn
import torchvision.models as models

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()

        # Define three 3D ResNet-18 models
        self.model1 = models.video.r3d_18(pretrained=True)  # First model (input shape: 5,3,256,256)
        self.model2 = models.video.r3d_18(pretrained=True)  # Second model (input shape: 5,3,256,256)
        self.model3 = models.video.r3d_18(pretrained=True)  # Third model (input shape: 5,9,256,256)
        
        original_conv1 = self.model1.stem[0]
        self.model1.stem[0] = nn.Conv3d(5, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
        self.model2.stem[0] = nn.Conv3d(5, 64, kernel_size=(9, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
        self.model3.stem[0] = nn.Conv3d(5, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)

        
        with torch.no_grad():
            self.model1.stem[0].weight[:, :3, :, :, :] = original_conv1.weight  # Copy the pre-trained weights for 3 channels
            self.model1.stem[0].weight[:, 3:, :, :, :] = original_conv1.weight[:, :2, :, :, :]  # Repeat/modify to match 5 channels
            
            self.model2.stem[0].weight[:, :3, :3, :, :] = original_conv1.weight  # Copy the pre-trained weights for 3 channels
            self.model2.stem[0].weight[:, 3:, :3, :, :] = original_conv1.weight[:, :2, :, :, :]  # Repeat/modify to match 5 channels
            
            #self.model2.stem[0].weight[:, :, :3, :, :] = self.model2.stem[0].weight[:, :, :, :, :]  # Copy the pre-trained weights for 3 channels
            self.model2.stem[0].weight[:, :, 3:6, :, :] = self.model2.stem[0].weight[:, :, :3, :, :]  # Repeat/modify to match 5 channels
            self.model2.stem[0].weight[:, :, 6:9, :, :] = self.model2.stem[0].weight[:, :, :3, :, :]  # Repeat/modify to match 5 channels
            
            
            self.model3.stem[0].weight[:, :3, :, :, :] = original_conv1.weight  # Copy the pre-trained weights for 3 channels
            self.model3.stem[0].weight[:, 3:, :, :, :] = original_conv1.weight[:, :2, :, :, :]  # Repeat/modify to match 5 channels
        
        self.model1.fc = nn.Identity()
        self.model2.fc = nn.Identity()
        self.model3.fc = nn.Identity()
        
        # Extract the feature size from the output of the models
        self.hidden_size = 32  # Adjust this if necessary, based on the model's output size
        self.flatten_size = 512
        # Concatenation layer
        self.fc = nn.Linear(self.flatten_size * 3, self.hidden_size)   # Final linear layer after concatenation

        # Subclass outputs (25 classes, each has 3 subclasses)
        self.subclass_layers = nn.ModuleList([nn.Linear(self.hidden_size, 3) for _ in range(25)])

    def forward(self, x1, x2, x3):
        # Forward pass through each model
        output1 = self.model1(x1)  # Shape: [batch_size, hidden_size]
        output2 = self.model2(x2)  # Shape: [batch_size, hidden_size]
        output3 = self.model3(x3)  # Shape: [batch_size, hidden_size]
        
        flatten1 = torch.flatten(output1, 1)
        flatten2 = torch.flatten(output2, 1)
        flatten3 = torch.flatten(output3, 1)
        #print(flatten1.shape)
        #print(flatten2.shape)
       # print(flatten3.shape)
        # Concatenate outputs from the three models
        concatenated_output = torch.cat((flatten1, flatten2,flatten2), dim=1)  # Shape: [batch_size, hidden_size * 3]
       #print(concatenated_output.shape)
        # Pass concatenated output through the final linear layer
        combined_output = self.fc(concatenated_output)  # Shape: [batch_size, hidden_size]

        # Subclass prediction for each class
        subclass_outputs = [torch.softmax(layer(combined_output), dim=1) for layer in self.subclass_layers]

        return subclass_outputs

# Example usage
#if __name__ == "__main__":
  

In [25]:
# model = CustomModel()
# model = model.to(cfg.device)

# # Example input tensors
# input1 = torch.randn(8, 5, 3, 256, 256)  # Batch size 8 for the first model
# input2 = torch.randn(8, 5, 9, 256, 256)  # Batch size 8 for the second model
# input3 = torch.randn(8, 5, 3, 256, 256)  # Batch size 8 for the third model

# outputs = model(input1.to(cfg.device), input2.to(cfg.device), input3.to(cfg.device))
# for i, output in enumerate(outputs):
#     print(f"Subclass Output for Class {i + 1}: Shape {output.shape}")

In [26]:
# import torch
# import torch.nn.functional as F

# def prepare_target(raw_labels):
#     """
#     Convert raw labels into one-hot encoded labels for 25 classes.
    
#     Args:
#         raw_labels (torch.Tensor): Tensor of shape [batch_size, 25] where each value is the subclass label (0, 1, or 2).
        
#     Returns:
#         torch.Tensor: One-hot encoded labels of shape [batch_size, 25, 3].
#     """
#     batch_size = raw_labels.shape[0]
#     num_classes = raw_labels.shape[1]
#     num_subclasses = 3  # There are 3 subclasses

#     # One-hot encode the subclass labels
#     one_hot_labels = F.one_hot(raw_labels, num_classes=num_subclasses)  # Shape: [batch_size, 25, 3]
    
#     return one_hot_labels.float()  # Return as float for compatibility with loss functions

# # Example usage:
# batch_size = 8
# num_classes = 25

# # Randomly generated raw labels where each value is 0, 1, or 2
# raw_labels = torch.randint(0, 3, (batch_size, num_classes))  # Shape: [batch_size, 25]

# # Prepare one-hot encoded targets
# targets = prepare_target(raw_labels)

# print(targets.shape)  # Should print: torch.Size([8, 25, 3])

# Training And Validation Functions

In [27]:
def val_epoch(model, loader):

    model.eval()
    predictions = []
    bar = tqdm(loader)
    with torch.no_grad():
        for data in bar:
            
            sagittal_t2,sagittal_t1,axial_t2, = data
            axial_t2, sagittal_t2, sagittal_t1 = axial_t2.to(device), sagittal_t2.to(device), sagittal_t1.to(device)
           
            logits = model(sagittal_t2,sagittal_t1,axial_t2)
             
            loss = 0
            batch_prediction = []
            for i in range(25):
                # Select the i-th subclass prediction and corresponding target
                subclass_pred = logits[i]  # Output from model, shape: [batch_size, 3]
                #print(subclass_pred.detach().cpu().numpy().shape)
                batch_prediction.append(subclass_pred.detach().cpu().numpy())
                #print(subclass_pred)
                # Compute the loss for the i-th subclass

            predictions.append(np.array(batch_prediction))

    return predictions

In [28]:
dataset_train = RSNADataset(train_df,cordinates_t1,cordinates_t2,cfg,resize_transform, isTrain= False,transform_axial = prepare_level_wise_axial,transform_sag = plot_5_crops)
prediction_dl = torch.utils.data.DataLoader(dataset_train, batch_size=cfg.batch_size, sampler=RandomSampler(dataset_train), num_workers=cfg.num_workers)




In [29]:
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.CUDA_VISIBLE_DEVICES
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [31]:
 cfg.model_dir_lsdd

'/kaggle/input/training-lsdd/resnet18_final_fold0.pth'

In [32]:
model = CustomModel()
model = model.to(device)

# Load the model weights (assuming 'model_weights.pth' is the saved file)
model_weights_path = "/kaggle/input/lumbarspinediseasedetection/pytorch/v1/1/resnet18_final_fold0.pth"
model.load_state_dict(torch.load(model_weights_path,map_location=device))

# Set the model to evaluation mode (if you are using it for inference)
model.eval()

Downloading: "https://download.pytorch.org/models/r3d_18-b3b3357e.pth" to /root/.cache/torch/hub/checkpoints/r3d_18-b3b3357e.pth


URLError: <urlopen error [Errno -3] Temporary failure in name resolution>

In [None]:
predictions = val_epoch(model,prediction_dl)

In [None]:
len(predictions)

In [None]:
predictions_arr = np.concatenate(predictions, axis=1)

In [None]:
predictions_arr.shape

In [None]:
# Class names (for example)
class_names = [
 'spinal_canal_stenosis_l1_l2',
 'spinal_canal_stenosis_l2_l3',
 'spinal_canal_stenosis_l3_l4',
 'spinal_canal_stenosis_l4_l5',
 'spinal_canal_stenosis_l5_s1',
 'left_neural_foraminal_narrowing_l1_l2',
 'left_neural_foraminal_narrowing_l2_l3',
 'left_neural_foraminal_narrowing_l3_l4',
 'left_neural_foraminal_narrowing_l4_l5',
 'left_neural_foraminal_narrowing_l5_s1',
 'right_neural_foraminal_narrowing_l1_l2',
 'right_neural_foraminal_narrowing_l2_l3',
 'right_neural_foraminal_narrowing_l3_l4',
 'right_neural_foraminal_narrowing_l4_l5',
 'right_neural_foraminal_narrowing_l5_s1',
 'left_subarticular_stenosis_l1_l2',
 'left_subarticular_stenosis_l2_l3',
 'left_subarticular_stenosis_l3_l4',
 'left_subarticular_stenosis_l4_l5',
 'left_subarticular_stenosis_l5_s1',
 'right_subarticular_stenosis_l1_l2',
 'right_subarticular_stenosis_l2_l3',
 'right_subarticular_stenosis_l3_l4',
 'right_subarticular_stenosis_l4_l5',
 'right_subarticular_stenosis_l5_s1']

# Prepare a list to hold the data for the DataFrame
data = []

# Populate the data list with class names and subclass predictions
for class_idx, class_name in enumerate(class_names):
    for example_idx in range(predictions_arr.shape[1]):  # Loop through examples
        # Get subclass predictions for this class and example
        subclass_predictions = predictions_arr[class_idx, example_idx, :].tolist()
        # Append to the data list
        data.append([f"{train_df.iloc[example_idx].study_id}_{class_name}"] + subclass_predictions)

# Create a DataFrame
df = pd.DataFrame(data, columns=['row_id', 'normal_mild', 'moderate', 'severe'])

# Display the DataFrame
df_sorted = df.sort_values(by='row_id')

In [None]:
df_sorted['normal_mild'] = df_sorted['normal_mild'].round(7)
df_sorted['moderate'] = df_sorted['moderate'].round(7)
df_sorted['severe'] = df_sorted['severe'].round(7)

In [None]:
df_sorted.tail(10)

In [None]:
df_sorted.to_csv("submission.csv",index=False)

In [None]:
df_sorted