# Data & Config

In [None]:
import pandas as pd
import pydicom
import numpy as np
import cv2
import os

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torchmetrics import Accuracy, F1Score, Recall, AUROC

import albumentations as A
import timm
from tqdm import tqdm

In [None]:
IMG_SIZE = (512, 512)
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = TRAIN_BATCH_SIZE * 2
AUG_PROB = 0.75
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LR = 1e-4
EPOCHS = 20
LOAD_CHECKPOINT = True

In [None]:
df = pd.read_csv("/kaggle/input/rsna-lsdc-data/train.csv")

df = df[df["series_description"] == "Sagittal T1"]

cols_to_include = ["study_id", "series_id", "instance_number"] + list(filter(lambda x: x.startswith("spinal_canal_stenosis"), df.columns))
df = df[cols_to_include]

df.reset_index(drop=True, inplace=True)
df = pd.get_dummies(df, dtype=int)

NUM_CLASSES = df.shape[1] - 3

df.head()

In [None]:
df = df.sample(n=5000, ignore_index=True, random_state=42)

# Utils

In [None]:
def load_dicom(src_path, resize_shape):

    dicom_data = pydicom.dcmread(src_path).pixel_array
    resized_image = (dicom_data / np.max(dicom_data) * 255).astype(np.uint8)
    resized_image = cv2.resize(resized_image, resize_shape)
    
    return resized_image

# Dataset & DataLoader

In [None]:
from sklearn.model_selection import train_test_split

train_df, valid_df = train_test_split(df, test_size=.2, random_state=42)
train_df.reset_index(drop=True, inplace=True)
valid_df.reset_index(drop=True, inplace=True)

In [None]:
transforms_train = A.Compose([
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=AUG_PROB),
    A.OneOf([
        A.MotionBlur(blur_limit=5),
        A.MedianBlur(blur_limit=5),
        A.GaussianBlur(blur_limit=5),
        A.GaussNoise(var_limit=(5.0, 30.0)),
    ], p=AUG_PROB),

    A.OneOf([
        A.OpticalDistortion(distort_limit=1.0),
        A.GridDistortion(num_steps=5, distort_limit=1.),
        A.ElasticTransform(alpha=3),
    ], p=AUG_PROB),

    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=AUG_PROB),
    # A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
    A.CoarseDropout(max_holes=16, max_height=64, max_width=64, min_holes=1, min_height=8, min_width=8, p=AUG_PROB),    
    A.Normalize(mean=0.5, std=0.5)
])

transforms_valid = A.Compose([
    # A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
    A.Normalize(mean=0.5, std=0.5)
])

In [None]:
class LSDCDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
        
    def __len__(self): return len(self.df)
    
    def __getitem__(self, i):
        
        row = self.df.loc[i]
        study_id = row["study_id"]
        series_id = row["series_id"]
        instance_number = row["instance_number"]
        
        dcm_src = f"/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images/{study_id}/{series_id}/{instance_number}.dcm"
        image = load_dicom(dcm_src, IMG_SIZE)
        
        if self.transforms is not None:
            image = self.transforms(image=image)["image"]
        image = torch.tensor(np.expand_dims(image, axis=0), dtype=torch.float)
        
        columns = list(self.df.columns)[3:]
        target = torch.tensor(row[columns].values, dtype=torch.float)
        
        return image, target

In [None]:
train_dataset = LSDCDataset(train_df, transforms=transforms_train)
train_dataloader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)

valid_dataset = LSDCDataset(valid_df, transforms=transforms_valid)
valid_dataloader = DataLoader(valid_dataset, batch_size=VALID_BATCH_SIZE, shuffle=False)

# Modelling

In [None]:
class Model(nn.Module):
    def __init__(self, num_classes):
        super(Model, self).__init__()
        
        self.conv_layer = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=1)
        self.cnn_model = timm.create_model('efficientnet_b0', num_classes=num_classes, pretrained=True)
    
    def forward(self, X):
        X = self.conv_layer(X)
        X = self.cnn_model(X)
        
        return X

# Training

In [None]:
model = Model(NUM_CLASSES)
model.to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()
scaler = GradScaler()

accuracy = Accuracy(task="binary").to(DEVICE)
f1 = F1Score(task="binary").to(DEVICE)
recall = Recall(task="binary").to(DEVICE)
auroc = AUROC(task="binary").to(DEVICE)

