In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torch.optim as optim
import os
import numpy as np
from matplotlib import pyplot as plt
from timeit import default_timer as timer
import psutil
from os.path import join

from data_load import Dataset
from model.LN_UXFormer import LN_UXFormer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

root_dir = '/data3/glocal_project/multiclass_seg_kid_ney/data' 
raw_data_dir = join(root_dir, 'Training')
results_dir = '/data3/glocal_project/LN_UXFormer/result_bestmodel'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print('CUDA version:', torch.version.cuda)
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB')
    torch.cuda.empty_cache()

print(f'RAM: {round(psutil.virtual_memory()[0]/1000000000, 2)} GB')

Using device: cuda
GPU: NVIDIA GeForce RTX 3090
CUDA version: 11.1
GPU Memory: 23.69 GB
RAM: 269.95 GB




In [2]:
def make_seg_loaders(data, batch_size):
    N = len(data)
    N_train = int(0.9 * N)
    N_dev = N - N_train
    
    print(f'Total: {N}, Train: {N_train}, Validation: {N_dev}')
    
    train_data, dev_data = D.random_split(data, [N_train, N_dev])
    
    train_loader = D.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    dev_loader = D.DataLoader(dev_data, batch_size=batch_size, shuffle=True)
    
    print(f'Train batches: {len(train_loader)}, Validation batches: {len(dev_loader)}')
    
    return train_loader, dev_loader

In [3]:
batch_size = 16
dataset = Dataset(raw_data_dir)

train_loader, dev_loader = make_seg_loaders(dataset, batch_size)

Found 2497 valid image pairs
Skipped 1 pairs with empty masks
Total: 2497, Train: 2247, Validation: 250
Train batches: 141, Validation batches: 16


In [4]:
seg_model = LN_UXFormer(n_channels=1, n_classes=1).to(device)

optimizers = {
    'main': optim.AdamW(seg_model.parameters(), lr=1e-3)
}

