In [1]:
import os
import glob
import random
import math

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm import tqdm
from types import SimpleNamespace
import albumentations as A
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import pydicom


import torch.nn as nn
import torch.optim as optim

from torch.optim import lr_scheduler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import CosineAnnealingLR

from sklearn.model_selection import train_test_split
import bisect
import time

def set_seed(seed=1234):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

  check_for_updates()


In [2]:
!pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

Collecting git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
  Cloning https://github.com/ildoonet/pytorch-gradual-warmup-lr.git to /tmp/pip-req-build-6n8qpfeq
  Running command git clone --filter=blob:none --quiet https://github.com/ildoonet/pytorch-gradual-warmup-lr.git /tmp/pip-req-build-6n8qpfeq
  Resolved https://github.com/ildoonet/pytorch-gradual-warmup-lr.git to commit 7021d63a49106e22c79b40564a7d39930e7b0f53
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: warmup_scheduler
  Building wheel for warmup_scheduler (setup.py) ... [?25ldone
[?25h  Created wheel for warmup_scheduler: filename=warmup_scheduler-0.3.2-py3-none-any.whl size=3880 sha256=b2464e8e2e584077616f6a1340f7b43cb36bb54de3fd4584a2f0d6a511df8408
  Stored in directory: /tmp/pip-ephem-wheel-cache-86eggzt_/wheels/49/78/e6/9168d5844935482a171c7880a0626fa1c6c412b55666635f59
Successfully built warmup_scheduler
Installing collected packages: warmup_scheduler
Successf

# Configurations

In [3]:
# Config
cfg= SimpleNamespace(
    img_dir= "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images",
    device= torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    n_frames=3,
    epochs=4,
    lr=0.0005,
    batch_size=16,
    backbone="resnet18",
    seed= 0,
    model_dir = "/kaggle/working/",
    kernel_type = "resnet18",
    num_workers = 4,
    n_epochs = 4,
    init_lr =0.0005,
    CUDA_VISIBLE_DEVICES = "0",
    sag_axial_slices = 3,
    sag_2_slices =  1,
    sag_1_slices = 4
)
set_seed(seed=cfg.seed) # Makes results reproducable

In [4]:
def load_training_dataframe(cfg,isTrain=True):
    
    train_df  = pd.read_csv('/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train.csv')
    train_series_description = pd.read_csv('/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv')
    train_df_df_cleaned = train_df.dropna()

    sagtialla_t1_df = train_series_description[train_series_description['series_description'] == "Sagittal T1"]
    sagtialla_t1_df_u = sagtialla_t1_df.drop_duplicates(subset=['study_id']).reset_index(drop=True)
    
    axial_t2_df = train_series_description[train_series_description['series_description'] == "Axial T2"]
    axial_t2_df_u = axial_t2_df.drop_duplicates(subset=['study_id']).reset_index(drop=True)

    sagtialla_t2_df = train_series_description[train_series_description['series_description'] == "Sagittal T2/STIR"]
    sagtialla_t2_df_u = sagtialla_t2_df.drop_duplicates(subset=['study_id']).reset_index(drop=True)

    merge_sagital_t1 = pd.merge(train_df_df_cleaned,sagtialla_t1_df_u, on=["study_id"],how="inner" )
    merge_sagital_t1 = merge_sagital_t1.drop(columns=['series_description'])
    merge_sagital_t1.rename(columns={'series_id': 'series_id_sg1'}, inplace=True)
    
    merge_sagital_t2 = pd.merge(merge_sagital_t1,sagtialla_t2_df_u, on=["study_id"],how="inner" )
    merge_sagital_t2 = merge_sagital_t2.drop(columns=['series_description'])
    merge_sagital_t2.rename(columns={'series_id': 'series_id_sg2'}, inplace=True)

    merge_axial_t2 = pd.merge(merge_sagital_t2,axial_t2_df_u, on=["study_id"],how="inner" )
    merge_axial_t2 = merge_axial_t2.drop(columns=['series_description'])
    merge_axial_t2.rename(columns={'series_id': 'series_id_a2'}, inplace=True)

    label2id = {'Normal/Mild': 0, 'Moderate':1, 'Severe':2}
    merge_axial_t2 = merge_axial_t2.replace(label2id)
    return merge_axial_t2  
    

# Utils

In [5]:
def batch_to_device(batch, device, skip_keys=[]):
    batch_dict= {}
    for key in batch:
        if key in skip_keys:
             batch_dict[key]= batch[key]
        else:    
            batch_dict[key]= batch[key].to(device)
    return batch_dict

def visualize_prediction(batch, pred, epoch):
    
    mid= cfg.n_frames//2
    
    # Plot
    for idx in range(1):
    
        # Select Data
        img= batch["img"][idx, mid, :, :].cpu().numpy()*255
        cs_true= batch["label"][idx, ...].cpu().numpy()*256
        cs= pred[idx, ...].cpu().numpy()*256
                
        coords_list = [("TRUE", "lightblue", cs_true), ("PRED", "orange", cs)]
        text_labels = [str(x) for x in range(1,21)]
        
        # Plot coords
        fig, axes = plt.subplots(1, len(coords_list), figsize=(10,4))
        fig.suptitle("EPOCH: {}".format(epoch))
        for ax, (title, color, coords) in zip(axes, coords_list):
            ax.imshow(img, cmap='gray')
            ax.scatter(coords[0::2], coords[1::2], c=color, s=50)
            ax.axis('off')
            ax.set_title(title)

            # Add text labels near the coordinates
            for i, (x, y) in enumerate(zip(coords[0::2], coords[1::2])):
                if i < len(text_labels):  # Ensure there are enough labels
                    ax.text(x + 10, y, text_labels[i], color='white', fontsize=15, bbox=dict(facecolor='black', alpha=0.5))


        fig.suptitle("EPOCH: {}".format(epoch))
        plt.show()
#         plt.close(fig)
    return

def load_weights_skip_mismatch(model, weights_path, device):
    # Load Weights
    state_dict = torch.load(weights_path, map_location=device)
    model_dict = model.state_dict()
    
    # Iter models
    params = {}
    for (sdk, sfv), (mdk, mdv) in zip(state_dict.items(), model_dict.items()):
        if sfv.size() == mdv.size():
            params[sdk] = sfv
        else:
            print("Skipping param: {}, {} != {}".format(sdk, sfv.size(), mdv.size()))
    
    # Reload + Skip
    model.load_state_dict(params, strict=False)
    print("Loaded weights from:", weights_path)

# Create csv Dataset

In [6]:
train_df = load_training_dataframe(cfg)
#train_df = train_df.head(10)
# train_df = train_df[train_df['study_id'].isin([
#     4646740
# ])] 

  merge_axial_t2 = merge_axial_t2.replace(label2id)


In [7]:
# train_df

In [8]:
folder_path = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images/4646740/3201256954'

# List all files in the directory
files = os.listdir(folder_path)

# Filter only files (not directories)
files = [f for f in files if os.path.isfile(os.path.join(folder_path, f))]

print(files)

['12.dcm', '18.dcm', '9.dcm', '22.dcm', '25.dcm', '39.dcm', '45.dcm', '14.dcm', '11.dcm', '44.dcm', '24.dcm', '34.dcm', '29.dcm', '23.dcm', '41.dcm', '35.dcm', '10.dcm', '46.dcm', '28.dcm', '43.dcm', '50.dcm', '37.dcm', '17.dcm', '30.dcm', '1.dcm', '15.dcm', '2.dcm', '52.dcm', '48.dcm', '36.dcm', '8.dcm', '7.dcm', '53.dcm', '21.dcm', '49.dcm', '51.dcm', '33.dcm', '5.dcm', '4.dcm', '42.dcm', '54.dcm', '47.dcm', '31.dcm', '38.dcm', '19.dcm', '27.dcm', '6.dcm', '16.dcm', '20.dcm', '40.dcm', '3.dcm', '32.dcm', '26.dcm', '13.dcm']


In [9]:
# train_df

# Co-ordinate Prediction Dataset

In [10]:
class PreTrainDataset(torch.utils.data.Dataset):
    def __init__(self, study_ids,cfg,transform,isTrain = False,is_dataset_for_t1= False):
        self.cfg= cfg
        self.study_ids = study_ids
        self.transform = transform
        self.isTrain = isTrain
        self.is_dataset_for_t1 = is_dataset_for_t1

    
    def convert_to_8bit(self,x):
        lower, upper = np.percentile(x, (1, 99))
        x = np.clip(x, lower, upper)
        x = x - np.min(x)
        x = x / np.max(x) 
        return (x * 255).astype("uint8")


    def load_dicom_stack(self,dicom_folder, plane, reverse_sort=False):
        dicom_files = glob.glob(os.path.join(dicom_folder, "*.dcm"))
        dicoms = [pydicom.dcmread(f) for f in dicom_files]
        plane = {"sagittal": 0, "coronal": 1, "axial": 2}[plane.lower()]
        positions = np.asarray([float(d.ImagePositionPatient[plane]) for d in dicoms])
        # if reverse_sort=False, then increasing array index will be from RIGHT->LEFT and CAUDAL->CRANIAL
        # thus we do reverse_sort=True for axial so increasing array index is craniocaudal
        idx = np.argsort(-positions if reverse_sort else positions)
        ipp = np.asarray([d.ImagePositionPatient for d in dicoms]).astype("float")[idx]
        array = np.stack([d.pixel_array.astype("float32") for d in dicoms])
        array = array[idx]
        return {"array": self.convert_to_8bit(array), "positions": ipp, "pixel_spacing": np.asarray(dicoms[0].PixelSpacing).astype("float")}
    
    def pad_image(self, img):
        n= img.shape[0]
        if n >= self.cfg.n_frames:
            start_idx = (n - self.cfg.n_frames) // 2
            return img[start_idx:start_idx + self.cfg.n_frames,:, :]
        else:
            pad_left = (self.cfg.n_frames - n) // 2
            pad_right = self.cfg.n_frames - n - pad_left
            return np.pad(img, ((pad_left, pad_right),(0,0), (0,0)), 'constant', constant_values=0)
    
    def load_img(self, series_id):
        fname = self.load_dicom_stack(os.path.join(self.cfg.img_dir, str(series_id)), plane="sagittal")
        img= fname["array"]
        img= self.pad_image(img)
        img= np.transpose(img, (1,2, 0))
        img= self.transform(image=img)["image"]
        img= np.transpose(img, (2, 0, 1))
        img= (img / 255.0)
        return img
        
        
    def __getitem__(self, idx):
        d= self.study_ids.iloc[idx]
        
        if self.is_dataset_for_t1:
            series_id= "/".join([str(d.study_id),str(d.series_id_sg1)])      
        else:
            series_id= "/".join([str(d.study_id),str(d.series_id_sg2)])      
                
        img= self.load_img(series_id)
        if self.isTrain:
            return {
                'img': img, 
                'label': label,
            }
        else:
            return {
                'img': img 
            }
            
    
    def __len__(self,):
        return len(self.study_ids)
    

resize_transform_point= A.Compose([
A.LongestMaxSize(max_size=256, interpolation=cv2.INTER_CUBIC, always_apply=True),
A.PadIfNeeded(min_height=256, min_width=256, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0), always_apply=True),
])




ds = PreTrainDataset(train_df, cfg,resize_transform_point)    

# Plot a Single Sample
print("---- Sample Shapes -----")
for k, v in ds[0].items():
    print(k, v.shape)


---- Sample Shapes -----
img (3, 256, 256)


# Load Model

In [11]:
# Load backbone for RSNA 2024 task
model_path = "/kaggle/input/lumbar_spine_verterbrae_disc-detection/pytorch/1/1/resnet18_0.pt"
model = timm.create_model('resnet18', pretrained=True, num_classes=20)
model = model.to(cfg.device)
load_weights_skip_mismatch(model, model_path, cfg.device)

model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

  state_dict = torch.load(weights_path, map_location=device)


Loaded weights from: /kaggle/input/lumbar_spine_verterbrae_disc-detection/pytorch/1/1/resnet18_0.pt


# Co-ordinate Inferencing

In [12]:
def coordinate_prediction(model,pred_dataloader,isSagitalT1 = False):
    predictions = []
    with torch.no_grad():
        model = model.eval()
        for batch in tqdm(pred_dataloader):
            batch = batch_to_device(batch, cfg.device)

            pred = model(batch["img"].float())
            pred = torch.sigmoid(pred)
            predictions.append(pred)
            
    data_list_cpu = [tensor.cpu().numpy() for tensor in predictions]
    combined_data = np.vstack(data_list_cpu)
    df_cords = pd.DataFrame(combined_data)
    if isSagitalT1:
        df_cords.to_csv("predicted_cordinates_sag_t1.csv",index=False)
    else:
        df_cords.to_csv("predicted_cordinates_sag_t2.csv",index=False)
    return df_cords

# Coordinate Prediction

In [13]:
train_ds = PreTrainDataset(train_df, cfg,resize_transform_point)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=False)

