In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

import random
import numpy as np
import torch

from src.models import UNet_CBAM

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.networks.layers.factories import Act, Norm

from monai.config import print_config
from monai.metrics import DiceMetric
# from src.models.swincspunetr import SwinCSPUNETR
# from src.models.swincspunetr_unet import SwinCSPUNETR_unet
# from src.models.swincspunetr3plus import SwinCSPUNETR3plus

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

# 랜덤 시드 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


print_config()

  from .autonotebook import tqdm as notebook_tqdm


MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: c:\Users\<username>\.conda\envs\UM\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.2.0
Tensorboard version: 2.18.0
gdown version: 5.2.0
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.17.2
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.

In [2]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 1500}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 1500},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.16393442622950818, 2: 0.01639344262295082, 3: 0.2459016393442623, 4: 0.16393442622950818, 5: 0.2459016393442623, 6: 0.16393442622950818}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.16393442622950818, 0.01639344262295082, 0.2459016393442623, 0.16393442622950818, 0.2459016393442623, 0.16393442622950818]


# 모델 설정

In [3]:
from src.dataset.dataset import create_dataloaders, create_dataloaders_bw
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd, RandCropd,RandCropByPosNegLabeld, RandGaussianSmoothd
)
from monai.transforms import CastToTyped
import numpy as np

train_img_dir = "./datasets/pretrain_exdata/images"
train_label_dir = "./datasets/pretrain_exdata/labels"
val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
# DATA CONFIG
img_size =  96 # Match your patch size
img_depth = 32
n_classes = 7
batch_size = 16 # 13.8GB GPU memory required for 128x128 img size
loader_batch = 1
num_samples = batch_size // loader_batch # 한 이미지에서 뽑을 샘플 수
num_repeat = 4
# MODEL CONFIG
num_epochs = 4000
lamda = 0.52
ce_weight = 0.4
lr = 0.001
feature_size = 48
use_checkpoint = True
use_v2 = True
drop_rate= 0.25
attn_drop_rate = 0.25
num_bottleneck = 2
# CLASS_WEIGHTS
class_weights = None
# class_weights = torch.tensor([0.0001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치
# class_weights = torch.tensor([0.9,1,0.9,1.1,1,1.1,1], dtype=torch.float32)  # 클래스별 가중치
# class_weights = torch.tensor([1,1,1,1,1,1,1], dtype=torch.float32)  # 클래스별 가중치
sigma = 2.0
accumulation_steps = 1
# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    # GaussianSmoothd(
    #     keys=["image"],      # 변환을 적용할 키
    #     sigma=[sigma, sigma, sigma]  # 각 축(x, y, z)의 시그마 값
    #     ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    RandGaussianSmoothd(
    keys=["image"],      # 변환을 적용할 키
    sigma_x = (0.0, sigma), # 각 축(x, y, z)의 시그마 값
    sigma_y = (0.0, sigma),
    sigma_z = (0.0, sigma),
    prob=1.0,
    ),
])
val_random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    # RandGaussianSmoothd(
    # keys=["image"],      # 변환을 적용할 키
    # sigma_x = (0.0, sigma), # 각 축(x, y, z)의 시그마 값
    # sigma_y = (0.0, sigma),
    # sigma_z = (0.0, sigma),
    # prob=1.0,
    # ),
])


In [4]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders_bw(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    val_non_random_transforms=non_random_transforms,
    random_transforms = random_transforms, 
    val_random_transforms=val_random_transforms,
    batch_size = loader_batch,
    num_workers=0,
    train_num_repeat=num_repeat,
    val_num_repeat=num_repeat,
    )

Loading dataset: 100%|██████████| 51/51 [00:04<00:00, 10.24it/s]
Loading dataset: 100%|██████████| 1/1 [00:00<00:00,  9.39it/s]


https://monai.io/model-zoo.html

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.losses import TverskyLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DynamicTverskyLoss 클래스 정의
class DynamicTverskyLoss(TverskyLoss):
    def __init__(self, lamda=0.5, **kwargs):
        super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
        self.lamda = lamda

    def set_lamda(self, lamda):
        self.lamda = lamda
        self.alpha = 1 - lamda
        self.beta = lamda


# CombinedCETverskyLoss 클래스
class CombinedCETverskyLoss(nn.Module):

    
    def __init__(self, lamda=0.5, ce_weight=0.5, n_classes=7, class_weights=None, ignore_index=-1, **kwargs):
        super().__init__()
        self.n_classes = n_classes
        self.ce_weight = ce_weight
        self.ignore_index = ignore_index
        
        # CrossEntropyLoss에서 클래스별 가중치를 적용
        self.ce = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction='mean', **kwargs)
        
        # TverskyLoss
        self.tversky = DynamicTverskyLoss(lamda=lamda, reduction="mean",softmax=True, **kwargs)

    def forward(self, inputs, targets):
        
        # CrossEntropyLoss는 정수형 클래스 인덱스를 사용
        ce_loss = self.ce(inputs, targets)

        # TverskyLoss 계산 (원핫 인코딩된 라벨을 사용)
        
        tversky_loss = self.tversky(inputs, targets)

        # 최종 손실 계산
        final_loss = self.ce_weight * ce_loss + (1 - self.ce_weight) * tversky_loss  # mean()으로 배치에 대해 평균
        
        return final_loss

    def set_lamda(self, lamda):
        self.tversky.set_lamda(lamda)

    @property
    def lamda(self):
        return self.tversky.lamda

criterion = CombinedCETverskyLoss(
    lamda=lamda,
    ce_weight=ce_weight,
    n_classes=n_classes,
    class_weights=class_weights,
).to(device)

In [5]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
from monai.networks.nets import UNet
from src.models import *


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

enc_channels = (32, 64, 128, 256)
enc_strides = (2, 2, 2)
num_layers_enc = (1, 1, 1, 2)

core_channels = 64
dec_channels = (128, 64, 32)
dec_strides = (2, 2, 2)
num_layers_dec = (1, 1, 1)

skip_map = {
    0: [("enc", 2)],       # 디코더0 => 인코더2
    1: [("enc", 3), ("enc", 1)],  # 디코더1 => 인코더1 + 디코더0
    2: [("enc", 3), ("dec", 0), ("enc", 0)]   # 디코더2 => 인코더0 + 디코더1
}

model = FlexibleUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=n_classes,
    encoder_channels=enc_channels,
    encoder_strides=enc_strides,
    core_channels=core_channels,
    decoder_channels=dec_channels,
    decoder_strides=dec_strides,
    num_layers_encoder=num_layers_enc,
    num_layers_decoder=num_layers_dec,
    skip_connections=skip_map,
    kernel_size=3,
    up_kernel_size=3,
    act=Act.PRELU,
    norm=Norm.INSTANCE,
    dropout=drop_rate,
    bias=True,
    mode="trilinear",
    align_corners=False,
).to(device)

pretrain_str = "yes" if use_checkpoint else "no"
weight_str = "weighted" if class_weights is not None else ""

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
folder_name = f"FLEX_IoU_511_241_f{feature_size}_lr{lr:.0e}_a{lamda:.2f}_b{batch_size}_r{num_repeat}_ce{ce_weight}_ac{accumulation_steps}"
checkpoint_dir = checkpoint_base_dir / folder_name
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model_pretrained.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss', 'best_val_fbeta_score']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                best_val_fbeta_score = checkpoint['best_val_fbeta_score']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")

기존 best model 발견: model_checkpoints\FLEX_IoU_511_241_f48_lr1e-03_a0.52_b16_r4_ce0.4_ac1\best_model_pretrained.pt
기존 학습된 가중치를 성공적으로 로드했습니다.


  checkpoint = torch.load(best_model_path, map_location=device)


In [7]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([16, 1, 32, 96, 96]) torch.Size([16, 1, 32, 96, 96])


In [8]:
torch.backends.cudnn.benchmark = True

In [9]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_UNet',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        # "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        # "attn_drop_rate": attn_drop_rate,
        # "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        # "num_bottleneck": num_bottleneck,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpook0612[0m ([33mlimbw[0m). Use [1m`wandb login --relogin`[0m to force relogin


# 학습

In [6]:
from monai.metrics import DiceMetric
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    loss = criterion(outputs, labels_onehot)
    # loss = loss_fn(criterion(outputs, labels_onehot),class_weights=class_weights, device=device)
    return loss, outputs, labels, outputs.argmax(dim=1)

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    with tqdm(train_loader, desc='Training') as pbar:
        for i, batch_data in enumerate(pbar):
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
    avg_loss = epoch_loss / len(train_loader)
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss


def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval, ce_weight):
    model.eval()
    val_loss = 0
    
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    class_mIoU_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())
                        intersection = torch.sum(pred_i & label_i).float()
                        union = torch.sum(pred_i | label_i).float()
                        iou = (intersection + 1e-8) / (union + 1e-8)
                        class_mIoU_scores[i].append(iou.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
               
        print() 
    if epoch % calculate_dice_interval == 0:
        print("Validation mIoU Score")
        all_classes_mIoU_scores = []
        for i in range(n_classes):
            mean_IoU = np.mean(class_mIoU_scores[i])
            wandb.log({f'class_{i}_IoU_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_IoU:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_mIoU_scores.append(mean_IoU)
                
        print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        overall_mean_IoU = np.mean(all_classes_mIoU_scores)
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1, 'overall_mean_IoU_score': overall_mean_IoU})
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\nOverall Mean IoU Score: {overall_mean_IoU:.4f}")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    final_score = overall_mean_fbeta * (1 - ce_weight) + overall_mean_IoU * ce_weight
    
    return val_loss / len(val_loader), final_score 

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4, pretrained=False
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps
        )
        
        scheduler.step(train_loss)
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval,
            ce_weight = ce_weight
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation hybrid_score: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            if pretrained:
                checkpoint_path = os.path.join(checkpoint_dir, 'best_model_pretrained.pt')
            else:
                checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        # if epochs_no_improve % 6 == 0 & epochs_no_improve != 0:
        #     # 손실이 개선되지 않았으므로 lambda 감소
        #     new_lamda = max(criterion.lamda - 0.01, 0.35)  # 최소값은 0.1로 설정
        #     criterion.set_lamda(new_lamda)
        #     print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")

    wandb.finish()


