In [6]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
token = user_secrets.get_secret("github_token")

In [None]:
!git clone https://{token}@github.com/seprow/BrainMriClassification-IAAA-.git

In [None]:
#!pip install -r '/kaggle/working/BrainMriClassification-IAAA-/requirements.txt'
!pip install SimpleITK
!pip install pydicom
!pip install albumentations
!pip install antspyx
!pip install plotly ipywidgets
!pip install monai

In [14]:
import pandas as pd
import os
import sys
from pathlib import Path
import SimpleITK as sitk

sys.path.append('/kaggle/working/BrainMriClassification-IAAA-')
from Preprocess import *
from DataGenerator import *

import torch
from torch.utils.data import DataLoader
torch.cuda.empty_cache() # Clears all the cache allocated by PyTorch on the GPU

!nvidia-smi

/bin/bash: nvidia-smi: command not found


In [17]:
ROOT_DATA_DIR = Path(r'/kaggle/input/iaaa-mri-challenge').expanduser().absolute()
DATA_DIR = os.path.join(ROOT_DATA_DIR,'data')
LABELS_PATH = os.path.join(ROOT_DATA_DIR,'train.csv')

# Data Handling

In [18]:
annotations = pd.read_csv(LABELS_PATH)
list_id = [
    '1.3.46.670589.11.10042.5.0.7984.2024022316295067732',
    '1.3.46.670589.11.10042.5.0.8184.2024011321595084988',
    '1.3.46.670589.11.10042.5.0.5548.2024010521045198196',
    '1.3.46.670589.11.10042.5.0.6048.2024030612191717163',
    '1.3.46.670589.11.10042.5.0.3364.2024011206110656762',
    '1.3.46.670589.11.10042.5.0.5244.2024011800080517866',
    '1.3.46.670589.11.10042.5.0.5596.2024031317194490629',
    '1.3.46.670589.11.10042.5.0.5864.2024030418020328916',
    '1.3.46.670589.11.10042.5.0.6596.2024021917471807658',
    '1.3.46.670589.11.10042.5.0.5244.2024011514063439527',
    '1.3.46.670589.11.10042.5.0.7620.2023122711014093653',
    '1.3.46.670589.11.10042.5.0.6596.2024012110370925812',
    '1.3.46.670589.11.10042.5.0.5484.2024030209230264275',
    '1.3.46.670589.11.10042.5.0.1412.2024020409533565254',
    '1.3.46.670589.11.10042.5.0.6596.2024022109510462198'
]

removable_id = annotations[annotations['SeriesInstanceUID'].isin(list_id)].index
annotations.drop(removable_id, axis=0, inplace=True)

print(annotations['prediction'].value_counts())

prediction
0    2731
1     386
Name: count, dtype: int64


In [19]:
from helper import estimate_class_weights 
estimate_class_weights(annotations, method='mfb')

[0.57067007, 4.0375648]

In [20]:
from sklearn.model_selection import train_test_split

validation_size = 0.3

stratify_df = annotations['prediction']
training_df, validation_df = train_test_split(
    annotations,
    test_size=validation_size,
    stratify=annotations['prediction'],
    random_state=7
)

In [None]:
#hyperparameters
target_size = [288,288,16]
batch_size = 8

reference_image = {
    'T1W_SE': '/kaggle/input/iaaa-mri-challenge/data/1.3.46.670589.11.10042.5.0.1412.2024020321545257411',
    'T2W_TSE' : '/kaggle/input/iaaa-mri-challenge/data/1.3.46.670589.11.10042.5.0.1412.2024020313391873234',
    'T2W_FLAIR' : '/kaggle/input/iaaa-mri-challenge/data/1.3.46.670589.11.10042.5.0.1412.2024020410193570629',
}