In [14]:
cordinates_t1 = coordinate_prediction(model,train_dl,isSagitalT1 = True)
cordinates_t2 = coordinate_prediction(model,train_dl,isSagitalT1 = False)


100%|██████████| 112/112 [12:44<00:00,  6.82s/it]
100%|██████████| 112/112 [09:03<00:00,  4.86s/it]


In [15]:
#train_df['study_id']

In [16]:
#cordinates_t1

In [17]:
cordinates_t1['study_id'] = train_df['study_id'].values
cordinates_t2['study_id'] = train_df['study_id'].values

In [18]:
cordinates_t1

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11,12,13,14,15,16,17,18,19,study_id
0,0.358656,0.374705,0.509141,0.371950,0.344499,0.489798,0.511614,0.473629,0.353111,0.605587,...,0.587102,0.373437,0.722634,0.526738,0.671897,0.424140,0.849819,0.551493,0.743065,4003253
1,0.390123,0.272419,0.629843,0.302406,0.361005,0.392987,0.603127,0.413803,0.353677,0.510201,...,0.507031,0.364985,0.626339,0.617148,0.590170,0.398347,0.745176,0.657364,0.680044,4646740
2,0.377804,0.277879,0.551068,0.310396,0.354630,0.411120,0.536126,0.413698,0.358668,0.525839,...,0.527877,0.356278,0.658888,0.532196,0.629970,0.390902,0.790636,0.553370,0.707839,7143189
3,0.422230,0.272064,0.554788,0.318366,0.375878,0.379581,0.531984,0.418145,0.357825,0.490115,...,0.513823,0.356010,0.614302,0.512234,0.588183,0.397038,0.752634,0.528639,0.675257,8785691
4,0.402759,0.257216,0.596857,0.293941,0.369302,0.386106,0.574754,0.422091,0.338726,0.532496,...,0.546451,0.322665,0.671288,0.537063,0.671580,0.354253,0.850769,0.566507,0.759883,10728036
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1785,0.348200,0.222990,0.571816,0.267857,0.310215,0.352571,0.529989,0.385268,0.293090,0.515912,...,0.505981,0.307020,0.670571,0.526042,0.634243,0.350070,0.830489,0.557492,0.720233,4282019580
1786,0.377478,0.307640,0.564883,0.339093,0.337882,0.457434,0.533598,0.461048,0.322099,0.600892,...,0.591484,0.334299,0.757367,0.519319,0.693534,0.415685,0.898814,0.571472,0.791665,4283570761
1787,0.332602,0.327601,0.519570,0.348472,0.317127,0.434848,0.492108,0.478025,0.310772,0.575786,...,0.579262,0.321266,0.717456,0.487239,0.674535,0.357391,0.863263,0.506651,0.764439,4284048608
1788,0.349592,0.288166,0.521606,0.327668,0.318610,0.419837,0.505833,0.439730,0.319585,0.558770,...,0.535285,0.354532,0.707523,0.532307,0.632179,0.429828,0.826329,0.586490,0.707909,4287160193