In [11]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps,
    pretrained=True,
    ) 

Epoch 9/4000


Training: 100%|██████████| 204/204 [03:10<00:00,  1.07it/s, loss=0.439]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.87it/s, loss=0.475]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.5307, Class 2: 0.0000, Class 3: 0.3157, Class 4: 0.6869, Class 5: 0.3269, Class 6: 0.8969, 
Validation F-beta Score
Class 0: 0.9867, Class 1: 0.5791, Class 2: 0.0000, Class 3: 0.5371, Class 4: 0.7487, Class 5: 0.2776, Class 6: 0.9087, 
Validation mIoU Score
Class 0: 0.9751, Class 1: 0.3655, Class 2: 0.0000, Class 3: 0.1890, Class 4: 0.5310, Class 5: 0.1956, Class 6: 0.8138, 

Overall Mean Dice Score: 0.5514
Overall Mean F-beta Score: 0.6103
Overall Mean IoU Score: 0.4190
Training Loss: 0.4622, Validation Loss: 0.4593, Validation hybrid_score: 0.5338
Epoch 10/4000


Training: 100%|██████████| 204/204 [02:43<00:00,  1.24it/s, loss=0.426]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.81it/s, loss=0.455]


Validation Dice Score
Class 0: 0.9831, Class 1: 0.5603, Class 2: 0.0000, Class 3: 0.3819, Class 4: 0.6511, Class 5: 0.3437, Class 6: 0.8978, 
Validation F-beta Score
Class 0: 0.9785, Class 1: 0.5558, Class 2: 0.0000, Class 3: 0.5154, Class 4: 0.7683, Class 5: 0.3542, Class 6: 0.9274, 
Validation mIoU Score
Class 0: 0.9668, Class 1: 0.4088, Class 2: 0.0000, Class 3: 0.2377, Class 4: 0.4836, Class 5: 0.2104, Class 6: 0.8168, 

Overall Mean Dice Score: 0.5670
Overall Mean F-beta Score: 0.6242
Overall Mean IoU Score: 0.4315
Training Loss: 0.4573, Validation Loss: 0.4618, Validation hybrid_score: 0.5471
Epoch 11/4000


Training: 100%|██████████| 204/204 [02:55<00:00,  1.16it/s, loss=0.429]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.83it/s, loss=0.457]


Validation Dice Score
Class 0: 0.9850, Class 1: 0.6743, Class 2: 0.0000, Class 3: 0.3228, Class 4: 0.7077, Class 5: 0.3966, Class 6: 0.8944, 
Validation F-beta Score
Class 0: 0.9801, Class 1: 0.7434, Class 2: 0.0000, Class 3: 0.5828, Class 4: 0.8330, Class 5: 0.4030, Class 6: 0.9390, 
Validation mIoU Score
Class 0: 0.9705, Class 1: 0.5119, Class 2: 0.0000, Class 3: 0.1948, Class 4: 0.5509, Class 5: 0.2486, Class 6: 0.8096, 

Overall Mean Dice Score: 0.5992
Overall Mean F-beta Score: 0.7002
Overall Mean IoU Score: 0.4632
Training Loss: 0.4540, Validation Loss: 0.4538, Validation hybrid_score: 0.6054
SUPER Best model saved. Loss:0.4538, Score:0.6054
Epoch 12/4000


Training: 100%|██████████| 204/204 [02:42<00:00,  1.26it/s, loss=0.443]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.91it/s, loss=0.469]


Validation Dice Score
Class 0: 0.9832, Class 1: 0.7164, Class 2: 0.0076, Class 3: 0.2907, Class 4: 0.6890, Class 5: 0.3765, Class 6: 0.8861, 
Validation F-beta Score
Class 0: 0.9771, Class 1: 0.7121, Class 2: 0.0041, Class 3: 0.4787, Class 4: 0.8219, Class 5: 0.3983, Class 6: 0.9459, 
Validation mIoU Score
Class 0: 0.9669, Class 1: 0.5600, Class 2: 0.0038, Class 3: 0.1743, Class 4: 0.5267, Class 5: 0.2337, Class 6: 0.7959, 

Overall Mean Dice Score: 0.5917
Overall Mean F-beta Score: 0.6714
Overall Mean IoU Score: 0.4581
Training Loss: 0.4484, Validation Loss: 0.4579, Validation hybrid_score: 0.5861
Epoch 13/4000


Training: 100%|██████████| 204/204 [02:53<00:00,  1.17it/s, loss=0.423]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.74it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9860, Class 1: 0.5753, Class 2: 0.0244, Class 3: 0.3701, Class 4: 0.7222, Class 5: 0.3955, Class 6: 0.9192, 
Validation F-beta Score
Class 0: 0.9818, Class 1: 0.4955, Class 2: 0.0144, Class 3: 0.5324, Class 4: 0.8132, Class 5: 0.4599, Class 6: 0.9283, 
Validation mIoU Score
Class 0: 0.9724, Class 1: 0.4083, Class 2: 0.0126, Class 3: 0.2338, Class 4: 0.5667, Class 5: 0.2584, Class 6: 0.8506, 

Overall Mean Dice Score: 0.5965
Overall Mean F-beta Score: 0.6458
Overall Mean IoU Score: 0.4636
Training Loss: 0.4477, Validation Loss: 0.4557, Validation hybrid_score: 0.5729
Epoch 14/4000


Training: 100%|██████████| 204/204 [02:57<00:00,  1.15it/s, loss=0.418]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.80it/s, loss=0.446]


Validation Dice Score
Class 0: 0.9848, Class 1: 0.6124, Class 2: 0.1466, Class 3: 0.4258, Class 4: 0.6410, Class 5: 0.3620, Class 6: 0.9178, 
Validation F-beta Score
Class 0: 0.9797, Class 1: 0.6805, Class 2: 0.0969, Class 3: 0.6217, Class 4: 0.7978, Class 5: 0.4003, Class 6: 0.9516, 
Validation mIoU Score
Class 0: 0.9701, Class 1: 0.4433, Class 2: 0.0795, Class 3: 0.2711, Class 4: 0.4763, Class 5: 0.2239, Class 6: 0.8487, 

Overall Mean Dice Score: 0.5918
Overall Mean F-beta Score: 0.6904
Overall Mean IoU Score: 0.4526
Training Loss: 0.4439, Validation Loss: 0.4499, Validation hybrid_score: 0.5953
Epoch 15/4000


Training: 100%|██████████| 204/204 [02:42<00:00,  1.26it/s, loss=0.439]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.78it/s, loss=0.45] 


Validation Dice Score
Class 0: 0.9867, Class 1: 0.5600, Class 2: 0.1289, Class 3: 0.2761, Class 4: 0.6885, Class 5: 0.4876, Class 6: 0.8906, 
Validation F-beta Score
Class 0: 0.9835, Class 1: 0.5964, Class 2: 0.1050, Class 3: 0.4322, Class 4: 0.7808, Class 5: 0.5112, Class 6: 0.9225, 
Validation mIoU Score
Class 0: 0.9738, Class 1: 0.3913, Class 2: 0.0699, Class 3: 0.1631, Class 4: 0.5274, Class 5: 0.3235, Class 6: 0.8033, 

Overall Mean Dice Score: 0.5805
Overall Mean F-beta Score: 0.6486
Overall Mean IoU Score: 0.4417
Training Loss: 0.4410, Validation Loss: 0.4449, Validation hybrid_score: 0.5659
Epoch 16/4000


Training: 100%|██████████| 204/204 [02:57<00:00,  1.15it/s, loss=0.428]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.69it/s, loss=0.437]


Validation Dice Score
Class 0: 0.9825, Class 1: 0.7154, Class 2: 0.1744, Class 3: 0.2736, Class 4: 0.6389, Class 5: 0.5137, Class 6: 0.8761, 
Validation F-beta Score
Class 0: 0.9738, Class 1: 0.7986, Class 2: 0.1743, Class 3: 0.4997, Class 4: 0.8394, Class 5: 0.6155, Class 6: 0.9333, 
Validation mIoU Score
Class 0: 0.9656, Class 1: 0.5574, Class 2: 0.0963, Class 3: 0.1623, Class 4: 0.4724, Class 5: 0.3459, Class 6: 0.7842, 

Overall Mean Dice Score: 0.6035
Overall Mean F-beta Score: 0.7373
Overall Mean IoU Score: 0.4645
Training Loss: 0.4501, Validation Loss: 0.4408, Validation hybrid_score: 0.6282
SUPER Best model saved. Loss:0.4408, Score:0.6282
Epoch 17/4000


Training: 100%|██████████| 204/204 [02:54<00:00,  1.17it/s, loss=0.435]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.72it/s, loss=0.455]


Validation Dice Score
Class 0: 0.9844, Class 1: 0.6265, Class 2: 0.1591, Class 3: 0.4440, Class 4: 0.7505, Class 5: 0.3762, Class 6: 0.8814, 
Validation F-beta Score
Class 0: 0.9788, Class 1: 0.6850, Class 2: 0.1569, Class 3: 0.5872, Class 4: 0.8076, Class 5: 0.5026, Class 6: 0.9144, 
Validation mIoU Score
Class 0: 0.9692, Class 1: 0.4651, Class 2: 0.0885, Class 3: 0.2904, Class 4: 0.6015, Class 5: 0.2351, Class 6: 0.7897, 

Overall Mean Dice Score: 0.6157
Overall Mean F-beta Score: 0.6994
Overall Mean IoU Score: 0.4764
Training Loss: 0.4465, Validation Loss: 0.4472, Validation hybrid_score: 0.6102
Epoch 18/4000


Training: 100%|██████████| 204/204 [02:53<00:00,  1.18it/s, loss=0.401]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.73it/s, loss=0.41] 


Validation Dice Score
Class 0: 0.9824, Class 1: 0.5743, Class 2: 0.1883, Class 3: 0.3745, Class 4: 0.7298, Class 5: 0.4186, Class 6: 0.8931, 
Validation F-beta Score
Class 0: 0.9752, Class 1: 0.6448, Class 2: 0.1920, Class 3: 0.5625, Class 4: 0.8189, Class 5: 0.5461, Class 6: 0.9189, 
Validation mIoU Score
Class 0: 0.9654, Class 1: 0.4223, Class 2: 0.1048, Class 3: 0.2310, Class 4: 0.5749, Class 5: 0.2697, Class 6: 0.8071, 

