In [14]:
import os
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

from monai.losses import DiceCELoss, FocalLoss, DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference

from data_loading import create_dataloaders, get_data_list
from model import create_qct_segmentation_model


In [None]:
def train_model(images_dir, labels_dir, model_name='unet', epochs=10, device='cuda'):
    num_classes = 5

    # Load and filter data
    data_list = get_data_list(images_dir, labels_dir)
    train_loader, val_loader = create_dataloaders(data_list, batch_size=1)

    # Class weights
    weights = calculate_class_weights(train_loader, num_classes, device)

    # Model, loss, optimizer
    model = create_qct_segmentation_model(model_name).to(device)
    loss_fn = get_loss_fn(weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    best_dice = 0
    save_path = f"./{model_name}_best_model.pth"

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            x = batch["image"].to(device)
            y = batch["label"].to(device)

            optimizer.zero_grad()
            out = model(x)
            loss = loss_fn(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Train Loss: {total_loss / len(train_loader):.4f}")

        # Validation
        model.eval()
        dice_scores = []
        with torch.no_grad():
            for val_batch in val_loader:
                x_val = val_batch["image"].to(device)
                y_val = val_batch["label"].to(device)

                val_out = sliding_window_inference(x_val, (64, 64, 64), 1, model)
                dice = compute_dice(val_out, y_val, num_classes)
                dice_scores.append(dice)

        avg_dice = np.mean(dice_scores)
        print(f"Validation Dice: {avg_dice:.4f}")

        if avg_dice > best_dice:
            best_dice = avg_dice
            torch.save(model.state_dict(), save_path)
            print(f"✅ New best model saved! Dice = {best_dice:.4f}")


: 

In [None]:
root_dir = "/home/user/auto-annotation/auto-annotation/dataset/femur-bone/data"
images_dir = os.path.join(root_dir, "image")
labels_dir = os.path.join(root_dir, "label")

train_model(images_dir, labels_dir, model_name="unet", epochs=20)


Loading dataset: 100%|██████████| 20/20 [00:23<00:00,  1.15s/it]
Loading dataset:  71%|███████▏  | 5/7 [00:13<00:04,  2.03s/it]