In [19]:
def get_padded_roi(orientation, numberOfImageFromCenter):
    # Calculate middle index
    middleImage = len(orientation["array"]) // 2

    # Calculate start and end indices
    start_idx = middleImage - numberOfImageFromCenter
    end_idx = middleImage + numberOfImageFromCenter + 1

    # Handle bounds
    array_length = orientation["array"].shape[0]  # Slicing along the first axis (number of images)

    # Ensure we don't go beyond the array's bounds
    start_pad = max(0, -start_idx)
    end_pad = max(0, end_idx - array_length)

    # Slice the valid part of the array along the first axis
    roi = orientation["array"][max(0, start_idx):min(array_length, end_idx)]

    # Pad with zeros if needed
    if start_pad > 0 or end_pad > 0:
        roi = np.pad(roi, ((start_pad, end_pad), (0, 0), (0, 0)), mode='constant', constant_values=0)

    return roi

In [20]:
orientation = {
    "array": np.random.randint(0, 10, (1, 5, 5))  # Example 3D array of shape (10, 5, 5)
}

numberOfImageFromCenter = 1
orientation['array'] = get_padded_roi(orientation, numberOfImageFromCenter)
print(orientation)

{'array': array([[[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]],

       [[5, 0, 3, 3, 7],
        [9, 3, 5, 2, 4],
        [7, 6, 8, 8, 1],
        [6, 7, 7, 8, 1],
        [5, 9, 8, 9, 4]],

       [[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]]])}