Overall Mean Dice Score: 0.5981
Overall Mean F-beta Score: 0.6982
Overall Mean IoU Score: 0.4610
Training Loss: 0.4461, Validation Loss: 0.4466, Validation hybrid_score: 0.6033
Epoch 19/4000


Training: 100%|██████████| 204/204 [02:50<00:00,  1.19it/s, loss=0.422]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.449]


Validation Dice Score
Class 0: 0.9847, Class 1: 0.5948, Class 2: 0.1030, Class 3: 0.2359, Class 4: 0.6373, Class 5: 0.4494, Class 6: 0.9356, 
Validation F-beta Score
Class 0: 0.9777, Class 1: 0.6520, Class 2: 0.1851, Class 3: 0.4397, Class 4: 0.8285, Class 5: 0.5078, Class 6: 0.9564, 
Validation mIoU Score
Class 0: 0.9699, Class 1: 0.4357, Class 2: 0.0548, Class 3: 0.1369, Class 4: 0.4753, Class 5: 0.2958, Class 6: 0.8792, 

Overall Mean Dice Score: 0.5706
Overall Mean F-beta Score: 0.6769
Overall Mean IoU Score: 0.4446
Training Loss: 0.4437, Validation Loss: 0.4511, Validation hybrid_score: 0.5840
Epoch 20/4000


Training: 100%|██████████| 204/204 [02:43<00:00,  1.25it/s, loss=0.431]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.91it/s, loss=0.466]


Validation Dice Score
Class 0: 0.9844, Class 1: 0.7259, Class 2: 0.2029, Class 3: 0.3907, Class 4: 0.6366, Class 5: 0.4518, Class 6: 0.9035, 
Validation F-beta Score
Class 0: 0.9767, Class 1: 0.7639, Class 2: 0.2492, Class 3: 0.6520, Class 4: 0.8228, Class 5: 0.5415, Class 6: 0.9532, 
Validation mIoU Score
Class 0: 0.9693, Class 1: 0.5735, Class 2: 0.1184, Class 3: 0.2445, Class 4: 0.4698, Class 5: 0.2954, Class 6: 0.8242, 

Overall Mean Dice Score: 0.6217
Overall Mean F-beta Score: 0.7467
Overall Mean IoU Score: 0.4815
Training Loss: 0.4423, Validation Loss: 0.4473, Validation hybrid_score: 0.6406
Epoch 21/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.426]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s, loss=0.459]


Validation Dice Score
Class 0: 0.9843, Class 1: 0.5142, Class 2: 0.1196, Class 3: 0.3678, Class 4: 0.6217, Class 5: 0.4050, Class 6: 0.8906, 
Validation F-beta Score
Class 0: 0.9774, Class 1: 0.6799, Class 2: 0.2389, Class 3: 0.5779, Class 4: 0.7838, Class 5: 0.4846, Class 6: 0.9529, 
Validation mIoU Score
Class 0: 0.9691, Class 1: 0.3787, Class 2: 0.0638, Class 3: 0.2255, Class 4: 0.4560, Class 5: 0.2589, Class 6: 0.8037, 

Overall Mean Dice Score: 0.5599
Overall Mean F-beta Score: 0.6958
Overall Mean IoU Score: 0.4246
Training Loss: 0.4419, Validation Loss: 0.4432, Validation hybrid_score: 0.5873
Epoch 22/4000


Training: 100%|██████████| 204/204 [02:47<00:00,  1.22it/s, loss=0.393]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.87it/s, loss=0.44] 


Validation Dice Score
Class 0: 0.9835, Class 1: 0.7115, Class 2: 0.0576, Class 3: 0.4071, Class 4: 0.7065, Class 5: 0.4342, Class 6: 0.9310, 
Validation F-beta Score
Class 0: 0.9797, Class 1: 0.8167, Class 2: 0.0678, Class 3: 0.5680, Class 4: 0.7718, Class 5: 0.5076, Class 6: 0.9251, 
Validation mIoU Score
Class 0: 0.9676, Class 1: 0.5571, Class 2: 0.0303, Class 3: 0.2608, Class 4: 0.5491, Class 5: 0.2795, Class 6: 0.8714, 

Overall Mean Dice Score: 0.6381
Overall Mean F-beta Score: 0.7178
Overall Mean IoU Score: 0.5036
Training Loss: 0.4392, Validation Loss: 0.4401, Validation hybrid_score: 0.6321
SUPER Best model saved. Loss:0.4401, Score:0.6321
Epoch 23/4000


Training: 100%|██████████| 204/204 [02:52<00:00,  1.18it/s, loss=0.436]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9833, Class 1: 0.7585, Class 2: 0.1709, Class 3: 0.3208, Class 4: 0.6962, Class 5: 0.4421, Class 6: 0.9207, 
Validation F-beta Score
Class 0: 0.9751, Class 1: 0.8188, Class 2: 0.2047, Class 3: 0.5163, Class 4: 0.8344, Class 5: 0.5717, Class 6: 0.9349, 
Validation mIoU Score
Class 0: 0.9671, Class 1: 0.6110, Class 2: 0.0971, Class 3: 0.1938, Class 4: 0.5340, Class 5: 0.2845, Class 6: 0.8532, 

Overall Mean Dice Score: 0.6277
Overall Mean F-beta Score: 0.7352
Overall Mean IoU Score: 0.4953
Training Loss: 0.4365, Validation Loss: 0.4417, Validation hybrid_score: 0.6393
Epoch 24/4000


Training: 100%|██████████| 204/204 [02:51<00:00,  1.19it/s, loss=0.417]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.90it/s, loss=0.442]


Validation Dice Score
Class 0: 0.9854, Class 1: 0.6522, Class 2: 0.1223, Class 3: 0.2648, Class 4: 0.7546, Class 5: 0.4700, Class 6: 0.9171, 
Validation F-beta Score
Class 0: 0.9812, Class 1: 0.7423, Class 2: 0.2055, Class 3: 0.3597, Class 4: 0.8345, Class 5: 0.5189, Class 6: 0.9112, 
Validation mIoU Score
Class 0: 0.9711, Class 1: 0.4946, Class 2: 0.0678, Class 3: 0.1549, Class 4: 0.6076, Class 5: 0.3103, Class 6: 0.8474, 

Overall Mean Dice Score: 0.6117
Overall Mean F-beta Score: 0.6733
Overall Mean IoU Score: 0.4830
Training Loss: 0.4376, Validation Loss: 0.4406, Validation hybrid_score: 0.5972
Epoch 25/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.376]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.83it/s, loss=0.433]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.7067, Class 2: 0.2482, Class 3: 0.4353, Class 4: 0.7311, Class 5: 0.4527, Class 6: 0.9078, 
Validation F-beta Score
Class 0: 0.9820, Class 1: 0.7937, Class 2: 0.2557, Class 3: 0.6141, Class 4: 0.8198, Class 5: 0.5445, Class 6: 0.9453, 
Validation mIoU Score
Class 0: 0.9738, Class 1: 0.5475, Class 2: 0.1460, Class 3: 0.2789, Class 4: 0.5789, Class 5: 0.2939, Class 6: 0.8344, 

Overall Mean Dice Score: 0.6467
Overall Mean F-beta Score: 0.7435
Overall Mean IoU Score: 0.5067
Training Loss: 0.4365, Validation Loss: 0.4300, Validation hybrid_score: 0.6488
SUPER Best model saved. Loss:0.4300, Score:0.6488
Epoch 26/4000


Training: 100%|██████████| 204/204 [02:49<00:00,  1.21it/s, loss=0.395]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s, loss=0.407]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.7377, Class 2: 0.2844, Class 3: 0.3854, Class 4: 0.7255, Class 5: 0.4779, Class 6: 0.9156, 
Validation F-beta Score
Class 0: 0.9820, Class 1: 0.8131, Class 2: 0.3797, Class 3: 0.5069, Class 4: 0.8160, Class 5: 0.6154, Class 6: 0.8984, 
Validation mIoU Score
Class 0: 0.9744, Class 1: 0.5877, Class 2: 0.1668, Class 3: 0.2397, Class 4: 0.5696, Class 5: 0.3173, Class 6: 0.8450, 

Overall Mean Dice Score: 0.6484
Overall Mean F-beta Score: 0.7300
Overall Mean IoU Score: 0.5118
Training Loss: 0.4358, Validation Loss: 0.4230, Validation hybrid_score: 0.6427
Epoch 27/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.414]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.75it/s, loss=0.436]


Validation Dice Score
Class 0: 0.9841, Class 1: 0.7673, Class 2: 0.2305, Class 3: 0.4335, Class 4: 0.7542, Class 5: 0.4432, Class 6: 0.8300, 
Validation F-beta Score
Class 0: 0.9780, Class 1: 0.8480, Class 2: 0.2269, Class 3: 0.5834, Class 4: 0.8420, Class 5: 0.5441, Class 6: 0.9128, 
Validation mIoU Score
Class 0: 0.9688, Class 1: 0.6276, Class 2: 0.1366, Class 3: 0.2781, Class 4: 0.6064, Class 5: 0.2877, Class 6: 0.7301, 

Overall Mean Dice Score: 0.6456
Overall Mean F-beta Score: 0.7461
Overall Mean IoU Score: 0.5060
Training Loss: 0.4359, Validation Loss: 0.4315, Validation hybrid_score: 0.6500
Epoch 28/4000


Training: 100%|██████████| 204/204 [02:53<00:00,  1.17it/s, loss=0.417]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.83it/s, loss=0.444]


Validation Dice Score
Class 0: 0.9854, Class 1: 0.7773, Class 2: 0.2376, Class 3: 0.3812, Class 4: 0.7532, Class 5: 0.4484, Class 6: 0.9136, 
Validation F-beta Score
Class 0: 0.9805, Class 1: 0.8280, Class 2: 0.3367, Class 3: 0.5228, Class 4: 0.8143, Class 5: 0.5493, Class 6: 0.9055, 
Validation mIoU Score
Class 0: 0.9712, Class 1: 0.6365, Class 2: 0.1393, Class 3: 0.2421, Class 4: 0.6042, Class 5: 0.2892, Class 6: 0.8414, 