schedulers = optim.lr_scheduler.ReduceLROnPlateau(
    optimizers['main'],
    mode='min',
    factor=0.9,
    patience=3,
    verbose=True,
    min_lr=1e-8
)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
def binary_dice_loss_sigmoid(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred_flat = pred.reshape(pred.size(0), -1)
    target_flat = target.reshape(target.size(0), -1)
    
    intersection = (pred_flat * target_flat).sum(dim=1)
    union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
    
    dice = (2. * intersection + smooth) / (union + smooth)
    loss = 1 - dice.mean()
    
    return loss


def combined_binary_loss_sigmoid(pred, target, dice_weight=0.95, ce_weight=0.05):
    dice = binary_dice_loss_sigmoid(pred, target)
    
    bce_loss = nn.BCEWithLogitsLoss()
    pred_reshaped = pred.view(pred.size(0), -1)
    target_reshaped = target.view(target.size(0), -1)
    ce = bce_loss(pred_reshaped, target_reshaped)
    
    total_loss = dice_weight * dice + ce_weight * ce
    
    return total_loss, dice.item(), ce.item()


def train_epoch(seg_model, optimizers, train_loader, epoch):
    seg_model.train()
    total_loss = 0
    total_dice_value = 0
    total_ce_value = 0

    for batch in train_loader:
        flair, seg = batch
        flair = flair.to(device)
        seg = seg.to(device)
        
        if isinstance(optimizers, dict):
            for optimizer in optimizers.values():
                optimizer.zero_grad()
        else:
            optimizers.zero_grad()
            
        output = seg_model(flair)
        dice_loss, dice_value, ce_value = combined_binary_loss_sigmoid(output, seg)
        
        total_dice_value += dice_value
        total_ce_value += ce_value
        
        loss = torch.mean(dice_loss)
        total_loss += loss.item()
        
        loss.backward()
        
        if isinstance(optimizers, dict):
            for optimizer in optimizers.values():
                optimizer.step()
        else:
            optimizers.step()
        
    avg_dice_value = total_dice_value / len(train_loader)
    avg_ce_value = total_ce_value / len(train_loader)   
    
    print(f'Train - Dice Loss: {avg_dice_value:.4f}, CE Loss: {avg_ce_value:.4f}, Combined Loss: {total_loss/len(train_loader):.4f}')
    
    return total_loss / len(train_loader)


def evaluate(seg_model, dev_loader, epoch, schedulers): 
    seg_model.eval()
    
    total_loss = 0
    total_dice_value = 0
    total_ce_value = 0
    total_dice_score = 0
    
    with torch.no_grad():
        for batch in dev_loader:
            flair, y = batch
            flair = flair.to(device)
            y = y.to(device)
            
            output = seg_model(flair)
            dice_loss, dice_value, ce_value = combined_binary_loss_sigmoid(output, y)
            
            total_dice_value += dice_value
            total_ce_value += ce_value
            
            dice_score = 1 - binary_dice_loss_sigmoid(output, y)
            total_dice_score += dice_score 
            
            loss = torch.mean(dice_loss)
            total_loss += loss.item()

    avg_loss = total_loss / len(dev_loader)
    avg_dice_score = total_dice_score / len(dev_loader)
    avg_dice_value = total_dice_value / len(dev_loader)
    avg_ce_value = total_ce_value / len(dev_loader)
        
    print(f'Validation - Epoch {epoch}')
    print(f'Dice Score: {avg_dice_score:.4f}')
    print(f'Dice Loss: {avg_dice_value:.4f}, CE Loss: {avg_ce_value:.4f}, Combined Loss: {avg_loss:.4f}')
    
    if isinstance(schedulers, dict):
        for scheduler in schedulers.values():
            scheduler.step(avg_loss)
    else:
        schedulers.step(avg_loss)

    return avg_loss

In [6]:
total_epochs = 1000
train_losses = []
dev_losses = []
best_loss = float('inf')

os.makedirs(results_dir, exist_ok=True)

try:
    for epoch in range(1, total_epochs + 1):
        start_time = timer()
        
        train_loss = train_epoch(seg_model, optimizers, train_loader, epoch)
        train_losses.append(train_loss)
        
        dev_loss = evaluate(seg_model, dev_loader, epoch, schedulers)
        dev_losses.append(dev_loss)
        
        end_time = timer()
        
        torch.save(seg_model.state_dict(), 
                  join(results_dir, f'LN_UXFormer_epoch_{epoch}.pt'))
        
        if dev_loss < best_loss:
            best_loss = dev_loss
            torch.save(seg_model.state_dict(), 
                      join(results_dir, 'LN_UXFormer_epoch.pt'))
            print(f'Best model saved with loss: {best_loss:.4f}')
        
        print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Valid Loss: {dev_loss:.4f}, Time: {(end_time - start_time):.3f}s\n')
    
    print('Training completed.')
    
except KeyboardInterrupt:
    print('Training interrupted.')

Train - Dice Loss: 0.8671, CE Loss: 1.1552, Combined Loss: 0.8815
Validation - Epoch 1
Dice Score: 0.2311
Dice Loss: 0.7689, CE Loss: 1.4558, Combined Loss: 0.8033
Best model saved with loss: 0.8033
Epoch 1: Train Loss: 0.8815, Valid Loss: 0.8033, Time: 74.393s

Train - Dice Loss: 0.7440, CE Loss: 0.7979, Combined Loss: 0.7467
Validation - Epoch 2
Dice Score: 0.2723
Dice Loss: 0.7277, CE Loss: 0.5146, Combined Loss: 0.7171
Best model saved with loss: 0.7171
Epoch 2: Train Loss: 0.7467, Valid Loss: 0.7171, Time: 73.878s

Train - Dice Loss: 0.6361, CE Loss: 0.2753, Combined Loss: 0.6180
Validation - Epoch 3
Dice Score: 0.4333
Dice Loss: 0.5667, CE Loss: 0.2068, Combined Loss: 0.5487
Best model saved with loss: 0.5487
Epoch 3: Train Loss: 0.6180, Valid Loss: 0.5487, Time: 76.259s

Train - Dice Loss: 0.5330, CE Loss: 0.2114, Combined Loss: 0.5169
Validation - Epoch 4
Dice Score: 0.4984
Dice Loss: 0.5016, CE Loss: 0.1571, Combined Loss: 0.4844
Best model saved with loss: 0.4844
Epoch 4: Tra