# Load Utils for LSDD 

In [21]:
def prepare_level_wise_axial(sagittal_img, imgsag_y_coord_to_axial_slice, coordinates,no_of_axial_slice = 3):
    h, w = sagittal_img.shape
    axial_list = []
    #print("start axial")
    first_key = next(iter(imgsag_y_coord_to_axial_slice))
    # Get the first value
    first_value = imgsag_y_coord_to_axial_slice[first_key]

    keys = list(imgsag_y_coord_to_axial_slice.keys())
    keys.sort()

    for i in range(0, len(coordinates), 4):
        #print(i)
        category = coordinates[i:i+4]  # Extracting 4 elements at a time
        y= [category[1]*h,category[3]*h]
        minimum = math.floor(min(y[0],y[1]))
        maximum  = math.ceil(max(y[0],y[1]))
        filtered_keys = [k for k in imgsag_y_coord_to_axial_slice.keys() if minimum <= k <= maximum]
        
        if len(filtered_keys) >= no_of_axial_slice:
            # Use numpy to select 3 keys at uniform intervals
            selected_keys = np.linspace(0, len(filtered_keys)-1,no_of_axial_slice , dtype=int)
            selected_keys = [filtered_keys[i] for i in selected_keys]
        else:
            if len(filtered_keys) == 0:
                index = bisect.bisect_left(keys, minimum)
                # Check if we can find a nearest value less than current_value
                if index > 0:
                    filtered_keys.append(keys[index - 1])

                index = bisect.bisect_right(keys, maximum)
                if index < len(keys):
                    filtered_keys.append(keys[index])  # Return the nearest value greater than current_value
                    

            filtered_keys.extend([filtered_keys[-1]] * (no_of_axial_slice - len(filtered_keys)))
            selected_keys = filtered_keys
        
        if len(selected_keys) ==  0:
            result  = np.zeros((no_of_axial_slice, first_value.shape[0],first_value.shape[1]))
        else:
            selected_axial =[imgsag_y_coord_to_axial_slice.get(k) for k in selected_keys]
            result  = np.array(selected_axial)

        roi_copy_list  = []
        for j in range(result.shape[0]):
            roi_copy_list.append(resize_transform(image=result[j])["image"])
            
        axial_list.append(np.array(roi_copy_list))
        
    return  np.array(axial_list)

In [22]:
resize_transform= A.Compose([
    A.LongestMaxSize(max_size=256, interpolation=cv2.INTER_CUBIC, always_apply=True),
    A.PadIfNeeded(min_height=256, min_width=256, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0), always_apply=True),
    A.Normalize()
])

def angle_of_line(x1, y1, x2, y2):
    return math.degrees(math.atan2(-(y2-y1), x2-x1))

def crop_between_keypoints(roi,img, keypoint1, keypoint2):
    h, w = img.shape
    x1, y1 = int(keypoint1[0]), int(keypoint1[1])
    x2, y2 = int(keypoint2[0]), int(keypoint2[1])

    # Calculate bounding box around the keypoints
    left = int(min(x1, x2) - (w * 0.1))
    right = int(max(x1, x2) + (w * 0.1))
    top = int(min(y1, y2) - (h * 0.07))
    bottom = int(max(y1, y2) + (h * 0.1))
    
    left = max(0, left)
    right = min(w, right)
    top = max(0, top)
    bottom = min(h, bottom)
    # Crop the image
    return img[top:bottom, left:right],roi[:,top:bottom, left:right]