Overall Mean Dice Score: 0.6547
Overall Mean F-beta Score: 0.7240
Overall Mean IoU Score: 0.5227
Training Loss: 0.4343, Validation Loss: 0.4371, Validation hybrid_score: 0.6435
Epoch 29/4000


Training: 100%|██████████| 204/204 [02:52<00:00,  1.18it/s, loss=0.415]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.91it/s, loss=0.433]


Validation Dice Score
Class 0: 0.9847, Class 1: 0.6696, Class 2: 0.1088, Class 3: 0.4559, Class 4: 0.7575, Class 5: 0.5633, Class 6: 0.9038, 
Validation F-beta Score
Class 0: 0.9792, Class 1: 0.8327, Class 2: 0.1298, Class 3: 0.6277, Class 4: 0.8318, Class 5: 0.6373, Class 6: 0.9405, 
Validation mIoU Score
Class 0: 0.9698, Class 1: 0.5167, Class 2: 0.0611, Class 3: 0.2970, Class 4: 0.6121, Class 5: 0.3927, Class 6: 0.8260, 

Overall Mean Dice Score: 0.6700
Overall Mean F-beta Score: 0.7740
Overall Mean IoU Score: 0.5289
Training Loss: 0.4348, Validation Loss: 0.4239, Validation hybrid_score: 0.6759
SUPER Best model saved. Loss:0.4239, Score:0.6759
Epoch 30/4000


Training: 100%|██████████| 204/204 [02:42<00:00,  1.26it/s, loss=0.405]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.93it/s, loss=0.447]


Validation Dice Score
Class 0: 0.9852, Class 1: 0.7652, Class 2: 0.2878, Class 3: 0.4290, Class 4: 0.6361, Class 5: 0.3176, Class 6: 0.9102, 
Validation F-beta Score
Class 0: 0.9783, Class 1: 0.8390, Class 2: 0.3187, Class 3: 0.6057, Class 4: 0.8150, Class 5: 0.4580, Class 6: 0.9332, 
Validation mIoU Score
Class 0: 0.9709, Class 1: 0.6226, Class 2: 0.1687, Class 3: 0.2803, Class 4: 0.4795, Class 5: 0.1930, Class 6: 0.8356, 

Overall Mean Dice Score: 0.6116
Overall Mean F-beta Score: 0.7302
Overall Mean IoU Score: 0.4822
Training Loss: 0.4330, Validation Loss: 0.4328, Validation hybrid_score: 0.6310
Epoch 31/4000


Training: 100%|██████████| 204/204 [02:51<00:00,  1.19it/s, loss=0.41] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.74it/s, loss=0.421]


Validation Dice Score
Class 0: 0.9882, Class 1: 0.7256, Class 2: 0.2050, Class 3: 0.4841, Class 4: 0.7624, Class 5: 0.5220, Class 6: 0.9225, 
Validation F-beta Score
Class 0: 0.9843, Class 1: 0.8092, Class 2: 0.2905, Class 3: 0.6393, Class 4: 0.7969, Class 5: 0.6435, Class 6: 0.9195, 
Validation mIoU Score
Class 0: 0.9768, Class 1: 0.5714, Class 2: 0.1167, Class 3: 0.3228, Class 4: 0.6171, Class 5: 0.3538, Class 6: 0.8565, 

Overall Mean Dice Score: 0.6833
Overall Mean F-beta Score: 0.7617
Overall Mean IoU Score: 0.5443
Training Loss: 0.4335, Validation Loss: 0.4156, Validation hybrid_score: 0.6747
Epoch 32/4000


Training: 100%|██████████| 204/204 [02:52<00:00,  1.18it/s, loss=0.437]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.87it/s, loss=0.415]


Validation Dice Score
Class 0: 0.9850, Class 1: 0.6598, Class 2: 0.1113, Class 3: 0.4627, Class 4: 0.7359, Class 5: 0.5480, Class 6: 0.8434, 
Validation F-beta Score
Class 0: 0.9803, Class 1: 0.8032, Class 2: 0.1209, Class 3: 0.6176, Class 4: 0.7944, Class 5: 0.6170, Class 6: 0.9075, 
Validation mIoU Score
Class 0: 0.9705, Class 1: 0.5019, Class 2: 0.0613, Class 3: 0.3015, Class 4: 0.5847, Class 5: 0.3777, Class 6: 0.7345, 

Overall Mean Dice Score: 0.6499
Overall Mean F-beta Score: 0.7479
Overall Mean IoU Score: 0.5001
Training Loss: 0.4327, Validation Loss: 0.4305, Validation hybrid_score: 0.6488
Epoch 33/4000


Training: 100%|██████████| 204/204 [02:52<00:00,  1.18it/s, loss=0.419]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.444]


Validation Dice Score
Class 0: 0.9834, Class 1: 0.7129, Class 2: 0.1291, Class 3: 0.4175, Class 4: 0.7021, Class 5: 0.4804, Class 6: 0.9038, 
Validation F-beta Score
Class 0: 0.9762, Class 1: 0.7973, Class 2: 0.2290, Class 3: 0.5663, Class 4: 0.7714, Class 5: 0.6612, Class 6: 0.9009, 
Validation mIoU Score
Class 0: 0.9674, Class 1: 0.5600, Class 2: 0.0710, Class 3: 0.2684, Class 4: 0.5439, Class 5: 0.3198, Class 6: 0.8262, 

Overall Mean Dice Score: 0.6434
Overall Mean F-beta Score: 0.7394
Overall Mean IoU Score: 0.5037
Training Loss: 0.4326, Validation Loss: 0.4347, Validation hybrid_score: 0.6451
Epoch 34/4000


Training: 100%|██████████| 204/204 [02:50<00:00,  1.20it/s, loss=0.395]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.436]


Validation Dice Score
Class 0: 0.9855, Class 1: 0.6702, Class 2: 0.1897, Class 3: 0.3807, Class 4: 0.7414, Class 5: 0.4209, Class 6: 0.9135, 
Validation F-beta Score
Class 0: 0.9801, Class 1: 0.7517, Class 2: 0.2296, Class 3: 0.5764, Class 4: 0.8165, Class 5: 0.5195, Class 6: 0.9230, 
Validation mIoU Score
Class 0: 0.9714, Class 1: 0.5164, Class 2: 0.1061, Class 3: 0.2383, Class 4: 0.5907, Class 5: 0.2752, Class 6: 0.8414, 

Overall Mean Dice Score: 0.6253
Overall Mean F-beta Score: 0.7174
Overall Mean IoU Score: 0.4924
Training Loss: 0.4339, Validation Loss: 0.4435, Validation hybrid_score: 0.6274
Epoch 35/4000


Training: 100%|██████████| 204/204 [02:37<00:00,  1.29it/s, loss=0.416]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s, loss=0.426]


Validation Dice Score
Class 0: 0.9841, Class 1: 0.7145, Class 2: 0.1717, Class 3: 0.4821, Class 4: 0.7387, Class 5: 0.4572, Class 6: 0.8837, 
Validation F-beta Score
Class 0: 0.9767, Class 1: 0.7884, Class 2: 0.1965, Class 3: 0.6126, Class 4: 0.8526, Class 5: 0.6164, Class 6: 0.9257, 
Validation mIoU Score
Class 0: 0.9686, Class 1: 0.5600, Class 2: 0.0994, Class 3: 0.3187, Class 4: 0.5868, Class 5: 0.2970, Class 6: 0.7935, 

Overall Mean Dice Score: 0.6552
Overall Mean F-beta Score: 0.7591
Overall Mean IoU Score: 0.5112
Training Loss: 0.4330, Validation Loss: 0.4348, Validation hybrid_score: 0.6600
Epoch 36/4000


Training: 100%|██████████| 204/204 [02:44<00:00,  1.24it/s, loss=0.4]  
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.447]


Validation Dice Score
Class 0: 0.9838, Class 1: 0.7485, Class 2: 0.1648, Class 3: 0.4389, Class 4: 0.6555, Class 5: 0.4394, Class 6: 0.9209, 
Validation F-beta Score
Class 0: 0.9751, Class 1: 0.8632, Class 2: 0.2357, Class 3: 0.6174, Class 4: 0.8481, Class 5: 0.6018, Class 6: 0.9286, 
Validation mIoU Score
Class 0: 0.9682, Class 1: 0.6003, Class 2: 0.0933, Class 3: 0.2849, Class 4: 0.4878, Class 5: 0.2818, Class 6: 0.8538, 

Overall Mean Dice Score: 0.6407
Overall Mean F-beta Score: 0.7718
Overall Mean IoU Score: 0.5017
Training Loss: 0.4322, Validation Loss: 0.4287, Validation hybrid_score: 0.6638
Epoch 37/4000


Training: 100%|██████████| 204/204 [02:43<00:00,  1.25it/s, loss=0.404]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.441]


Validation Dice Score
Class 0: 0.9857, Class 1: 0.7707, Class 2: 0.1118, Class 3: 0.4317, Class 4: 0.6098, Class 5: 0.4475, Class 6: 0.9194, 
Validation F-beta Score
Class 0: 0.9785, Class 1: 0.8804, Class 2: 0.1438, Class 3: 0.6003, Class 4: 0.7467, Class 5: 0.5907, Class 6: 0.9328, 
Validation mIoU Score
Class 0: 0.9718, Class 1: 0.6299, Class 2: 0.0626, Class 3: 0.2784, Class 4: 0.4796, Class 5: 0.2896, Class 6: 0.8520, 

Overall Mean Dice Score: 0.6358
Overall Mean F-beta Score: 0.7502
Overall Mean IoU Score: 0.5059
Training Loss: 0.4315, Validation Loss: 0.4281, Validation hybrid_score: 0.6525
Epoch 38/4000


Training: 100%|██████████| 204/204 [02:49<00:00,  1.21it/s, loss=0.427]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.83it/s, loss=0.437]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.7606, Class 2: 0.1690, Class 3: 0.4144, Class 4: 0.6809, Class 5: 0.5118, Class 6: 0.9135, 
Validation F-beta Score
Class 0: 0.9812, Class 1: 0.8464, Class 2: 0.2213, Class 3: 0.6441, Class 4: 0.7659, Class 5: 0.6138, Class 6: 0.9264, 
Validation mIoU Score
Class 0: 0.9729, Class 1: 0.6194, Class 2: 0.0955, Class 3: 0.2654, Class 4: 0.5252, Class 5: 0.3461, Class 6: 0.8429, 