In [None]:
if LOAD_CHECKPOINT:
    checkpoint_path = "/kaggle/input/rsna-lsdc-models/0.2/10.pth"
    checkpoint = torch.load(checkpoint_path)

    model.load_state_dict(checkpoint)

In [None]:
def train_epoch(epoch, model, dataloader, loss_fn, optimizer, scaler):
    model.train()
    running_loss = .0
    running_accuracy = .0
    running_f1 = .0
    running_recall = .0
    running_auroc = .0
    
    print(f"Epoch [{epoch}/{EPOCHS}]")
    
    progress_bar = tqdm(dataloader, desc="Training", total=len(dataloader), unit="batch")
    for X, y in progress_bar:
        X, y = X.to(DEVICE), y.to(DEVICE)
        
        with autocast():
            output = model(X)
            loss = loss_fn(output, y)
            
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        accuracy_score = accuracy(output, y).item()
        f1_score = f1(output, y).item()
        recall_score = recall(output, y).item()
        auroc_score = auroc(output, y).item()
        
        running_loss += loss.item() * X.size(0)
        running_accuracy += accuracy_score * X.size(0)
        running_f1 += f1_score * X.size(0)
        running_recall += recall_score * X.size(0)
        running_auroc += auroc_score * X.size(0)
        
        metrics_dict = {
            "Batch Loss": loss.item(),
            "Accuracy": accuracy_score,
            "F1": f1_score,
            "Recall": recall_score,
            "AUCROC": auroc_score
        }
        progress_bar.set_postfix(metrics_dict)
        
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_accuracy = running_accuracy / len(dataloader.dataset)
    epoch_f1 = running_f1 / len(dataloader.dataset)
    epoch_recall = running_recall / len(dataloader.dataset)
    epoch_auroc = running_auroc / len(dataloader.dataset)
    
    return epoch_loss, epoch_accuracy, epoch_f1, epoch_recall, epoch_auroc

In [None]:
def valid_epoch(epoch, model, dataloader, loss_fn):
    model.eval()
    running_loss = .0
    running_accuracy = .0
    running_f1 = .0
    running_recall = .0
    running_auroc = .0
    
    progress_bar = tqdm(dataloader, desc="Validation", total=len(dataloader), unit="batch")
    with torch.inference_mode():
        for X, y in progress_bar:
            X, y = X.to(DEVICE), y.to(DEVICE)
            
            with autocast():
                output = model(X)
                loss = loss_fn(output, y)
                
            accuracy_score = accuracy(output, y).item()
            f1_score = f1(output, y).item()
            recall_score = recall(output, y).item()
            auroc_score = auroc(output, y).item()

            running_loss += loss.item() * X.size(0)
            running_loss += loss.item() * X.size(0)
            running_accuracy += accuracy_score * X.size(0)
            running_f1 += f1_score * X.size(0)
            running_recall += recall_score * X.size(0)
            running_auroc += auroc_score * X.size(0)

            metrics_dict = {
                "Batch Loss": loss.item(),
                "Accuracy": accuracy_score,
                "F1": f1_score,
                "Recall": recall_score,
                "AUCROC": auroc_score
            }
            progress_bar.set_postfix(metrics_dict)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_accuracy = running_accuracy / len(dataloader.dataset)
    epoch_f1 = running_f1 / len(dataloader.dataset)
    epoch_recall = running_recall / len(dataloader.dataset)
    epoch_auroc = running_auroc / len(dataloader.dataset)
    
    return epoch_loss, epoch_accuracy, epoch_f1, epoch_recall, epoch_auroc

In [None]:
os.makedirs("saved-models", exist_ok=True)

def save_model(model, epoch):
    PATH = f"saved-models/{epoch}.pth"
    torch.save(model.state_dict(), PATH)

In [None]:
for epoch in range(11, EPOCHS+11):
    train_epoch_loss = train_epoch(epoch, model, train_dataloader, loss_fn, optimizer, scaler)
    valid_epoch_loss = valid_epoch(epoch, model, valid_dataloader, loss_fn)
    
    print(f"Training Loss - {train_epoch_loss}")
    print(f"Validation Loss - {valid_epoch_loss}")
    
    save_model(model, epoch)