def plot_5_crops(orientataion,coords_temp,numberOfImageFromCenter = 3):
    # Create a figure and axis for the grid
    #fig = plt.figure(figsize=(10, 10))
    #gs = gridspec.GridSpec(1, 5, width_ratios=[1]*5)
    #print("start crop")
    #print(coords_temp)
    # Plot the crops
    #print("plot-5-crop-img")
    orientataion['array'] = get_padded_roi(orientataion,numberOfImageFromCenter)
    middleImage = len(orientataion["array"])//2
    img = orientataion["array"][middleImage]
    
    roi = orientataion["array"][middleImage-numberOfImageFromCenter:middleImage+numberOfImageFromCenter+1]
    croppedImage = []
    #print(p)
    for i in range(0, len(coords_temp), 4):
        # Copy of img
        img_copy= img.copy()
        h, w = img.shape
        #print("Level ",i//4)
        #print("img shape",img.shape)
        roi_copy = roi.copy()
        # Extract Keypoints
        category = coords_temp[i:i+4]  # Extracting 4 elements at a time
        a= (category[0]*w, category[1]*h)
        b= (category[2]*w, category[3]*h)
        
        # Rotate
        rotate_angle= angle_of_line(a[0], a[1], b[0], b[1])
        transform = A.Compose([
            A.Rotate(limit=(-rotate_angle, -rotate_angle), p=1.0),
        ], keypoint_params= A.KeypointParams(format='xy', remove_invisible=False),
        )
    
        t= transform(image=img_copy, keypoints=[a,b])
        img_copy= t["image"]
        a,b= t["keypoints"]
        #print("before",img_copy.shape)
        # Crop + Resize
        img_copy,roi_copy = crop_between_keypoints(roi_copy,img_copy, a, b)
        #print("Before Transformation",roi_copy.shape)
        #print("after Tra img Shape",img_copy.shape)
        roi_copy_list  = []
        for j in range (0,numberOfImageFromCenter*2+1):
            roi_copy_list.append(resize_transform(image=roi_copy[j])["image"])
        roi_copy_list = np.array(roi_copy_list)
        #print("After transformation",roi_copy_list.shape)
        img_copy = roi_copy_list[numberOfImageFromCenter]
        croppedImage.append(roi_copy_list)
        # Plot
        #ax = plt.subplot(gs[i//4])
        #ax.imshow(img_copy, cmap='gray')
        #ax.set_title(f"L{i//4+1}")
        #ax.axis('off')
    #plt.show()
    #print(np.array(croppedImage).shape)
    return np.array(croppedImage)

# Define DataSet for training

In [23]:
# import pandas as pd
# import numpy as np

# # Sample DataFrame
# df = pd.DataFrame({
#     'A': [0, 1, 2],
#     'B': [1, 0, 2],
#     'C': [2, 2, 1],
#     'D': [0, 1, 0],
#     'E': [1, 0, 1]
#     # add 25 columns in actual case
# })

# # One-hot encode each column in the DataFrame
# def one_hot_encode_row(row):
#     # One hot encode each value in the row (0,1,2)
#     return np.array([np.eye(3)[int(val)] for val in row])

# # Apply the one-hot encoding function to each row
# encoded_data = np.array([one_hot_encode_row(row) for row in df.values])

# # Example: to check one row's shape (25, 3)
# print(encoded_data[0].shape)  # Should output (5, 3) for this example, adjust to 25 columns

# # Example: print the one-hot encoded array for a row
# print(encoded_data[0])  # One-hot encoded first row

# # Final shape of the encoded data for all rows
# print(encoded_data.shape)  # (number_of_rows, 25, 3)

In [24]:
# def one_hot_encode_row(row):
#     # One hot encode each value in the row (0,1,2)
#     return np.array([np.eye(3)[int(val)] for val in row])

# s = one_hot_encode_row([0,1,2,0,0])
# # Apply the one-hot encoding function to each row