Overall Mean Dice Score: 0.6562
Overall Mean F-beta Score: 0.7593
Overall Mean IoU Score: 0.5198
Training Loss: 0.4304, Validation Loss: 0.4282, Validation hybrid_score: 0.6635
Epoch 39/4000


Training: 100%|██████████| 204/204 [02:43<00:00,  1.25it/s, loss=0.429]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.88it/s, loss=0.411]


Validation Dice Score
Class 0: 0.9864, Class 1: 0.7820, Class 2: 0.0872, Class 3: 0.3959, Class 4: 0.7041, Class 5: 0.4515, Class 6: 0.8857, 
Validation F-beta Score
Class 0: 0.9810, Class 1: 0.8396, Class 2: 0.1129, Class 3: 0.5524, Class 4: 0.7958, Class 5: 0.5610, Class 6: 0.8974, 
Validation mIoU Score
Class 0: 0.9731, Class 1: 0.6459, Class 2: 0.0478, Class 3: 0.2480, Class 4: 0.5499, Class 5: 0.2931, Class 6: 0.7963, 

Overall Mean Dice Score: 0.6438
Overall Mean F-beta Score: 0.7292
Overall Mean IoU Score: 0.5066
Training Loss: 0.4316, Validation Loss: 0.4289, Validation hybrid_score: 0.6402
Epoch 40/4000


Training: 100%|██████████| 204/204 [02:51<00:00,  1.19it/s, loss=0.396]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.87it/s, loss=0.403]


Validation Dice Score
Class 0: 0.9845, Class 1: 0.7337, Class 2: 0.1494, Class 3: 0.4047, Class 4: 0.7630, Class 5: 0.4992, Class 6: 0.8705, 
Validation F-beta Score
Class 0: 0.9784, Class 1: 0.8329, Class 2: 0.2273, Class 3: 0.6089, Class 4: 0.8343, Class 5: 0.5908, Class 6: 0.9028, 
Validation mIoU Score
Class 0: 0.9695, Class 1: 0.5899, Class 2: 0.0826, Class 3: 0.2579, Class 4: 0.6178, Class 5: 0.3331, Class 6: 0.7721, 

Overall Mean Dice Score: 0.6542
Overall Mean F-beta Score: 0.7539
Overall Mean IoU Score: 0.5142
Training Loss: 0.4310, Validation Loss: 0.4331, Validation hybrid_score: 0.6580
Epoch 41/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.41] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.93it/s, loss=0.439]

Validation Dice Score
Class 0: 0.9854, Class 1: 0.7372, Class 2: 0.2283, Class 3: 0.4611, Class 4: 0.7445, Class 5: 0.4683, Class 6: 0.8770, 
Validation F-beta Score
Class 0: 0.9813, Class 1: 0.8078, Class 2: 0.2788, Class 3: 0.7005, Class 4: 0.7948, Class 5: 0.5460, Class 6: 0.9044, 
Validation mIoU Score
Class 0: 0.9713, Class 1: 0.5892, Class 2: 0.1291, Class 3: 0.3002, Class 4: 0.5947, Class 5: 0.3061, Class 6: 0.7834, 

Overall Mean Dice Score: 0.6576
Overall Mean F-beta Score: 0.7507
Overall Mean IoU Score: 0.5147
Training Loss: 0.4305, Validation Loss: 0.4261, Validation hybrid_score: 0.6563
Early stopping





0,1
class_0_IoU_score,▂▅▆▇▅▇▄▅▃▄███▄▅▃▇▁▃▂▆▅▄▂▁▄▄▅▅▄▁▂▂
class_0_dice_score,▇▂▄▂▅▄▆▁▃▁▄▃▃▂▂▅▆▇▃▅▄▄█▄▂▅▃▃▅▆▆▄▅
class_0_f_beta_score,█▄▄▃▅▄▆▁▄▂▃▃▃▄▂▅▅▅▃▅▄▃▇▅▂▄▃▂▄▅▅▃▅
class_1_IoU_score,▂▅▆▇▅▇▄▅▃▄███▄▅▃▇▁▃▂▆▅▄▂▁▄▄▅▅▄▁▂▂
class_1_dice_score,▁▂▅▆▃▄▂▆▄▃▃▇▁▆▇▅▆▇██▅█▇▅▆▅▆▇█▇█▇▇
class_1_f_beta_score,▃▂▆▅▁▄▃▇▄▄▄▆▄▇▇▅▆▇▇▇▇▇▇▇▆▆▆██▇▇▇▇
class_2_IoU_score,▂▅▆▇▅▇▄▅▃▄███▄▅▃▇▁▃▂▆▅▄▂▁▄▄▅▅▄▁▂▂
class_2_dice_score,▁▁▁▁▂▅▄▅▅▆▄▆▄▂▅▄▇█▇▇▄█▆▄▄▆▅▅▄▅▃▅▇
class_2_f_beta_score,▁▁▁▁▁▃▃▄▄▅▄▆▅▂▅▅▆█▅▇▃▇▆▃▅▅▅▅▄▅▃▅▆
class_3_IoU_score,▂▅▆▇▅▇▄▅▃▄███▄▅▃▇▁▃▂▆▅▄▂▁▄▄▅▅▄▁▂▂

0,1
class_0_IoU_score,0.90442
class_0_dice_score,0.98544
class_0_f_beta_score,0.98134
class_1_IoU_score,0.90442
class_1_dice_score,0.73724
class_1_f_beta_score,0.80777
class_2_IoU_score,0.90442
class_2_dice_score,0.22828
class_2_f_beta_score,0.27882
class_3_IoU_score,0.90442


In [8]:


train_img_dir = './datasets/denoised_data/images'
train_label_dir = './datasets/denoised_data/labels'

train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders_bw(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    val_non_random_transforms = non_random_transforms,
    random_transforms = val_random_transforms, 
    val_random_transforms=val_random_transforms,
    batch_size = loader_batch,
    num_workers=0,
    train_num_repeat=num_repeat
    )

batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

Loading dataset: 100%|██████████| 6/6 [00:00<00:00,  7.05it/s]
Loading dataset: 100%|██████████| 1/1 [00:00<00:00, 11.05it/s]


torch.Size([16, 1, 32, 96, 96]) torch.Size([16, 1, 32, 96, 96])


In [9]:
torch.backends.cudnn.benchmark = True

In [10]:

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model_pretrained.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss', 'best_val_fbeta_score']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                best_val_fbeta_score = checkpoint['best_val_fbeta_score']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")
            
# lr = lr/10
            
# optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)


기존 best model 발견: model_checkpoints\FLEX_IoU_511_241_f48_lr1e-03_a0.52_b16_r4_ce0.4_ac1\best_model_pretrained.pt
기존 학습된 가중치를 성공적으로 로드했습니다.


  checkpoint = torch.load(best_model_path, map_location=device)


In [11]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_UNet',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        # "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        # "attn_drop_rate": attn_drop_rate,
        # "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        # "num_bottleneck": num_bottleneck,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpook0612[0m ([33mlimbw[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps,
    pretrained=False,
    ) 

Epoch 30/4000


Training: 100%|██████████| 24/24 [00:42<00:00,  1.78s/it, loss=0.461]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.65it/s, loss=0.408]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7970, Class 2: 0.1998, Class 3: 0.2868, Class 4: 0.7101, Class 5: 0.5153, Class 6: 0.9027, 
Validation F-beta Score
Class 0: 0.9828, Class 1: 0.8452, Class 2: 0.1802, Class 3: 0.5003, Class 4: 0.7985, Class 5: 0.5693, Class 6: 0.9473, 
Validation mIoU Score
Class 0: 0.9741, Class 1: 0.6625, Class 2: 0.1110, Class 3: 0.1674, Class 4: 0.5505, Class 5: 0.3471, Class 6: 0.8226, 

Overall Mean Dice Score: 0.6424
Overall Mean F-beta Score: 0.7321
Overall Mean IoU Score: 0.5100
Training Loss: 0.4506, Validation Loss: 0.4078, Validation hybrid_score: 0.6433
Epoch 31/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.35it/s, loss=0.438]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.77it/s, loss=0.432]


Validation Dice Score
Class 0: 0.9860, Class 1: 0.7921, Class 2: 0.0066, Class 3: 0.4165, Class 4: 0.6855, Class 5: 0.4681, Class 6: 0.9034, 
Validation F-beta Score
Class 0: 0.9862, Class 1: 0.9060, Class 2: 0.0137, Class 3: 0.6230, Class 4: 0.6752, Class 5: 0.4216, Class 6: 0.9273, 
Validation mIoU Score
Class 0: 0.9724, Class 1: 0.6557, Class 2: 0.0033, Class 3: 0.2630, Class 4: 0.5215, Class 5: 0.3056, Class 6: 0.8238, 

Overall Mean Dice Score: 0.6531
Overall Mean F-beta Score: 0.7106
Overall Mean IoU Score: 0.5139
Training Loss: 0.4468, Validation Loss: 0.4318, Validation hybrid_score: 0.6319
Epoch 32/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.428]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s, loss=0.426]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.6459, Class 2: 0.3828, Class 3: 0.4661, Class 4: 0.7737, Class 5: 0.5017, Class 6: 0.7488, 
Validation F-beta Score
Class 0: 0.9818, Class 1: 0.8487, Class 2: 0.5445, Class 3: 0.5314, Class 4: 0.8174, Class 5: 0.6545, Class 6: 0.8742, 
Validation mIoU Score
Class 0: 0.9743, Class 1: 0.4770, Class 2: 0.2367, Class 3: 0.3039, Class 4: 0.6309, Class 5: 0.3348, Class 6: 0.5985, 

Overall Mean Dice Score: 0.6273
Overall Mean F-beta Score: 0.7453
Overall Mean IoU Score: 0.4690
Training Loss: 0.4399, Validation Loss: 0.4261, Validation hybrid_score: 0.6348
Epoch 33/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.432]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.79it/s, loss=0.458]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.6664, Class 2: 0.1861, Class 3: 0.4162, Class 4: 0.7488, Class 5: 0.3967, Class 6: 0.9111, 
Validation F-beta Score
Class 0: 0.9828, Class 1: 0.8349, Class 2: 0.1404, Class 3: 0.5743, Class 4: 0.8133, Class 5: 0.5496, Class 6: 0.9279, 
Validation mIoU Score
Class 0: 0.9750, Class 1: 0.4997, Class 2: 0.1026, Class 3: 0.2628, Class 4: 0.5984, Class 5: 0.2474, Class 6: 0.8367, 

