In [1]:
import os
import pandas as pd
import torch
import nibabel as nib
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

from model_v2 import CombinedVNetModel


from data_loader import load_nifti, resample
import logging
import numpy as np



In [2]:

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

directory = '../../../Aims-Tbi'
demographics_file = '../../TestSet_demographics_with_lesion.csv'

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


In [3]:

class MRIDataset(Dataset):
    def __init__(self, directory, demographics_file, transform=None):
        self.directory = directory
        self.transform = transform

        self.demographics = pd.read_csv(demographics_file)

        self.demographics['Sex'] = self.demographics['Sex'].astype(float)

        self.t1_paths = []
        self.diag_data = []

        available_mris = set(f.replace('_T1.nii.gz', '') for f in os.listdir(directory) if f.endswith('_T1.nii.gz'))

        for _, row in self.demographics.iterrows():
            rand_id = row['RandID']
            if rand_id in available_mris:
                mri_file = os.path.join(directory, f"{rand_id}_T1.nii.gz")
                if os.path.exists(mri_file):
                    self.t1_paths.append(mri_file)
                    self.diag_data.append(row)

    def __len__(self):
        return len(self.t1_paths)

    def __getitem__(self, idx):
        t1_path = self.t1_paths[idx]
        
        
        diag_row = self.diag_data[idx]

        mri_img = load_nifti(t1_path)
        
        mri_img = (mri_img - np.min(mri_img)) / (np.max(mri_img) - np.min(mri_img))

        mri_img = resample(mri_img, (256, 256, 256))
        
        mri_tensor = torch.tensor(mri_img, dtype=torch.float32).unsqueeze(0)
        
        lebel_data = diag_row['Lesion']
        label_tensor = torch.tensor(lebel_data, dtype=torch.float32).unsqueeze(0)
        
        manufacturer_encoding = {'Siemens': 1, 'Philips': 2, 'GE': 3}

        diag_values = [
            float(diag_row['Age']),
            float(diag_row['Sex']),
            float(diag_row['TSI']),
            manufacturer_encoding[diag_row['ScanManufacturer']]
        ]

        diag_numpy = np.array(diag_values, dtype=np.float32)
        diag_tensor = torch.tensor(diag_numpy)

        if self.transform:
            mri_tensor = self.transform(mri_tensor)
            label_tensor = self.transform(label_tensor)

        return mri_tensor, label_tensor, diag_tensor


In [4]:
def load_data(directory, demographics_file, batch_size):
    
    
    train_dataset_proportion = 0.25
    val_dataset_proportion = 0.1
    test_dataset_proportion = 0.3

    dataset = MRIDataset(directory=directory, demographics_file=demographics_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    print("The length of the dataloader:", len(dataloader));

    total_size = len(dataset)
    
    print(f"Total dataset size: {total_size}")

    train_size = int(total_size * train_dataset_proportion)
    val_size = int(total_size * val_dataset_proportion)
    test_size = total_size - train_size - val_size

    if val_size == 0 or test_size == 0:
        raise ValueError("Dataset too small to split according to the specified ratios. Increase dataset size or adjust the ratios.")

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    print(f"Training set size: {len(train_loader)}")
    print(f"Validation set size: {len(val_loader)}")
    print(f"Test set size: {len(test_loader)}")
    
    return train_loader, val_loader, test_loader

In [5]:
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)

def load_checkpoint(filename, model, optimizer):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logging.info(f"Loaded checkpoint '{filename}' (epoch {start_epoch})")
        return start_epoch
    else:
        logging.info(f"No checkpoint found at '{filename}'")
        return 0

In [6]:
def check_values(output, n):
    count = torch.sum(output > 0.5).item()
    
    if count >= n:
        return 1
    else:
        return 0

In [8]:
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=10, device='cpu', checkpoint_file='checkpoint.pth.tar'):
    start_epoch = load_checkpoint(checkpoint_file, model, optimizer)

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0
        progress_bar = tqdm(total=len(train_loader.dataset), desc=f"Epoch {epoch + 1}/{num_epochs}")

        for t1, lesion, extra_features in train_loader:
            t1 = t1.to(device)
            lesion = lesion.to(device).float()
            extra_features = extra_features.to(device)

            optimizer.zero_grad()
            outputs = model(t1, extra_features)
            loss = criterion(outputs, lesion)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            progress_bar.update(len(t1))

        progress_bar.set_postfix(loss=train_loss/len(train_loader))

        val_loss = 0
        model.eval()
        with torch.no_grad():
            for t1, lesion, extra_features in val_loader:
                t1 = t1.to(device)
                lesion = lesion.to(device)
                extra_features = extra_features.to(device)

                outputs = model(t1, extra_features)
                loss = criterion(outputs, lesion)
                val_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}")

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, checkpoint_file)

    torch.save(model.state_dict(), 'final_model.pth')
    print('Model saved as final_model.pth')

In [None]:
if not os.path.exists(directory):
    raise FileNotFoundError(f"Data directory not found: {data_dir}")
if not os.path.exists(demographics_file):
    raise FileNotFoundError(f"Demographics file not found: {demographics_file}")

num_extra_features = 4
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model = CombinedVNetModel(num_extra_features, use_downsampling=True).to(device)


criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
    
batch_size = 1
num_epochs = 5

train_loader, val_loader, test_loader = load_data(directory, demographics_file, batch_size)

train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs, device=device, checkpoint_file='checkpoint.pth.tar')
