In [49]:
import os
from pathlib import Path
import torchio as tio
import nibabel as nib
import numpy as np
import glob
import pandas as pd
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from monai.networks.nets import resnet
import torch.nn as nn
import torch
from sklearn.model_selection import KFold

Skull stripping with HD-BET

In [None]:
!nohup hd-bet -i /home/ubuntu/volume/Tinception/VBM -o . --save_bet_mask > hdbet.log 2>&1 &

Creating the dataset with Torch IO

In [33]:
df = pd.read_csv("/Users/payamsadeghishabestari/tinception/material/with_qc/behavioural/tinception_master_ready_for_matching_qc.csv")

tio_subjects = []
subjects = []
bet_dir = Path("/Users/payamsadeghishabestari/temp_folder/bet")
for fname in bet_dir.iterdir():
    if fname.suffix == ".gz" and "_bet" not in fname.stem:
        subjects.append(fname.stem[:-4])

for sid in sorted(subjects):
    row = df.loc[df['subject ID'] == sid].iloc[0]

    subject = tio.Subject(
                        t1=tio.ScalarImage(bet_dir / f"{sid}.nii.gz"),
                        brain_mask=tio.LabelMap(bet_dir / f"{sid}_bet.nii.gz"),
                        sex = row['sex'],
                        age = row['age'],
                        site = row['site'],
                        PTA = row['PTA'],
                        label = row['group']
                        )
    tio_subjects.append(subject)

preprocess = tio.Compose([
                        tio.ToCanonical(),                   
                        tio.Resample(1),                     
                        tio.ZNormalization(masking_method='brain_mask'),  
                        tio.CropOrPad((160, 192, 160))      
                        ])

dataset = tio.SubjectsDataset(tio_subjects, transform=preprocess)

Training

In [None]:
subjects_list = list(range(len(dataset)))  # indices
kf = KFold(n_splits=5, shuffle=True, random_state=42)

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0

    def step(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

def save_checkpoint(model, optimizer, epoch, path="model.pth"):
    torch.save({
        'epoch': epoch,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
    }, path)

In [None]:
criterion = nn.CrossEntropyLoss()

for fold, (train_idx, val_idx) in enumerate(kf.split(subjects_list)):
    print(f"\n===== FOLD {fold+1} =====")

    train_set = torch.utils.data.Subset(dataset, train_idx)
    val_set   = torch.utils.data.Subset(dataset, val_idx)
    
    train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=1)

    model = resnet.ResNet(
                            block=resnet.ResNetBottleneck,
                            layers=[2, 2, 2, 2],
                            block_inplanes=[64, 128, 256, 512],
                            spatial_dims=3,
                            n_input_channels=1,
                            num_classes=2,
                        ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    early_stop = EarlyStopping(patience=8)

    for epoch in range(100):
        # ---- Training ----
        model.train()
        for batch in train_loader:
            images = batch['t1'][tio.DATA].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # ---- Validation ----
        model.eval()
        val_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in val_loader:
                images = batch['t1'][tio.DATA].to(device)
                labels = batch['label'].to(device)

                outputs = model(images)
                val_loss += criterion(outputs, labels).item()

                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        
        val_acc = correct / total
        print(f"Epoch {epoch} | Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.2f}")

        # ---- Early stopping ----
        if early_stop.step(val_loss):
            print("Early stopping triggered.")
            break

        # ---- Save best checkpoint ----
        if early_stop.best_loss == val_loss:
            save_checkpoint(model, optimizer, epoch, f"best_fold{fold+1}.pth")