Overall Mean Dice Score: 0.6278
Overall Mean F-beta Score: 0.7400
Overall Mean IoU Score: 0.4890
Training Loss: 0.4417, Validation Loss: 0.4575, Validation hybrid_score: 0.6396
Epoch 34/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.454]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s, loss=0.457]


Validation Dice Score
Class 0: 0.9853, Class 1: 0.7529, Class 2: 0.1667, Class 3: 0.4422, Class 4: 0.6771, Class 5: 0.3630, Class 6: 0.9037, 
Validation F-beta Score
Class 0: 0.9794, Class 1: 0.8619, Class 2: 0.2579, Class 3: 0.5379, Class 4: 0.7324, Class 5: 0.5321, Class 6: 0.9090, 
Validation mIoU Score
Class 0: 0.9710, Class 1: 0.6037, Class 2: 0.0909, Class 3: 0.2839, Class 4: 0.5118, Class 5: 0.2217, Class 6: 0.8242, 

Overall Mean Dice Score: 0.6278
Overall Mean F-beta Score: 0.7147
Overall Mean IoU Score: 0.4891
Training Loss: 0.4452, Validation Loss: 0.4569, Validation hybrid_score: 0.6244
Epoch 35/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.427]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.80it/s, loss=0.436]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7483, Class 2: 0.0000, Class 3: 0.4553, Class 4: 0.8098, Class 5: 0.3098, Class 6: 0.9124, 
Validation F-beta Score
Class 0: 0.9847, Class 1: 0.8155, Class 2: 0.0000, Class 3: 0.5133, Class 4: 0.8463, Class 5: 0.3485, Class 6: 0.9486, 
Validation mIoU Score
Class 0: 0.9757, Class 1: 0.5978, Class 2: 0.0000, Class 3: 0.2947, Class 4: 0.6805, Class 5: 0.1833, Class 6: 0.8389, 

Overall Mean Dice Score: 0.6471
Overall Mean F-beta Score: 0.6944
Overall Mean IoU Score: 0.5190
Training Loss: 0.4381, Validation Loss: 0.4363, Validation hybrid_score: 0.6243
Epoch 36/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.448]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.87it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.6939, Class 2: 0.2015, Class 3: 0.4534, Class 4: 0.7651, Class 5: 0.5164, Class 6: 0.8960, 
Validation F-beta Score
Class 0: 0.9848, Class 1: 0.8859, Class 2: 0.1876, Class 3: 0.6066, Class 4: 0.7835, Class 5: 0.5609, Class 6: 0.9477, 
Validation mIoU Score
Class 0: 0.9743, Class 1: 0.5313, Class 2: 0.1120, Class 3: 0.2932, Class 4: 0.6195, Class 5: 0.3481, Class 6: 0.8115, 

Overall Mean Dice Score: 0.6650
Overall Mean F-beta Score: 0.7569
Overall Mean IoU Score: 0.5207
Training Loss: 0.4385, Validation Loss: 0.4426, Validation hybrid_score: 0.6624
Epoch 37/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.424]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s, loss=0.4]


Validation Dice Score
Class 0: 0.9891, Class 1: 0.7775, Class 2: 0.1749, Class 3: 0.4950, Class 4: 0.6846, Class 5: 0.6026, Class 6: 0.9346, 
Validation F-beta Score
Class 0: 0.9876, Class 1: 0.8620, Class 2: 0.2240, Class 3: 0.4897, Class 4: 0.7292, Class 5: 0.6257, Class 6: 0.9349, 
Validation mIoU Score
Class 0: 0.9785, Class 1: 0.6360, Class 2: 0.0958, Class 3: 0.3289, Class 4: 0.5205, Class 5: 0.4312, Class 6: 0.8773, 

Overall Mean Dice Score: 0.6989
Overall Mean F-beta Score: 0.7283
Overall Mean IoU Score: 0.5588
Training Loss: 0.4383, Validation Loss: 0.3999, Validation hybrid_score: 0.6605
Epoch 38/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.429]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.80it/s, loss=0.426]


Validation Dice Score
Class 0: 0.9883, Class 1: 0.7585, Class 2: 0.1186, Class 3: 0.5262, Class 4: 0.7091, Class 5: 0.4531, Class 6: 0.9016, 
Validation F-beta Score
Class 0: 0.9862, Class 1: 0.8775, Class 2: 0.1113, Class 3: 0.6849, Class 4: 0.7775, Class 5: 0.4333, Class 6: 0.9330, 
Validation mIoU Score
Class 0: 0.9770, Class 1: 0.6110, Class 2: 0.0630, Class 3: 0.3570, Class 4: 0.5493, Class 5: 0.2929, Class 6: 0.8208, 

Overall Mean Dice Score: 0.6697
Overall Mean F-beta Score: 0.7412
Overall Mean IoU Score: 0.5262
Training Loss: 0.4396, Validation Loss: 0.4256, Validation hybrid_score: 0.6552
Epoch 39/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.35it/s, loss=0.428]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.86it/s, loss=0.429]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.7427, Class 2: 0.0000, Class 3: 0.5739, Class 4: 0.7997, Class 5: 0.5785, Class 6: 0.8400, 
Validation F-beta Score
Class 0: 0.9832, Class 1: 0.8731, Class 2: 0.0000, Class 3: 0.5981, Class 4: 0.9089, Class 5: 0.5820, Class 6: 0.9044, 
Validation mIoU Score
Class 0: 0.9745, Class 1: 0.5907, Class 2: 0.0000, Class 3: 0.4025, Class 4: 0.6662, Class 5: 0.4070, Class 6: 0.7241, 

Overall Mean Dice Score: 0.7070
Overall Mean F-beta Score: 0.7733
Overall Mean IoU Score: 0.5581
Training Loss: 0.4390, Validation Loss: 0.4288, Validation hybrid_score: 0.6872
Epoch 40/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.457]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.89it/s, loss=0.391]


Validation Dice Score
Class 0: 0.9883, Class 1: 0.8067, Class 2: 0.2487, Class 3: 0.4982, Class 4: 0.6638, Class 5: 0.6359, Class 6: 0.9381, 
Validation F-beta Score
Class 0: 0.9866, Class 1: 0.8500, Class 2: 0.3167, Class 3: 0.5883, Class 4: 0.7108, Class 5: 0.6234, Class 6: 0.9663, 
Validation mIoU Score
Class 0: 0.9770, Class 1: 0.6760, Class 2: 0.1420, Class 3: 0.3317, Class 4: 0.4968, Class 5: 0.4661, Class 6: 0.8834, 

Overall Mean Dice Score: 0.7085
Overall Mean F-beta Score: 0.7477
Overall Mean IoU Score: 0.5708
Training Loss: 0.4364, Validation Loss: 0.3905, Validation hybrid_score: 0.6770
SUPER Best model saved. Loss:0.3905, Score:0.6770
Epoch 41/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.419]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s, loss=0.416]


Validation Dice Score
Class 0: 0.9876, Class 1: 0.7697, Class 2: 0.0042, Class 3: 0.5121, Class 4: 0.7333, Class 5: 0.5053, Class 6: 0.9251, 
Validation F-beta Score
Class 0: 0.9854, Class 1: 0.8708, Class 2: 0.0130, Class 3: 0.7815, Class 4: 0.6732, Class 5: 0.6528, Class 6: 0.9492, 
Validation mIoU Score
Class 0: 0.9755, Class 1: 0.6257, Class 2: 0.0021, Class 3: 0.3442, Class 4: 0.5789, Class 5: 0.3381, Class 6: 0.8606, 

Overall Mean Dice Score: 0.6891
Overall Mean F-beta Score: 0.7855
Overall Mean IoU Score: 0.5495
Training Loss: 0.4381, Validation Loss: 0.4160, Validation hybrid_score: 0.6911
Epoch 42/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.454]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s, loss=0.445]


Validation Dice Score
Class 0: 0.9881, Class 1: 0.6616, Class 2: 0.1458, Class 3: 0.4646, Class 4: 0.5387, Class 5: 0.4216, Class 6: 0.9410, 
Validation F-beta Score
Class 0: 0.9833, Class 1: 0.8255, Class 2: 0.3343, Class 3: 0.5813, Class 4: 0.6759, Class 5: 0.5304, Class 6: 0.9690, 
Validation mIoU Score
Class 0: 0.9765, Class 1: 0.4943, Class 2: 0.0786, Class 3: 0.3026, Class 4: 0.3686, Class 5: 0.2671, Class 6: 0.8885, 

Overall Mean Dice Score: 0.6055
Overall Mean F-beta Score: 0.7164
Overall Mean IoU Score: 0.4642
Training Loss: 0.4361, Validation Loss: 0.4446, Validation hybrid_score: 0.6155
Epoch 43/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.437]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s, loss=0.439]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.7932, Class 2: 0.0018, Class 3: 0.4180, Class 4: 0.8046, Class 5: 0.4559, Class 6: 0.9275, 
Validation F-beta Score
Class 0: 0.9848, Class 1: 0.9061, Class 2: 0.0021, Class 3: 0.4847, Class 4: 0.8550, Class 5: 0.6265, Class 6: 0.9309, 
Validation mIoU Score
Class 0: 0.9782, Class 1: 0.6572, Class 2: 0.0009, Class 3: 0.2642, Class 4: 0.6731, Class 5: 0.2953, Class 6: 0.8648, 

Overall Mean Dice Score: 0.6798
Overall Mean F-beta Score: 0.7606
Overall Mean IoU Score: 0.5509
Training Loss: 0.4349, Validation Loss: 0.4387, Validation hybrid_score: 0.6768
Epoch 44/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.403]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s, loss=0.406]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7897, Class 2: 0.3119, Class 3: 0.4917, Class 4: 0.6804, Class 5: 0.4179, Class 6: 0.9080, 
Validation F-beta Score
Class 0: 0.9828, Class 1: 0.8375, Class 2: 0.3639, Class 3: 0.5846, Class 4: 0.7096, Class 5: 0.5841, Class 6: 0.9606, 
Validation mIoU Score
Class 0: 0.9746, Class 1: 0.6525, Class 2: 0.1848, Class 3: 0.3260, Class 4: 0.5156, Class 5: 0.2641, Class 6: 0.8315, 

