In [1]:
import os
import pandas as pd
import torch
import nibabel as nib
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

directory = '../../../Aims-Tbi'
demographics_file = '../../TestSet_demographics.xlsx'

class MRIDataset(Dataset):
    def __init__(self, directory, demographics_file, transform=None):
        self.directory = directory
        self.transform = transform

        self.demographics = pd.read_excel(demographics_file)

        self.demographics['ScanManufacturer'] = pd.Categorical(self.demographics['ScanManufacturer']).codes
        self.demographics['Sex'] = self.demographics['Sex'].astype(float)

        self.mri_paths = []
        self.labels_paths = []
        self.diag_data = []

        available_mris = set(f.replace('_T1.nii.gz', '') for f in os.listdir(directory) if f.endswith('_T1.nii.gz'))

        for _, row in self.demographics.iterrows():
            rand_id = row['RandID']
            if rand_id in available_mris:
                mri_file = os.path.join(directory, f"{rand_id}_T1.nii.gz")
                label_file = os.path.join(directory, f"{rand_id}_Lesion.nii.gz")
                if os.path.exists(mri_file) and os.path.exists(label_file):
                    self.mri_paths.append(mri_file)
                    self.labels_paths.append(label_file)
                    self.diag_data.append(row)

    def __len__(self):
        return len(self.mri_paths)

    def __getitem__(self, idx):
        mri_path = self.mri_paths[idx]
        label_path = self.labels_paths[idx]
        diag_row = self.diag_data[idx]

        mri_img = nib.load(mri_path).get_fdata()
        label_img = nib.load(label_path).get_fdata()

        mri_tensor = torch.tensor(mri_img, dtype=torch.float32).unsqueeze(0)  # 添加一个通道维度
        label_tensor = torch.tensor(label_img, dtype=torch.float32).unsqueeze(0)

        diag_values = [float(diag_row['Age']), diag_row['Sex'], float(diag_row['TSI']), diag_row['ScanManufacturer']]
        diag_tensor = torch.tensor(diag_values, dtype=torch.float32)

        if self.transform:
            mri_tensor = self.transform(mri_tensor)
            label_tensor = self.transform(label_tensor)

        return mri_tensor, label_tensor, diag_tensor



transform = transforms.Compose([
    transforms.Resize(256)  # Adjust size if necessary
])
dataset = MRIDataset(directory=directory, demographics_file=demographics_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

print("The length of the dataloader:", len(dataloader));


total_size = len(dataset)
print(f"Total dataset size: {total_size}")

train_size = int(total_size * 0.70)
val_size = int(total_size * 0.15)
test_size = total_size - train_size - val_size  # 保证所有样本都被使用

if val_size == 0 or test_size == 0:
    raise ValueError("Dataset too small to split according to the specified ratios. Increase dataset size or adjust the ratios.")

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

print(f"Training set size: {len(train_loader)}")
print(f"Validation set size: {len(val_loader)}")
print(f"Test set size: {len(test_loader)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

weights_dir = 'weights'
os.makedirs(weights_dir, exist_ok=True)

def train_model(model, train_loader, criterion, optimizer, num_epochs=20):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}")
        for i, (mri, labels, diag) in progress_bar:
            optimizer.zero_grad()
            
            mri, labels, diag = mri.to(device), labels.to(device), diag.to(device)
            
            labels = torch.sigmoid(labels) 
            
            outputs = model(mri, diag)
            outputs = torch.sigmoid(outputs)
            
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            progress_bar.set_postfix(loss=running_loss / (i + 1))
            torch.cuda.empty_cache()

        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
        torch.save(model.state_dict(), os.path.join(weights_dir, f'model_epoch_{epoch+1}.pth'))


The length of the dataloader: 388
Total dataset size: 388
Training set size: 271
Validation set size: 58
Test set size: 59


In [None]:
from vnet import VNetWithDiagnosis

model = VNetWithDiagnosis().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()
scaler = GradScaler()

train_model(model, train_loader, criterion, optimizer)

torch.save(model.state_dict(), 'trained_vnet.pth')
