In [1]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchio as tio
import hashlib
from skimage.transform import resize
from torch.utils.data import Dataset, DataLoader, random_split
from src.soa_preprocessing import extract_subcubes, compute_subcubes_count, reconstruct_image, evaluate_dice_new, plot_3d_interactive, evaluateSegmentation
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.networks.nets import SwinUNETR
from monai.networks.nets import VNet
from datetime import datetime

In [2]:
bin_threshold = 0.5 
using_decomp = True
using_aug = False 
m_choice = 0 # from 0 to 3 (model choice)

In [4]:
if m_choice == 0:
    model = VNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1
    ).cuda()

elif m_choice == 1:
    model = SwinUNETR(
        img_size=target_size,  
        in_channels=1,
        out_channels=1,
        feature_size=48,
        use_checkpoint=True,
        use_v2=True
    ).cuda()

elif m_choice == 2:
    model = BasicUNetPlusPlus(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        features=(16, 32, 64, 128, 256, 16)  
    ).cuda()

elif m_choice == 2:
    model = SwinUNETR(
        img_size=target_size,  
        in_channels=1,
        out_channels=1,
        feature_size=48,
        use_checkpoint=True,
        use_v2=False
    ).cuda()

else:
    print("No model selected!")


criterion = criterion = DiceLoss(include_background=False, to_onehot_y=False, sigmoid=True)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

dice_metric = DiceMetric(include_background=False, reduction="mean")

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

In [None]:
def train(model, train_loader, val_loader, fold_idx, epochs=5):
    
    data_dict = # include dataset path here in .pt format, with torch.load()
    
    train_dataset = data_dict['train']
    val_dataset = data_dict['val']

    target_size = (64, 64, 64) if using_decomp else (128, 128, 128)

    batch_size_train = 8 if using_decomp else 2
    batch_size_val = 27 if using_decomp else 1

    train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=False)
    
    for epoch in range(epochs):
        print(f"{datetime.now().strftime('%H:%M:%S')} | Starting new epoch...")
        model.train()
        epoch_loss = 0
        for images, masks in train_loader:
            images, masks = images.cuda(), masks.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            if m_choice == 2:
                outputs = outputs[0] 
            masks = masks.clamp(0, 1)
            loss = criterion(outputs, masks)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item()

        if epoch % 4 == 0:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
            }
            torch.save(checkpoint, f"models/{m_choice}_fold{fold_idx}_epoch{epoch+1}_aug{using_aug}.pth")
        print(f"Epoch {epoch+1} - Training Loss: {epoch_loss / len(train_loader):.4f}")
        
        # Validation
        
        model.eval()
        with torch.no_grad():
            dice_scores = []
            iou_scores = []
            tnr_scores = []
            tpr_scores = []
            fdr_scores = []
            i=0
            for images, masks in val_loader: 
                images, masks = images.cuda(), masks.cuda()
                outputs = model(images)
                if m_choice == 2:
                  outputs = outputs[0] 
                pred_masks = outputs.float().cpu().numpy()
                masks = masks.cpu().numpy()
                
                if using_decomp == True:   
                    pred_masks = reconstruct_image(pred_masks, (128,128,128,1), (64,64,64,1), 32)
                    masks = reconstruct_image(masks, (128,128,128,1), (64,64,64,1), 32)
                    

                eval = evaluateSegmentation([masks], [np.copy(pred_masks)],t=bin_threshold, det_t=bin_threshold)

                dice_score = eval["dice"][0]
                iou_score = eval["iou"][0]
                tnr_score = eval["tnr"][0]
                tpr_score = eval["tpr"][0]
                fdr_score = eval["fdr"][0]
                
                iou_scores.append(iou_score)
                tnr_scores.append(tnr_score)
                tpr_scores.append(tpr_score)
                fdr_scores.append(fdr_score)
                dice_scores.append(dice_score)
                i+=1
                
            mean_dice = np.mean(dice_scores)
            mean_iou = np.mean(iou_scores)
            mean_tnr = np.mean(tnr_scores)
            mean_tpr = np.mean(tpr_scores)
            mean_fdr = np.mean(fdr_scores)
            print(f"Epoch {epoch+1}: DICE={mean_dice} | IOU={mean_iou} | TNR={mean_tnr} | TPR={mean_tpr} | FDR={mean_fdr}")
            scheduler.step(mean_dice)

In [None]:
for fold_idx in range(5):
    train(model, train_loader, val_loader, fold_idx, epochs=17)