Overall Mean Dice Score: 0.6575
Overall Mean F-beta Score: 0.7353
Overall Mean IoU Score: 0.5179
Training Loss: 0.4384, Validation Loss: 0.4064, Validation hybrid_score: 0.6483
Epoch 45/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.39it/s, loss=0.449]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.81it/s, loss=0.441]


Validation Dice Score
Class 0: 0.9836, Class 1: 0.7448, Class 2: 0.0000, Class 3: 0.4366, Class 4: 0.7175, Class 5: 0.4791, Class 6: 0.8685, 
Validation F-beta Score
Class 0: 0.9802, Class 1: 0.8798, Class 2: 0.0000, Class 3: 0.5161, Class 4: 0.7643, Class 5: 0.4964, Class 6: 0.9309, 
Validation mIoU Score
Class 0: 0.9678, Class 1: 0.5934, Class 2: 0.0000, Class 3: 0.2793, Class 4: 0.5594, Class 5: 0.3150, Class 6: 0.7676, 

Overall Mean Dice Score: 0.6493
Overall Mean F-beta Score: 0.7175
Overall Mean IoU Score: 0.5029
Training Loss: 0.4370, Validation Loss: 0.4415, Validation hybrid_score: 0.6317
Epoch 46/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.434]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s, loss=0.434]


Validation Dice Score
Class 0: 0.9856, Class 1: 0.6599, Class 2: 0.0806, Class 3: 0.4549, Class 4: 0.7582, Class 5: 0.4801, Class 6: 0.9042, 
Validation F-beta Score
Class 0: 0.9815, Class 1: 0.9085, Class 2: 0.2002, Class 3: 0.6008, Class 4: 0.7920, Class 5: 0.5403, Class 6: 0.9557, 
Validation mIoU Score
Class 0: 0.9716, Class 1: 0.4924, Class 2: 0.0420, Class 3: 0.2944, Class 4: 0.6106, Class 5: 0.3159, Class 6: 0.8252, 

Overall Mean Dice Score: 0.6515
Overall Mean F-beta Score: 0.7595
Overall Mean IoU Score: 0.5077
Training Loss: 0.4301, Validation Loss: 0.4342, Validation hybrid_score: 0.6588
Epoch 47/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.441]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s, loss=0.436]


Validation Dice Score
Class 0: 0.9833, Class 1: 0.7847, Class 2: 0.2118, Class 3: 0.6125, Class 4: 0.8010, Class 5: 0.5260, Class 6: 0.7149, 
Validation F-beta Score
Class 0: 0.9797, Class 1: 0.8477, Class 2: 0.2341, Class 3: 0.6748, Class 4: 0.8031, Class 5: 0.6657, Class 6: 0.9198, 
Validation mIoU Score
Class 0: 0.9672, Class 1: 0.6456, Class 2: 0.1185, Class 3: 0.4414, Class 4: 0.6680, Class 5: 0.3568, Class 6: 0.5562, 

Overall Mean Dice Score: 0.6878
Overall Mean F-beta Score: 0.7822
Overall Mean IoU Score: 0.5336
Training Loss: 0.4398, Validation Loss: 0.4362, Validation hybrid_score: 0.6828
Epoch 48/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.433]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s, loss=0.419]


Validation Dice Score
Class 0: 0.9893, Class 1: 0.8068, Class 2: 0.4008, Class 3: 0.4028, Class 4: 0.8011, Class 5: 0.5138, Class 6: 0.8859, 
Validation F-beta Score
Class 0: 0.9873, Class 1: 0.8711, Class 2: 0.3988, Class 3: 0.5264, Class 4: 0.8212, Class 5: 0.5506, Class 6: 0.9124, 
Validation mIoU Score
Class 0: 0.9789, Class 1: 0.6762, Class 2: 0.2506, Class 3: 0.2522, Class 4: 0.6682, Class 5: 0.3457, Class 6: 0.7951, 

Overall Mean Dice Score: 0.6821
Overall Mean F-beta Score: 0.7363
Overall Mean IoU Score: 0.5475
Training Loss: 0.4364, Validation Loss: 0.4188, Validation hybrid_score: 0.6608
Epoch 49/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.429]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.87it/s, loss=0.437]


Validation Dice Score
Class 0: 0.9852, Class 1: 0.7079, Class 2: 0.0876, Class 3: 0.6013, Class 4: 0.5970, Class 5: 0.5176, Class 6: 0.9318, 
Validation F-beta Score
Class 0: 0.9799, Class 1: 0.8695, Class 2: 0.1405, Class 3: 0.7115, Class 4: 0.7268, Class 5: 0.5981, Class 6: 0.9622, 
Validation mIoU Score
Class 0: 0.9707, Class 1: 0.5479, Class 2: 0.0458, Class 3: 0.4299, Class 4: 0.4255, Class 5: 0.3492, Class 6: 0.8724, 

Overall Mean Dice Score: 0.6711
Overall Mean F-beta Score: 0.7736
Overall Mean IoU Score: 0.5250
Training Loss: 0.4351, Validation Loss: 0.4368, Validation hybrid_score: 0.6742
Epoch 50/4000


Training: 100%|██████████| 24/24 [00:18<00:00,  1.29it/s, loss=0.415]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s, loss=0.447]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.6776, Class 2: 0.4092, Class 3: 0.6781, Class 4: 0.8163, Class 5: 0.2417, Class 6: 0.8695, 
Validation F-beta Score
Class 0: 0.9801, Class 1: 0.8635, Class 2: 0.4158, Class 3: 0.8454, Class 4: 0.8532, Class 5: 0.4591, Class 6: 0.9500, 
Validation mIoU Score
Class 0: 0.9739, Class 1: 0.5124, Class 2: 0.2573, Class 3: 0.5130, Class 4: 0.6896, Class 5: 0.1375, Class 6: 0.7692, 

Overall Mean Dice Score: 0.6567
Overall Mean F-beta Score: 0.7942
Overall Mean IoU Score: 0.5243
Training Loss: 0.4389, Validation Loss: 0.4474, Validation hybrid_score: 0.6863
Epoch 51/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.409]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s, loss=0.432]


Validation Dice Score
Class 0: 0.9825, Class 1: 0.7418, Class 2: 0.0730, Class 3: 0.3331, Class 4: 0.7113, Class 5: 0.4938, Class 6: 0.9046, 
Validation F-beta Score
Class 0: 0.9805, Class 1: 0.8564, Class 2: 0.0870, Class 3: 0.4284, Class 4: 0.7301, Class 5: 0.4931, Class 6: 0.9580, 
Validation mIoU Score
Class 0: 0.9656, Class 1: 0.5895, Class 2: 0.0379, Class 3: 0.1999, Class 4: 0.5520, Class 5: 0.3279, Class 6: 0.8259, 

Overall Mean Dice Score: 0.6369
Overall Mean F-beta Score: 0.6932
Overall Mean IoU Score: 0.4990
Training Loss: 0.4369, Validation Loss: 0.4316, Validation hybrid_score: 0.6155
Epoch 52/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.428]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.89it/s, loss=0.432]


Validation Dice Score
Class 0: 0.9862, Class 1: 0.6997, Class 2: 0.1317, Class 3: 0.3562, Class 4: 0.7695, Class 5: 0.4680, Class 6: 0.9079, 
Validation F-beta Score
Class 0: 0.9807, Class 1: 0.8904, Class 2: 0.1614, Class 3: 0.4525, Class 4: 0.8272, Class 5: 0.6264, Class 6: 0.9629, 
Validation mIoU Score
Class 0: 0.9728, Class 1: 0.5381, Class 2: 0.0705, Class 3: 0.2167, Class 4: 0.6254, Class 5: 0.3055, Class 6: 0.8314, 

Overall Mean Dice Score: 0.6403
Overall Mean F-beta Score: 0.7519
Overall Mean IoU Score: 0.5034
Training Loss: 0.4324, Validation Loss: 0.4315, Validation hybrid_score: 0.6525
Epoch 53/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.436]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.81it/s, loss=0.429]


Validation Dice Score
Class 0: 0.9851, Class 1: 0.8086, Class 2: 0.3216, Class 3: 0.4265, Class 4: 0.7768, Class 5: 0.6093, Class 6: 0.8570, 
Validation F-beta Score
Class 0: 0.9832, Class 1: 0.7614, Class 2: 0.3655, Class 3: 0.4282, Class 4: 0.7723, Class 5: 0.7587, Class 6: 0.9349, 
Validation mIoU Score
Class 0: 0.9706, Class 1: 0.6787, Class 2: 0.1916, Class 3: 0.2710, Class 4: 0.6351, Class 5: 0.4381, Class 6: 0.7498, 

Overall Mean Dice Score: 0.6956
Overall Mean F-beta Score: 0.7311
Overall Mean IoU Score: 0.5546
Training Loss: 0.4332, Validation Loss: 0.4285, Validation hybrid_score: 0.6605
Epoch 54/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.413]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s, loss=0.428]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.3261, Class 2: 0.3827, Class 3: 0.5691, Class 4: 0.8050, Class 5: 0.5801, Class 6: 0.8865, 
Validation F-beta Score
Class 0: 0.9851, Class 1: 0.6976, Class 2: 0.3942, Class 3: 0.5541, Class 4: 0.8173, Class 5: 0.6801, Class 6: 0.9533, 
Validation mIoU Score
Class 0: 0.9750, Class 1: 0.1948, Class 2: 0.2366, Class 3: 0.3977, Class 4: 0.6737, Class 5: 0.4085, Class 6: 0.7961, 

Overall Mean Dice Score: 0.6334
Overall Mean F-beta Score: 0.7405
Overall Mean IoU Score: 0.4942
Training Loss: 0.4396, Validation Loss: 0.4278, Validation hybrid_score: 0.6420
Epoch 55/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.44] 
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s, loss=0.424]


