In [None]:
#imports needed packages
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path
import pydicom

import os
from os.path import join
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import convnext_small

from tqdm import tqdm
from sklearn.model_selection import KFold

import warnings
warnings.filterwarnings("ignore")

In [None]:
#Creates config class
class CFG:
    verbose = 1
    seed = 21
    sag_labels = 45
    axial_labels = 30
    image_size = [512,512]
    epochs = 1
    batch_size = 12
    patience = 1
    learning_rate = 0.001
    
# Set seed for reproducibility
torch.manual_seed(CFG.seed)

# Dataset And Transformations

In [None]:
class MRIDataset(Dataset):
    def __init__(self, dataframe, base_dir, transform=None):
        self.dataframe = dataframe
        self.base_dir = base_dir
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_filenames = [self.dataframe.iloc[idx, i] for i in range(1, 16)] # columns wiht file paths
        img_paths = [os.path.join(self.base_dir, fname) for fname in img_filenames]
        images = []

        for img_path in img_paths:
            dicom_data = pydicom.dcmread(img_path)
            image = dicom_data.pixel_array.astype(float)

            if dicom_data.PhotometricInterpretation == "MONOCHROME1":
                image = np.amax(image) - image

            image = (image - np.mean(image)) / np.std(image)
            image = Image.fromarray((image * 255).astype(np.uint8))

            if self.transform:
                image = self.transform(image)

            images.append(image)

        # Concatenate images along the channel dimension
        images = torch.cat(images, dim=0)
        
        # Check and remove the unnecessary singleton dimension
        if images.dim() == 5:
            images = images.squeeze(2)

        labels = self.dataframe.iloc[idx, 16:].values.astype(float) #columns with OHE values
        return images, labels

In [None]:
# Function to create datasets and dataloaders
def create_datasets(train_df, val_df, base_dir):
    transform = transforms.Compose([
        transforms.Resize((CFG.image_size[0], CFG.image_size[1])),
        transforms.RandomRotation(degrees=15),
        transforms.RandomAffine(degrees=0, translate=(0.15, 0.15)),
        transforms.ToTensor()
    ])

    train_dataset = MRIDataset(train_df, base_dir=base_dir, transform=transform)
    val_dataset = MRIDataset(val_df, base_dir=base_dir, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False)

    return train_loader, val_loader

# Model Prep

In [None]:
def prepare_model(num_labels):
    model = convnext_small(pretrained=True)
    in_channels = 15  # Number of channels 
    
    # Modify the first convolutional layer
    model.features[0][0] = nn.Conv2d(in_channels, 96, kernel_size=4, stride=4)
    
    # Modify the classifier to match the number of labels
    model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_labels)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    return model, device

In [None]:
# function calculates the positive class weights for a set of labels, to handle class imbalances
def calculate_pos_weight(labels):
    pos_weights = []
    labels = labels.cpu().numpy()
    num_samples, num_labels = labels.shape
    for i in range(num_labels):
        pos_count = labels[:, i].sum()
        if pos_count == 0:
            pos_weight = 0.0 
        else:
            pos_weight = (num_samples - pos_count) / pos_count
        pos_weights.append(pos_weight)
    pos_weight_tensor = torch.tensor(pos_weights, dtype=torch.float32)
    return pos_weight_tensor