In [25]:
class RSNADataset(torch.utils.data.Dataset):
    def __init__(self, study_ids, s1_coords, s2_coords,cfg,transform,isTrain = True,transform_axial = None,transform_sag = None):
        self.cfg= cfg
        self.study_ids = study_ids
        self.transform = transform
        self.isTrain = isTrain
        self.transform_sag = transform_sag
        self.transform_axial = transform_axial
        self.s1_coords = self.align_cord(study_ids,s1_coords)
        self.s2_coords = self.align_cord(study_ids,s2_coords)
        self.labeldf = study_ids[[col for col in study_ids.columns if col not in ['study_id','series_id_sg1','series_id_sg2','series_id_a2']]]

    def convert_to_8bit(self,x):
        lower, upper = np.percentile(x, (1, 99))
        x = np.clip(x, lower, upper)
        x = x - np.min(x)
        x = x / np.max(x) 
        return (x * 255).astype("uint8")
    
    def align_cord(self,study_ids,cords):
        std_id = study_ids[['study_id']] 
        merged = pd.merge(std_id,cords,on=['study_id'],how="inner")
        return merged[[col for col in merged.columns if col not in ['study_id']]]

    def load_dicom_stack(self, dicom_folder, plane, reverse_sort=False):
        dicom_files = glob.glob(os.path.join(dicom_folder, "*.dcm"))
        dicoms = [pydicom.dcmread(f) for f in dicom_files]

        # Determine the plane for sorting (sagittal, coronal, axial)
        plane = {"sagittal": 0, "coronal": 1, "axial": 2}[plane.lower()]
        positions = np.asarray([float(d.ImagePositionPatient[plane]) for d in dicoms])

        # Sort DICOM files based on positions (reverse sort for axial plane if needed)
        idx = np.argsort(-positions if reverse_sort else positions)
        ipp = np.asarray([d.ImagePositionPatient for d in dicoms]).astype("float")[idx]

        # Get the shape of each pixel array (height, width)
        shapes = [d.pixel_array.shape for d in dicoms]
        # Check if all DICOM images have the same shape
        if len(set(shapes)) > 1:
            # There's a shape mismatch, find the minimum shape (height, width)
            min_shape = np.min(shapes, axis=0)
            # Resize images to the minimum shape
            resized_arrays = []
            for d in dicoms:
                img = d.pixel_array.astype("float32")
                if img.shape != tuple(min_shape):
                    resized_img = cv2.resize(img, (min_shape[1], min_shape[0]))  # Resize to (width, height)
                else:
                    resized_img = img  # No resizing needed
                resized_arrays.append(resized_img)

            # Stack the resized images along the first axis
            array = np.stack(resized_arrays)
        else:
            # If all shapes are the same, no resizing is needed
            array = np.stack([d.pixel_array.astype("float32") for d in dicoms])

        # Reorder the array according to the sorted positions
        array = array[idx]

        return {
            "array": self.convert_to_8bit(array),
            "positions": ipp,
            "pixel_spacing": np.asarray(dicoms[0].PixelSpacing).astype("float")
        }

     
    def one_hot_encode_row(self, row):
        return np.array([np.eye(3)[int(val)] for val in row]) 
    def __getitem__(self, idx):
        row = self.study_ids.iloc[idx]
        #print(row.study_id)
        sag_t1 = self.load_dicom_stack(os.path.join(self.cfg.img_dir, str(row.study_id), str(row.series_id_sg1)), plane="sagittal")
        ax_t2 = self.load_dicom_stack(os.path.join(self.cfg.img_dir, str(row.study_id), str(row.series_id_a2)), plane="axial", reverse_sort=True)
        sag_t2 = self.load_dicom_stack(os.path.join(self.cfg.img_dir, str(row.study_id), str(row.series_id_sg2)), plane="sagittal")
        
        top_left_hand_corner_sag_t2 = sag_t2["positions"][len(sag_t2["array"]) // 2]
        sag_y_axis_to_pixel_space = [top_left_hand_corner_sag_t2[2]]
        while len(sag_y_axis_to_pixel_space) < sag_t2["array"].shape[1]: 
            sag_y_axis_to_pixel_space.append(sag_y_axis_to_pixel_space[-1] - sag_t2["pixel_spacing"][1])
        
        sag_y_coord_to_axial_slice = {}
        for ax_t2_slice, ax_t2_pos in zip(ax_t2["array"], ax_t2["positions"]):
            diffs = np.abs(np.asarray(sag_y_axis_to_pixel_space) - ax_t2_pos[2])
            sag_y_coord = np.argmin(diffs)
            sag_y_coord_to_axial_slice[sag_y_coord] = ax_t2_slice
        
        sag1_cord = self.s1_coords.iloc[idx].tolist()
        sag2_cord = self.s2_coords.iloc[idx].tolist()
        
        #print("s1 cord",sag1_cord)
        #print("s2 cord",sag2_cord)
        img= sag_t2["array"][len(sag_t2["array"])//2]
        corresponding_axial = self.transform_axial(img, sag_y_coord_to_axial_slice, sag2_cord,no_of_axial_slice = self.cfg.sag_axial_slices)
        #print(corresponding_axial.shape)

        
        crop_result_t2 = self.transform_sag(sag_t2, sag2_cord,numberOfImageFromCenter = self.cfg.sag_2_slices)
        #print(crop_result_t2.shape)
        
        crop_result_t1 = self.transform_sag(sag_t1, sag1_cord,numberOfImageFromCenter = self.cfg.sag_1_slices)
        #print(crop_result_t1.shape)
        
        
        
        if self.isTrain:
            label  = self.one_hot_encode_row( self.labeldf.iloc[idx])
            return (crop_result_t2,crop_result_t1,corresponding_axial),label
        else:
            return (crop_result_t2,crop_result_t1, corresponding_axial)
            
    
    def __len__(self,):
        return len(self.study_ids)

ds = RSNADataset(train_df,cordinates_t1,cordinates_t2 ,cfg,resize_transform,isTrain= True,transform_axial = prepare_level_wise_axial,transform_sag = plot_5_crops)    

# Plot a Single Sample
print("---- Sample Shapes -----")
(k1,k2,k3), v  =  ds[1]
print("k1 shape",k1.shape)
print("k1 shape",k2.shape)
print("k3 shape",k3.shape)
print("label shape",v.shape)

---- Sample Shapes -----
k1 shape (5, 3, 256, 256)
k1 shape (5, 9, 256, 256)
k3 shape (5, 3, 256, 256)
label shape (25, 3)


# Create Model

In [26]:
import torch
import torch.nn as nn
import torchvision.models as models

In [27]:
#  model1 = models.video.r3d_18(pretrained=True)

In [28]:
# model1.stem[0].weight

In [29]:
# print(model1)

In [30]:
import torch
import torch.nn as nn
import torchvision.models as models

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()

        # Define three 3D ResNet-18 models
        self.model1 = models.video.r3d_18(pretrained=True)  # First model (input shape: 5,3,256,256)
        self.model2 = models.video.r3d_18(pretrained=True)  # Second model (input shape: 5,3,256,256)
        self.model3 = models.video.r3d_18(pretrained=True)  # Third model (input shape: 5,9,256,256)
        
        original_conv1 = self.model1.stem[0]
        self.model1.stem[0] = nn.Conv3d(5, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
        self.model2.stem[0] = nn.Conv3d(5, 64, kernel_size=(9, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
        self.model3.stem[0] = nn.Conv3d(5, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)

        
        with torch.no_grad():
            self.model1.stem[0].weight[:, :3, :, :, :] = original_conv1.weight  # Copy the pre-trained weights for 3 channels
            self.model1.stem[0].weight[:, 3:, :, :, :] = original_conv1.weight[:, :2, :, :, :]  # Repeat/modify to match 5 channels
            
            self.model2.stem[0].weight[:, :3, :3, :, :] = original_conv1.weight  # Copy the pre-trained weights for 3 channels
            self.model2.stem[0].weight[:, 3:, :3, :, :] = original_conv1.weight[:, :2, :, :, :]  # Repeat/modify to match 5 channels
            
            #self.model2.stem[0].weight[:, :, :3, :, :] = self.model2.stem[0].weight[:, :, :, :, :]  # Copy the pre-trained weights for 3 channels
            self.model2.stem[0].weight[:, :, 3:6, :, :] = self.model2.stem[0].weight[:, :, :3, :, :]  # Repeat/modify to match 5 channels
            self.model2.stem[0].weight[:, :, 6:9, :, :] = self.model2.stem[0].weight[:, :, :3, :, :]  # Repeat/modify to match 5 channels
            
            
            self.model3.stem[0].weight[:, :3, :, :, :] = original_conv1.weight  # Copy the pre-trained weights for 3 channels
            self.model3.stem[0].weight[:, 3:, :, :, :] = original_conv1.weight[:, :2, :, :, :]  # Repeat/modify to match 5 channels
        
        self.model1.fc = nn.Identity()
        self.model2.fc = nn.Identity()
        self.model3.fc = nn.Identity()
        
        # Extract the feature size from the output of the models
        self.hidden_size = 32  # Adjust this if necessary, based on the model's output size
        self.flatten_size = 512
        # Concatenation layer
        self.fc = nn.Linear(self.flatten_size * 3, self.hidden_size)   # Final linear layer after concatenation

        # Subclass outputs (25 classes, each has 3 subclasses)
        self.subclass_layers = nn.ModuleList([nn.Linear(self.hidden_size, 3) for _ in range(25)])

    def forward(self, x1, x2, x3):
        # Forward pass through each model
        output1 = self.model1(x1)  # Shape: [batch_size, hidden_size]
        output2 = self.model2(x2)  # Shape: [batch_size, hidden_size]
        output3 = self.model3(x3)  # Shape: [batch_size, hidden_size]
        
        flatten1 = torch.flatten(output1, 1)
        flatten2 = torch.flatten(output2, 1)
        flatten3 = torch.flatten(output3, 1)
        #print(flatten1.shape)
        #print(flatten2.shape)
       # print(flatten3.shape)
        # Concatenate outputs from the three models
        concatenated_output = torch.cat((flatten1, flatten2,flatten2), dim=1)  # Shape: [batch_size, hidden_size * 3]
       #print(concatenated_output.shape)
        # Pass concatenated output through the final linear layer
        combined_output = self.fc(concatenated_output)  # Shape: [batch_size, hidden_size]

        # Subclass prediction for each class
        subclass_outputs = [torch.softmax(layer(combined_output), dim=1) for layer in self.subclass_layers]

        return subclass_outputs

# Example usage
#if __name__ == "__main__":
  

In [31]:
# model = CustomModel()
# model = model.to(cfg.device)

# # Example input tensors
# input1 = torch.randn(8, 5, 3, 256, 256)  # Batch size 8 for the first model
# input2 = torch.randn(8, 5, 9, 256, 256)  # Batch size 8 for the second model
# input3 = torch.randn(8, 5, 3, 256, 256)  # Batch size 8 for the third model

# outputs = model(input1.to(cfg.device), input2.to(cfg.device), input3.to(cfg.device))
# for i, output in enumerate(outputs):
#     print(f"Subclass Output for Class {i + 1}: Shape {output.shape}")

In [32]:
# import torch
# import torch.nn.functional as F

# def prepare_target(raw_labels):
#     """
#     Convert raw labels into one-hot encoded labels for 25 classes.
    
#     Args:
#         raw_labels (torch.Tensor): Tensor of shape [batch_size, 25] where each value is the subclass label (0, 1, or 2).
        
#     Returns:
#         torch.Tensor: One-hot encoded labels of shape [batch_size, 25, 3].
#     """
#     batch_size = raw_labels.shape[0]
#     num_classes = raw_labels.shape[1]
#     num_subclasses = 3  # There are 3 subclasses

#     # One-hot encode the subclass labels
#     one_hot_labels = F.one_hot(raw_labels, num_classes=num_subclasses)  # Shape: [batch_size, 25, 3]
    
#     return one_hot_labels.float()  # Return as float for compatibility with loss functions

# # Example usage:
# batch_size = 8
# num_classes = 25

# # Randomly generated raw labels where each value is 0, 1, or 2
# raw_labels = torch.randint(0, 3, (batch_size, num_classes))  # Shape: [batch_size, 25]

# # Prepare one-hot encoded targets
# targets = prepare_target(raw_labels)

# print(targets.shape)  # Should print: torch.Size([8, 25, 3])

# Training And Validation Functions

In [33]:
def train_epoch(model, loader, optimizer):

    model.train()
    train_loss = []
    bar = tqdm(loader)
    for (data, target) in bar:

        optimizer.zero_grad()
        
        axial_t2,sagittal_t2,sagittal_t1 = data
        axial_t2, sagittal_t2, sagittal_t1, target = axial_t2.to(device), sagittal_t2.to(device), sagittal_t1.to(device),   target.to(device)
        logits = model(axial_t2, sagittal_t2,sagittal_t1)
             
        loss = 0
        for i in range(25):
            # Select the i-th subclass prediction and corresponding target
            subclass_pred = logits[i]  # Output from model, shape: [batch_size, 3]
            subclass_target = target[:, i]  # Raw target label (not one-hot), shape: [batch_size]

            # Compute the loss for the i-th subclass
            loss += criterion(subclass_pred, subclass_target)
        
        loss.backward()
        
        optimizer.step()

        loss_np = loss.detach().cpu().numpy()
        train_loss.append(loss_np)
        smooth_loss = sum(train_loss[-100:]) / min(len(train_loss), 100)
        bar.set_description('loss: %.5f, smth: %.5f' % (loss_np, smooth_loss))

    train_loss = np.mean(train_loss)
    return train_loss

In [34]:
def val_epoch(model, loader):

    model.eval()
    val_loss = []
    bar = tqdm(loader)
    with torch.no_grad():
        for (data, target) in bar:
            
            axial_t2,sagittal_t2,sagittal_t1 = data
            axial_t2, sagittal_t2, sagittal_t1, target = axial_t2.to(device), sagittal_t2.to(device), sagittal_t1.to(device),   target.to(device)
           
            logits = model(axial_t2, sagittal_t2,sagittal_t1)
             
            loss = 0
            for i in range(25):
                # Select the i-th subclass prediction and corresponding target
                subclass_pred = logits[i]  # Output from model, shape: [batch_size, 3]
                subclass_target = target[:, i]  # Raw target label (not one-hot), shape: [batch_size]
                #print(subclass_pred)
                # Compute the loss for the i-th subclass
                loss += criterion(subclass_pred, subclass_target)

            val_loss.append(loss.detach().cpu().numpy())

    return  np.mean(val_loss)

In [35]:
# Fix Warmup Bug
from warmup_scheduler import GradualWarmupScheduler  # https://github.com/ildoonet/pytorch-gradual-warmup-lr


class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

In [36]:
def run(df_train,df_valid,fold=0):

    
    dataset_train = RSNADataset(df_train,cordinates_t1,cordinates_t2,cfg,resize_transform, isTrain= True,transform_axial = prepare_level_wise_axial,transform_sag = plot_5_crops)
    dataset_valid = RSNADataset(df_valid, cordinates_t1,cordinates_t2,cfg,resize_transform,isTrain= True,transform_axial = prepare_level_wise_axial,transform_sag = plot_5_crops)
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=cfg.batch_size, sampler=RandomSampler(dataset_train), num_workers=cfg.num_workers)
    valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=cfg.batch_size, num_workers=cfg.num_workers)

    model = CustomModel()
    
    model = model.to(device)

    #auc_20_max = 0.
    model_file3 = os.path.join(cfg.model_dir, f'{cfg.kernel_type}_final_fold{fold}.pth')

    optimizer = optim.Adam(model.parameters(), lr=cfg.init_lr)
    
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, cfg.n_epochs - 1)
    scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=1, after_scheduler=scheduler_cosine)
    
    print(len(dataset_train), len(dataset_valid))

    for epoch in range(1, cfg.n_epochs + 1):
        print(time.ctime(), f'Fold {fold}, Epoch {epoch}')

        train_loss = train_epoch(model, train_loader, optimizer)
        val_loss = val_epoch(model, valid_loader)

        content = time.ctime() + ' ' + f'Fold {fold}, Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.5f}, valid loss: {(val_loss):.5f}.'
        print(content)

        scheduler_warmup.step()    
        if epoch==2: scheduler_warmup.step() # bug workaround   
        
        model_file  = os.path.join(cfg.model_dir, f'{cfg.kernel_type}_best_fold{epoch}.pth')
        print('auc_max ({:.6f} --> {:.6f}). Saving model ...'.format(train_loss, val_loss))
        torch.save(model.state_dict(), model_file)
    

    torch.save(model.state_dict(), model_file3)

In [37]:
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.CUDA_VISIBLE_DEVICES
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

In [38]:
df_train, df_valid = train_test_split(train_df, test_size=0.2, random_state=42)

In [39]:
run(df_train,df_valid,fold=0)

Downloading: "https://download.pytorch.org/models/r3d_18-b3b3357e.pth" to /root/.cache/torch/hub/checkpoints/r3d_18-b3b3357e.pth
100%|██████████| 127M/127M [00:00<00:00, 200MB/s] 


1432 358
Fri Oct 11 10:38:56 2024 Fold 0, Epoch 1


loss: 19.38962, smth: 19.91698: 100%|██████████| 90/90 [14:01<00:00,  9.35s/it] 
100%|██████████| 23/23 [03:36<00:00,  9.41s/it]


Fri Oct 11 10:56:34 2024 Fold 0, Epoch 1, lr: 0.0005000, train loss: 19.91698, valid loss: 19.05346.
auc_max (19.916983 --> 19.053460). Saving model ...
Fri Oct 11 10:56:35 2024 Fold 0, Epoch 2


loss: 18.78666, smth: 19.30875: 100%|██████████| 90/90 [13:25<00:00,  8.95s/it]
100%|██████████| 23/23 [03:14<00:00,  8.47s/it]
  _warn_get_lr_called_within_step(self)


Fri Oct 11 11:13:15 2024 Fold 0, Epoch 2, lr: 0.0050000, train loss: 19.30875, valid loss: 19.16362.
auc_max (19.308750 --> 19.163620). Saving model ...
Fri Oct 11 11:13:15 2024 Fold 0, Epoch 3


loss: 16.03252, smth: 19.26404: 100%|██████████| 90/90 [12:46<00:00,  8.52s/it]
100%|██████████| 23/23 [03:09<00:00,  8.23s/it]


Fri Oct 11 11:29:12 2024 Fold 0, Epoch 3, lr: 0.0037500, train loss: 19.26404, valid loss: 19.07790.
auc_max (19.264043 --> 19.077900). Saving model ...
Fri Oct 11 11:29:12 2024 Fold 0, Epoch 4


loss: 17.55321, smth: 19.16020: 100%|██████████| 90/90 [12:13<00:00,  8.15s/it]
100%|██████████| 23/23 [03:14<00:00,  8.45s/it]


Fri Oct 11 11:44:40 2024 Fold 0, Epoch 4, lr: 0.0012500, train loss: 19.16020, valid loss: 18.90640.
auc_max (19.160198 --> 18.906395). Saving model ...


In [40]:
df_valid.to_csv("hello.csv")