Validation Dice Score
Class 0: 0.9836, Class 1: 0.6315, Class 2: 0.0996, Class 3: 0.5892, Class 4: 0.7875, Class 5: 0.4394, Class 6: 0.8702, 
Validation F-beta Score
Class 0: 0.9823, Class 1: 0.8575, Class 2: 0.1453, Class 3: 0.6572, Class 4: 0.7854, Class 5: 0.4476, Class 6: 0.9538, 
Validation mIoU Score
Class 0: 0.9677, Class 1: 0.4615, Class 2: 0.0524, Class 3: 0.4176, Class 4: 0.6495, Class 5: 0.2815, Class 6: 0.7702, 

Overall Mean Dice Score: 0.6636
Overall Mean F-beta Score: 0.7403
Overall Mean IoU Score: 0.5161
Training Loss: 0.4355, Validation Loss: 0.4239, Validation hybrid_score: 0.6506
Epoch 56/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.441]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s, loss=0.45]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.6958, Class 2: 0.2807, Class 3: 0.1906, Class 4: 0.8162, Class 5: 0.5047, Class 6: 0.9306, 
Validation F-beta Score
Class 0: 0.9846, Class 1: 0.8989, Class 2: 0.3146, Class 3: 0.5296, Class 4: 0.7979, Class 5: 0.6007, Class 6: 0.9561, 
Validation mIoU Score
Class 0: 0.9750, Class 1: 0.5335, Class 2: 0.1633, Class 3: 0.1053, Class 4: 0.6894, Class 5: 0.3375, Class 6: 0.8702, 

Overall Mean Dice Score: 0.6276
Overall Mean F-beta Score: 0.7566
Overall Mean IoU Score: 0.5072
Training Loss: 0.4328, Validation Loss: 0.4503, Validation hybrid_score: 0.6569
Epoch 57/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.447]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.86it/s, loss=0.413]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.7058, Class 2: 0.1961, Class 3: 0.4089, Class 4: 0.7307, Class 5: 0.5287, Class 6: 0.8935, 
Validation F-beta Score
Class 0: 0.9816, Class 1: 0.8587, Class 2: 0.2316, Class 3: 0.5740, Class 4: 0.7787, Class 5: 0.6660, Class 6: 0.9180, 
Validation mIoU Score
Class 0: 0.9735, Class 1: 0.5453, Class 2: 0.1087, Class 3: 0.2570, Class 4: 0.5757, Class 5: 0.3594, Class 6: 0.8074, 

Overall Mean Dice Score: 0.6535
Overall Mean F-beta Score: 0.7591
Overall Mean IoU Score: 0.5090
Training Loss: 0.4360, Validation Loss: 0.4132, Validation hybrid_score: 0.6590
Epoch 58/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.441]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.69it/s, loss=0.445]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.7781, Class 2: 0.2526, Class 3: 0.4570, Class 4: 0.7573, Class 5: 0.3939, Class 6: 0.9362, 
Validation F-beta Score
Class 0: 0.9836, Class 1: 0.8695, Class 2: 0.2273, Class 3: 0.6104, Class 4: 0.7852, Class 5: 0.4674, Class 6: 0.9604, 
Validation mIoU Score
Class 0: 0.9739, Class 1: 0.6368, Class 2: 0.1446, Class 3: 0.2962, Class 4: 0.6094, Class 5: 0.2452, Class 6: 0.8801, 

Overall Mean Dice Score: 0.6645
Overall Mean F-beta Score: 0.7386
Overall Mean IoU Score: 0.5336
Training Loss: 0.4377, Validation Loss: 0.4454, Validation hybrid_score: 0.6566
Epoch 59/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.36it/s, loss=0.422]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s, loss=0.447]


Validation Dice Score
Class 0: 0.9864, Class 1: 0.7543, Class 2: 0.0000, Class 3: 0.4964, Class 4: 0.7054, Class 5: 0.4157, Class 6: 0.9190, 
Validation F-beta Score
Class 0: 0.9814, Class 1: 0.8345, Class 2: 0.0000, Class 3: 0.5167, Class 4: 0.8155, Class 5: 0.5028, Class 6: 0.9620, 
Validation mIoU Score
Class 0: 0.9732, Class 1: 0.6055, Class 2: 0.0000, Class 3: 0.3302, Class 4: 0.5448, Class 5: 0.2624, Class 6: 0.8502, 

Overall Mean Dice Score: 0.6582
Overall Mean F-beta Score: 0.7263
Overall Mean IoU Score: 0.5186
Training Loss: 0.4346, Validation Loss: 0.4472, Validation hybrid_score: 0.6432
Epoch 60/4000


Training: 100%|██████████| 24/24 [00:17<00:00,  1.37it/s, loss=0.425]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.79it/s, loss=0.428]

Validation Dice Score
Class 0: 0.9880, Class 1: 0.7175, Class 2: 0.2954, Class 3: 0.4819, Class 4: 0.7488, Class 5: 0.4187, Class 6: 0.8976, 
Validation F-beta Score
Class 0: 0.9835, Class 1: 0.8817, Class 2: 0.4516, Class 3: 0.4960, Class 4: 0.8249, Class 5: 0.5459, Class 6: 0.9342, 
Validation mIoU Score
Class 0: 0.9763, Class 1: 0.5594, Class 2: 0.1733, Class 3: 0.3175, Class 4: 0.5984, Class 5: 0.2648, Class 6: 0.8142, 

Overall Mean Dice Score: 0.6529
Overall Mean F-beta Score: 0.7365
Overall Mean IoU Score: 0.5109
Training Loss: 0.4379, Validation Loss: 0.4281, Validation hybrid_score: 0.6463
Early stopping





0,1
class_0_IoU_score,▆▅▁▅▄▆▆▅▅▃█▇█▅▇▅▇▄▄█▇▇█▅▇▇▇▄▇▇▅
class_0_dice_score,▅▅▆▆▄▆▆█▇▆▇▆▇█▆▂▄▂█▄▅▁▅▄▆▂▆▅▅▅▇
class_0_f_beta_score,▄▇▃▄▁▆▆█▇▄▇▆▄▆▄▂▃▁█▁▂▂▂▄▆▃▅▃▅▃▄
class_1_IoU_score,▆▅▁▅▄▆▆▅▅▃█▇█▅▇▅▇▄▄█▇▇█▅▇▇▇▄▇▇▅
class_1_dice_score,██▆▆▇▇▆█▇▇█▇▆██▇▆██▇▆▇▆█▁▅▆▇█▇▇
class_1_f_beta_score,▆█▆▆▆▅▇▆▇▇▆▇▅█▆▇█▆▇▇▇▆▇▃▁▆█▆▇▆▇
class_2_IoU_score,▆▅▁▅▄▆▆▅▅▃█▇█▅▇▅▇▄▄█▇▇█▅▇▇▇▄▇▇▅
class_2_dice_score,▄▁█▄▄▁▄▄▃▁▅▁▃▁▆▁▂▅█▂█▂▃▇█▃▆▄▅▁▆
class_2_f_beta_score,▃▁█▃▄▁▃▄▂▁▅▁▅▁▆▁▄▄▆▃▆▂▃▆▆▃▅▄▄▁▇
class_3_IoU_score,▆▅▁▅▄▆▆▅▅▃█▇█▅▇▅▇▄▄█▇▇█▅▇▇▇▄▇▇▅

0,1
class_0_IoU_score,0.9342
class_0_dice_score,0.988
class_0_f_beta_score,0.98351
class_1_IoU_score,0.9342
class_1_dice_score,0.71748
class_1_f_beta_score,0.88165
class_2_IoU_score,0.9342
class_2_dice_score,0.29536
class_2_f_beta_score,0.45163
class_3_IoU_score,0.9342


In [17]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# VAl

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import TverskyLoss
import torch
import numpy as np
from tqdm import tqdm
import wandb
from src.dataset.dataset import make_val_dataloader

val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
lamda = 0.52

wandb.init(
    project='czii_SwinUnetR_val',  # 프로젝트 이름 설정
    name='SwinUNETR96_96_lr0.001_lambda0.52_batch2',         # 실행(run) 이름 설정
    config={
        'learning_rate': 0.001,
        'batch_size': batch_size,
        'lambda': lamda,
        'img_size': img_size,
        'device': 'cuda',
        "checkpoint_dir": "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2",
        
    }
)

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

val_loader = make_val_dataloader(
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0
)
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    softmax=True
)
    
    
from monai.metrics import DiceMetric

img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

val_loss, overall_mean_fbeta_score = validate_one_epoch(
    model=model, 
    val_loader=val_loader, 
    criterion=criterion, 
    device=device, 
    epoch=0, 
    calculate_dice_interval=1
)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
class_0_dice_score,▁
class_0_f_beta_score,▁
class_1_dice_score,▁
class_1_f_beta_score,▁
class_2_dice_score,▁
class_2_f_beta_score,▁
class_3_dice_score,▁
class_3_f_beta_score,▁
class_4_dice_score,▁
class_4_f_beta_score,▁

0,1
class_0_dice_score,0.65703
class_0_f_beta_score,0.50748
class_1_dice_score,0.53332
class_1_f_beta_score,0.64703
class_2_dice_score,0.00286
class_2_f_beta_score,0.02334
class_3_dice_score,0.23703
class_3_f_beta_score,0.23033
class_4_dice_score,0.65487
class_4_f_beta_score,0.62525


Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]
  checkpoint = torch.load(pretrain_path, map_location=device)
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s, loss=0.865]

Validation Dice Score
Class 0: 0.6570, Class 1: 0.5333, Class 2: 0.0029, Class 3: 0.2370, 
Class 4: 0.6549, Class 5: 0.4790, Class 6: 0.4255, 
Validation F-beta Score
Class 0: 0.5075, Class 1: 0.6470, Class 2: 0.0233, Class 3: 0.2303, 
Class 4: 0.6252, Class 5: 0.5145, Class 6: 0.4720, 
Overall Mean Dice Score: 0.4659
Overall Mean F-beta Score: 0.4978






: 

: 

# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

: 

: 

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


: 

: 

In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


: 

: 

In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

: 

: 

In [None]:
val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=1, 
            calculate_dice_interval=0
        )

Validation:   0%|          | 0/4 [00:03<?, ?it/s, loss=0.764]


ZeroDivisionError: integer modulo by zero

: 

: 

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv


: 

: 