In [None]:
def train_model(model, train_loader, val_loader, device, CFG):
    for images, labels in train_loader:
        pos_weight = calculate_pos_weight(labels)
        break
    
    pos_weight = pos_weight.to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.Adam(model.parameters(), lr=CFG.learning_rate)
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(CFG.epochs):
        model.train()
        running_lossdef train_model(model, train_loader, val_loader, device, CFG):
    # Computes initial positive weights
    for images, labels in train_loader:
        pos_weight = calculate_pos_weight(labels)
        break
    
    pos_weight = pos_weight.to(device)
    
    # Defines the loss function with class weights
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    # Defines the optimizer
    optimizer = optim.Adam(model.parameters(), lr=CFG.learning_rate)
    
    best_val_loss = float('inf') 
    patience_counter = 0

    # Training loop
    for epoch in range(CFG.epochs):
        model.train() 
        running_loss = 0.0

        # Creates a progress bar for the training loop
        train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CFG.epochs}', leave=False)

        # Iterates over batches in the training set
        for images, labels in train_loader_tqdm:
            images = images.to(device) 
            labels = labels.to(device) 

            optimizer.zero_grad()  # Zerosthe parameter gradients
            outputs = model(images) 
            loss = criterion(outputs, labels)  # Computes loss

            # Checks for NaN values in the loss
            if torch.isnan(loss):
                print("NaN loss detected, breaking out of training loop.")
                return model

            loss.backward() 
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clips gradients to avoid exploding gradients
            optimizer.step()  # Updates model parameters
            running_loss += loss.item()

            # Update progress bar with the average loss
            train_loader_tqdm.set_postfix(loss=running_loss/len(train_loader))

        # Compute average training loss for this epoch
        avg_train_loss = running_loss / len(train_loader)

        
        
       # Validation loop
        val_loss = 0.0  # Initialize validation loss

        model.eval()
        val_loader_tqdm = tqdm(val_loader, desc='Validation', leave=False)

        with torch.no_grad():  # Disable gradient calculations
            for images, labels in val_loader_tqdm:
                images = images.to(device) 
                labels = labels.to(device) 
                outputs = model(images) 
                loss = criterion(outputs, labels) 

                if torch.isnan(loss):
                    print("NaN loss detected during validation, breaking out of validation loop.")
                    return model

                val_loss += loss.item()
                val_loader_tqdm.set_postfix(val_loss=val_loss/len(val_loader))

        avg_val_loss = val_loss / len(val_loader)

        # Print epoch results
        print(f'Epoch [{epoch+1}/{CFG.epochs}], '
              f'Loss: {avg_train_loss:.4f}, '
              f'Val Loss: {avg_val_loss:.4f}')

        # Check for early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss 
            patience_counter = 0 
        else:
            patience_counter += 1 
            if patience_counter >= CFG.patience:
                print("Early stopping")
                break  # Stop training if patience is reached

    return model, best_val_loss = 0.0

        train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CFG.epochs}', leave=False)

        for images, labels in train_loader_tqdm:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            if torch.isnan(loss):
                print("NaN loss detected, breaking out of training loop.")
                return model

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_loss += loss.item()

            train_loader_tqdm.set_postfix(loss=running_loss/len(train_loader))

        avg_train_loss = running_loss / len(train_loader)

        val_loss = 0.0
        model.eval()

        val_loader_tqdm = tqdm(val_loader, desc='Validation', leave=False)

        with torch.no_grad():
            for images, labels in val_loader_tqdm:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                if torch.isnan(loss):
                    print("NaN loss detected during validation, breaking out of validation loop.")
                    return model

                val_loss += loss.item()
                val_loader_tqdm.set_postfix(val_loss=val_loss/len(val_loader))

        avg_val_loss = val_loss / len(val_loader)

        print(f'Epoch [{epoch+1}/{CFG.epochs}], '
              f'Loss: {avg_train_loss:.4f}, '
              f'Val Loss: {avg_val_loss:.4f}')

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= CFG.patience:
                print("Early stopping")
                break

    return model, best_val_loss

In [None]:
# Cross-validation function
def cross_validate_model(dataframe, base_dir, num_labels, k_folds=5):
    kf = KFold(n_splits=k_folds)
    fold_results = []  
    best_model = None  
    best_fold_loss = float('inf') 

    # Iterate through each fold
    for fold, (train_idx, val_idx) in enumerate(kf.split(dataframe)):
        print(f'Fold {fold+1}/{k_folds}')
        
        # Splits dataframe into training and validation sets
        train_df = dataframe.iloc[train_idx]
        val_df = dataframe.iloc[val_idx]

        # Creates data loaders for the current fold
        train_loader, val_loader = create_datasets(train_df, val_df, base_dir)
        
        # Prepares model and device
        model, device = prepare_model(num_labels)
        
        # Trains the model and evaluate it on the validation set
        trained_model, val_loss = train_model(model, train_loader, val_loader, device, CFG)

        # Check if the current fold has the best validation loss
        if val_loss < best_fold_loss:
            best_fold_loss = val_loss  
            best_model = trained_model

        # Store validation loss for the current fold
        fold_results.append(val_loss)
        print(f'Fold {fold+1}/{k_folds} - Validation Loss: {val_loss:.4f}')

    # Calculate mean validation loss across all folds
    mean_val_loss = np.mean(fold_results)
    print(f'Mean Validation Loss: {mean_val_loss:.4f}')
    
    return best_model, mean_val_loss

In [None]:
# Function to save the model
def save_model(model, file_path):
    torch.save(model.state_dict(), file_path)


# Training

In [None]:
#Imports CSV files needed for notebook
base_dir = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images/'
df_path = '/kaggle/input/mri-datasets'

axial_df = pd.read_csv(df_path + '/axial_df.csv')

In [None]:
best_model, mean_val_loss = cross_validate_model(axial_df, base_dir, CFG.axial_labels) #change for model type

In [None]:
save_model(best_model, 'axial_model.pth')