# Dataset & Dataloder
train_dataset = DICOMDataGenerator(df=training_df, target_size=target_size, data_dir=DATA_DIR, reference_image_path=reference_image, shuffle=True)
val_dataset = DICOMDataGenerator(df=validation_df, target_size=target_size, data_dir=DATA_DIR, reference_image_path=reference_image,shuffle=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Create a sample batch
sample_batch = next(iter(train_loader))
images, labels = sample_batch

print(f'Image batch shape: {images.shape}')
print(f'Label batch shape: {labels.shape}')

# Trainer

In [None]:
from models import *
from helper import print_trainable_parameters

model = SimpleVASNet()
print_trainable_parameters(model)

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import numpy as np
from sklearn.metrics import roc_auc_score, recall_score, precision_score, accuracy_score, precision_recall_curve, auc, roc_curve
from torch.utils.tensorboard import SummaryWriter  
import logging  
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

# Set PYTORCH_CUDA_ALLOC_CONF to reduce memory fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

class Trainer:
    def __init__(self, train_data, val_data, model, num_epochs, lr, lr_decay_epoch, threshold=0.5, device=None, log_dir='logs'):
        self.train_data = train_data
        self.val_data = val_data
        self.model = model
        self.num_epochs = num_epochs
        self.lr = lr
        self.lr_decay_epoch = lr_decay_epoch
        self.best_val_loss = float('inf')
        self.threshold = threshold

        # Device configuration
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)

        self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([6.3], dtype=torch.float32).to(self.device))
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4)

        if os.path.exists('best_model.pth'):
            self.model.load_state_dict(torch.load('best_model.pth', map_location=self.device))
            print("Loaded best model from 'best_model.pth'")


        # Initialize TensorBoard
        self.writer = SummaryWriter(log_dir=log_dir)

        # Set up logging
        logging.basicConfig(filename=os.path.join(log_dir, 'training.log'), level=logging.INFO)
        self.logger = logging.getLogger()

    def adjust_learning_rate(self, epoch):
        """Decay the learning rate after a specific number of epochs."""
        if epoch > self.lr_decay_epoch:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr * 0.5

    def train_one_epoch(self):
        """Train the model for one epoch."""
        self.model.train()
        train_loss = 0.0
        all_train_targets = []
        all_train_outputs = []

        with tqdm(self.train_data, desc="Training") as pbar:  # Correct tqdm usage
            for inputs, targets in pbar:  # Indentation corrected
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()
                all_train_targets.extend(targets.cpu().numpy())
                all_train_outputs.extend(torch.sigmoid(outputs).detach().cpu().numpy())
                
                pbar.set_description(f"Training (Loss: {train_loss / len(self.train_data):.4f})")  # Fixed variable name

        train_loss /= len(self.train_data)
        train_accuracy = accuracy_score(np.array(all_train_targets), np.array(all_train_outputs).round())

        return train_loss, train_accuracy

    def validate(self, threshold=0.4):
        """Validate the model on the validation dataset"""
        self.model.eval()
        val_loss = 0.0
        all_val_targets = []
        all_val_outputs = []

        with torch.no_grad():
            for inputs, targets in self.val_data:
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                val_loss += loss.item()

                all_val_targets.extend(targets.cpu().numpy())
                all_val_outputs.extend(torch.sigmoid(outputs).detach().cpu().numpy())
                
        # Apply the custom threshold to determine positive class predictions
        binary_predictions = (np.array(all_val_outputs) >= threshold).astype(int)

        val_loss /= len(self.val_data)
        val_accuracy = accuracy_score(np.array(all_val_targets), binary_predictions)
        fpr, tpr, thresholds_roc = roc_curve(np.array(all_val_targets), np.array(all_val_outputs))
        roc_auc = auc(fpr, tpr)  # Fixed AUC naming conflict
        precision, recall, _ = precision_recall_curve(np.array(all_val_targets), np.array(all_val_outputs))
        pr_auc = auc(recall, precision)  # Fixed precision-recall AUC calculation

        return val_loss, val_accuracy, roc_auc, pr_auc

    def train_and_evaluate(self, threshold=0.4):
        """Train and evaluate the model over multiple epochs."""
        with tqdm(range(self.num_epochs), desc="Epoch") as pbar:
            for epoch in pbar:
                self.adjust_learning_rate(epoch)

                train_loss, train_accuracy = self.train_one_epoch()
                val_loss, val_accuracy, roc_auc, pr_auc = self.validate(threshold=threshold)

                # Log metrics to TensorBoard
                self.writer.add_scalar('Loss/Train', train_loss, epoch)
                self.writer.add_scalar('Loss/Validation', val_loss, epoch)
                self.writer.add_scalar('Accuracy/Train', train_accuracy, epoch)
                self.writer.add_scalar('Accuracy/Validation', val_accuracy, epoch)
                self.writer.add_scalar('AUC/Validation', roc_auc, epoch)
                self.writer.add_scalar('pr_auc/Validation', pr_auc, epoch)

                # Log to the console and file
                log_message = (
                    f'Epoch {epoch+1}/{self.num_epochs}, '
                    f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, '
                    f'Validation Loss: {val_loss:.4f}, AUC: {roc_auc:.4f}, '
                    f'pr_auc: {pr_auc:.4f}, '
                    f'Validation Accuracy: {val_accuracy:.4f}'
                )
                print(log_message)
                self.logger.info(log_message)

                # Save the best model
                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    torch.save(self.model.state_dict(), 'best_model.pth')
                    print(f'Best model saved with validation loss: {val_loss:.4f}')
                    self.logger.info(f'Best model saved with validation loss: {val_loss:.4f}')

In [None]:
num_epochs = 10
learning_rate = 0.0001
lr_decay_epoch = 10
threshold = 0.4

trainer = Trainer(train_data=train_loader, val_data=val_loader, model=model, num_epochs=num_epochs, lr=learning_rate, lr_decay_epoch=lr_decay_epoch ,threshold = threshold)

trainer.train_and_evaluate()