In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

import random
import numpy as np
import torch

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.config import print_config
from monai.metrics import DiceMetric
# from src.models.swincspunetr import SwinCSPUNETR
# from src.models.swincspunetr_unet import SwinCSPUNETR_unet
# from src.models.swincspunetr3plus import SwinCSPUNETR3plus

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

# 랜덤 시드 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


print_config()

  from .autonotebook import tqdm as notebook_tqdm


MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: c:\Users\<username>\.conda\envs\UM\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.2.0
Tensorboard version: 2.18.0
gdown version: 5.2.0
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.17.2
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.

In [2]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 1500}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 1500},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.16393442622950818, 2: 0.01639344262295082, 3: 0.2459016393442623, 4: 0.16393442622950818, 5: 0.2459016393442623, 6: 0.16393442622950818}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.16393442622950818, 0.01639344262295082, 0.2459016393442623, 0.16393442622950818, 0.2459016393442623, 0.16393442622950818]


# 모델 설정

In [3]:
from src.dataset.dataset import create_dataloaders, create_dataloaders_bw
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd, RandCropd,RandCropByPosNegLabeld, RandGaussianSmoothd
)
from monai.transforms import CastToTyped
import numpy as np

train_img_dir = "./datasets/train_de/images"
train_label_dir = "./datasets/train_de/labels"
val_img_dir = "./datasets/val_de/images"
val_label_dir = "./datasets/val_de/labels"
# DATA CONFIG
img_size =  96 # Match your patch size
img_depth = 32
n_classes = 7
batch_size = 32 # 13.8GB GPU memory required for 128x128 img size
loader_batch = 1
num_samples = batch_size // loader_batch # 한 이미지에서 뽑을 샘플 수
num_repeat = 20
# MODEL CONFIG
num_epochs = 4000
lamda = 0.5
beta = 2.0
ce_weight = 0.4
qfl_weight = 0.3
lr = 0.01
feature_size = (32, 64, 128, 256)
use_checkpoint = True
use_v2 = True
drop_rate= 0.2
attn_drop_rate = 0.2
num_bottleneck = 2
# CLASS_WEIGHTS
class_weights = None
# class_weights = torch.tensor([0.0001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치
# class_weights = torch.tensor([0.9,1,0.9,1.1,1,1.1,1], dtype=torch.float32)  # 클래스별 가중치
sigma = 2.0
accumulation_steps = 1
# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    # GaussianSmoothd(
    #     keys=["image"],      # 변환을 적용할 키
    #     sigma=[sigma, sigma, sigma]  # 각 축(x, y, z)의 시그마 값
    #     ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    # RandGaussianSmoothd(
    # keys=["image"],      # 변환을 적용할 키
    # sigma_x = (0.0, sigma), # 각 축(x, y, z)의 시그마 값
    # sigma_y = (0.0, sigma),
    # sigma_z = (0.0, sigma),
    # prob=1.0,
    # ),
])
val_random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    # RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    # RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    # RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    # RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    # RandGaussianSmoothd(
    # keys=["image"],      # 변환을 적용할 키
    # sigma_x = (0.0, sigma), # 각 축(x, y, z)의 시그마 값
    # sigma_y = (0.0, sigma),
    # sigma_z = (0.0, sigma),
    # prob=1.0,
    # ),
])


In [4]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders_bw(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    val_non_random_transforms=non_random_transforms,
    random_transforms = random_transforms, 
    val_random_transforms=val_random_transforms,
    batch_size = loader_batch,
    num_workers=0,
    train_num_repeat=num_repeat,
    val_num_repeat=num_repeat,
    )

https://monai.io/model-zoo.html

In [5]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from monai.losses import TverskyLoss

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # DynamicTverskyLoss 클래스 정의
# class DynamicTverskyLoss(TverskyLoss):
#     def __init__(self, lamda=0.5, **kwargs):
#         super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
#         self.lamda = lamda

#     def set_lamda(self, lamda):
#         self.lamda = lamda
#         self.alpha = 1 - lamda
#         self.beta = lamda


# # CombinedCETverskyLoss 클래스
# class CombinedCETverskyLoss_cl_weight(nn.Module):
    
    
#     def __init__(self, lamda=0.5, ce_weight=0.5, n_classes=7, class_weights=None, ignore_index=-1, **kwargs):
#         super().__init__()
#         self.n_classes = n_classes
#         self.ce_weight = ce_weight
#         self.ignore_index = ignore_index
        
#         # CrossEntropyLoss에서 클래스별 가중치를 적용
#         self.ce = nn.CrossEntropyLoss(weight=class_weights, ignore_index=self.ignore_index, reduction='mean', **kwargs)
        
#         # TverskyLoss
#         self.tversky = DynamicTverskyLoss(lamda=lamda, reduction="none",softmax=True, **kwargs)

#     def forward(self, inputs, targets):
        
#         # CrossEntropyLoss는 정수형 클래스 인덱스를 사용
#         ce_loss = self.ce(inputs, targets)

#         # TverskyLoss 계산 (원핫 인코딩된 라벨을 사용)
        
#         tversky_loss = self.tversky(inputs, targets)

#         # 클래스별 가중치 적용 (Tversky 손실에도 가중치를 곱하기)
#         class_weights = torch.tensor(self.ce.weight)  # CrossEntropy의 weight를 사용

#         # Tversky 손실이 (B, num_classes) 형태이므로, 가중치를 클래스 차원에 곱합니다.
#         tversky_loss = tversky_loss * class_weights.view(1, self.n_classes)

#         # 최종 손실 계산
#         final_loss = self.ce_weight * ce_loss + (1 - self.ce_weight) * tversky_loss.mean()  # mean()으로 배치에 대해 평균
#         return final_loss

#     def set_lamda(self, lamda):
#         self.tversky.set_lamda(lamda)

#     @property
#     def lamda(self):
#         return self.tversky.lamda

# # CombinedCETverskyLoss 클래스
# class CombinedCETverskyLoss(nn.Module):
    
#     def __init__(self, lamda=0.5, ce_weight=0.5, n_classes=7, class_weights=None, ignore_index=-1, **kwargs):
#         super().__init__()
#         self.n_classes = n_classes
#         self.ce_weight = ce_weight
#         self.ignore_index = ignore_index
        
#         # CrossEntropyLoss에서 클래스별 가중치를 적용
#         self.ce = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction='mean', **kwargs)
        
#         # TverskyLoss
#         self.tversky = DynamicTverskyLoss(lamda=lamda, reduction="mean",softmax=True, **kwargs)

#     def forward(self, inputs, targets):
        
#         # CrossEntropyLoss는 정수형 클래스 인덱스를 사용
#         ce_loss = self.ce(inputs, targets)

#         # TverskyLoss 계산 (원핫 인코딩된 라벨을 사용)
        
#         tversky_loss = self.tversky(inputs, targets)

#         # 최종 손실 계산
#         final_loss = self.ce_weight * ce_loss + (1 - self.ce_weight) * tversky_loss  # mean()으로 배치에 대해 평균
        
#         return final_loss

#     def set_lamda(self, lamda):
#         self.tversky.set_lamda(lamda)

#     @property
#     def lamda(self):
#         return self.tversky.lamda

# criterion = CombinedCETverskyLoss(
#     lamda=lamda,
#     ce_weight=ce_weight,
#     n_classes=n_classes,
#     class_weights=class_weights,
# ).to(device)

In [6]:
import torch

def qfl_loss(cls_pred, quality_pred, gt_onehot, beta=2.0):
    """
    Quality Focal Loss for segmentation tasks.
    
    Args:
        cls_pred: Class probabilities [B, C, D, H, W].
        quality_pred: Predicted quality scores [B, 1, D, H, W].
        gt: Ground truth segmentation mask [B, C, D, H, W] (one-hot encoded).
        beta: Focal Loss hyperparameter.
    
    Returns:
        Loss value (scalar).
    """
    # One-hot encode the ground truth mask
    # num_classes = cls_pred.shape[1]
    
    # Softmax for class probabilities
    prob = torch.softmax(cls_pred, dim=1)  # [B, C, D, H, W]
    prob = torch.clamp(prob, min=1e-6, max=1 - 1e-6)  # 확률 범위 안정화
    
    # 품질 점수를 IoU로 예측했다고 가정 (Sigmoid 사용)
    quality = torch.sigmoid(quality_pred)  # [B, 1, D, H, W]
    quality = torch.clamp(quality, min=1e-6, max=1 - 1e-6)  # 품질 예측 범위 안정화
    
    # Positive Loss: Ground truth mask가 있는 위치에서의 Loss
    pos_mask = gt_onehot == 1
    pos_loss = -((1 - prob)**2) * torch.log(prob) * quality  # [B, C, D, H, W]
    pos_loss = pos_loss[pos_mask].mean()  # Positive 영역 Loss

    # Negative Loss: Ground truth가 없는 위치에서의 Loss
    neg_mask = gt_onehot == 0
    neg_loss = -(prob**beta) * torch.log(1 - prob) * (1 - quality)  # [B, C, D, H, W]
    neg_loss = neg_loss[neg_mask].mean()  # Negative 영역 Loss

    # 최종 Loss
    return pos_loss + neg_loss

import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.losses import TverskyLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DynamicTverskyLoss 클래스 정의
class DynamicTverskyLoss(TverskyLoss):
    def __init__(self, lamda=0.5, **kwargs):
        super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
        self.lamda = lamda

    def set_lamda(self, lamda):
        self.lamda = lamda
        self.alpha = 1 - lamda
        self.beta = lamda

# CombinedCETverskyLoss 클래스
class CombinedCE_QFL_TverskyLoss(nn.Module):
    
    def __init__(self, lamda=0.5, ce_weight=0.4, qfl_weight=0.3,  n_classes=7, class_weights=None, ignore_index=-1, **kwargs):
        super().__init__()
        self.n_classes = n_classes
        self.ce_weight = ce_weight
        self.qfl_weight = qfl_weight
        self.ignore_index = ignore_index
        
        # CrossEntropyLoss에서 클래스별 가중치를 적용
        self.QFL = qfl_loss
        
        self.ce = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction='mean', **kwargs)
        
        # TverskyLoss
        self.tversky = DynamicTverskyLoss(lamda=lamda, reduction="mean",softmax=True, **kwargs)

    def forward(self, inputs, scores, targets, beta=2.0):
        
        # CrossEntropyLoss는 정수형 클래스 인덱스를 사용
        ce_loss = self.ce(inputs, targets)
        
        # QFL_loss 계산
        QFL_loss = self.QFL(inputs, scores, targets, beta)

        # TverskyLoss 계산 (원핫 인코딩된 라벨을 사용)
        
        tversky_loss = self.tversky(inputs, targets)

        # 최종 손실 계산
        final_loss = self.ce_weight * ce_loss + self.qfl_weight * QFL_loss + (1 - self.ce_weight - self.qfl_weight) * tversky_loss  # mean()으로 배치에 대해 평균
        
        return final_loss

    def set_lamda(self, lamda):
        self.tversky.set_lamda(lamda)

    @property
    def lamda(self):
        return self.tversky.lamda

criterion = CombinedCE_QFL_TverskyLoss(
    lamda=lamda,
    ce_weight=ce_weight,
    qfl_weight=qfl_weight,
    n_classes=n_classes,
    class_weights=class_weights,
).to(device)


In [7]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
from monai.networks.nets import UNet
from src.models import UNet_CBAM_bw
from monai.networks.layers.factories import Act, Norm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = UNet_CBAM_bw(
    spatial_dims=3,
    in_channels=1,
    out_channels=n_classes,
    channels=feature_size,
    strides=(2, 2, 2),
    dropout = drop_rate,
    norm = Norm.INSTANCE,
    act = Act.PRELU,
).to(device)

pretrain_str = "yes" if use_checkpoint else "no"
weight_str = "weighted" if class_weights is not None else ""

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
folder_name = f"CBAM_qfl_deno_noclswt_{weight_str}_f{feature_size}_d{img_depth}_s{img_size}_dropr{drop_rate}_lr{lr:.0e}_a{lamda:.2f}_b{batch_size}_r{num_repeat}_ce{ce_weight}_qfl{qfl_weight}_ac{accumulation_steps}"
checkpoint_dir = checkpoint_base_dir / folder_name
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")



In [8]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([32, 1, 32, 96, 96]) torch.Size([32, 1, 32, 96, 96])


In [9]:
torch.backends.cudnn.benchmark = True

In [10]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_UNet',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        # "attn_drop_rate": attn_drop_rate,
        # "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        # "num_bottleneck": num_bottleneck,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpook0612[0m ([33mlimbw[0m). Use [1m`wandb login --relogin`[0m to force relogin


# 학습

In [11]:
from monai.metrics import DiceMetric
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs, scores = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    loss = criterion(outputs, scores, labels_onehot, beta)
    # loss = loss_fn(criterion(outputs, labels_onehot),class_weights=class_weights, device=device)
    return loss, outputs, labels, outputs.argmax(dim=1)

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    with tqdm(train_loader, desc='Training') as pbar:
        for i, batch_data in enumerate(pbar):
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
    avg_loss = epoch_loss / len(train_loader)
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss


def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval, ce_weight):
    model.eval()
    val_loss = 0
    
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    class_mIoU_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())
                        intersection = torch.sum(pred_i & label_i).float()
                        union = torch.sum(pred_i | label_i).float()
                        iou = (intersection + 1e-8) / (union + 1e-8)
                        class_mIoU_scores[i].append(iou.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
               
        print() 
    if epoch % calculate_dice_interval == 0:
        print("Validation mIoU Score")
        all_classes_mIoU_scores = []
        for i in range(n_classes):
            mean_IoU = np.mean(class_mIoU_scores[i])
            wandb.log({f'class_{i}_IoU_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_IoU:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_mIoU_scores.append(mean_IoU)
                
        print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        overall_mean_IoU = np.mean(all_classes_mIoU_scores)
        final_score = overall_mean_fbeta * (1 - ce_weight) + overall_mean_IoU * ce_weight
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1, 'overall_mean_IoU_score': overall_mean_IoU, 'final_score': final_score})  
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\nOverall Mean IoU Score: {overall_mean_IoU:.4f}\nFinal_score: {final_score:.4f}")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    
    
    return val_loss / len(val_loader), final_score 

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4, pretrained=False
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps
        )
        
        scheduler.step(train_loss)
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval,
            ce_weight = 0.5
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation hybrid_score: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            if pretrained:
                checkpoint_path = os.path.join(checkpoint_dir, 'best_model_pretrained.pt')
            else:
                checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        # if epochs_no_improve % 6 == 0 & epochs_no_improve != 0:
        #     # 손실이 개선되지 않았으므로 lambda 감소
        #     new_lamda = max(criterion.lamda - 0.01, 0.35)  # 최소값은 0.1로 설정
        #     criterion.set_lamda(new_lamda)
        #     print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")

    wandb.finish()


In [None]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps,
    pretrained=False
     ) 

Epoch 1/4000


Training: 100%|██████████| 120/120 [03:27<00:00,  1.73s/it, loss=0.305]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s, loss=0.303]


Validation Dice Score
Class 0: 0.9831, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.3657, Class 5: 0.0001, Class 6: 0.4126, 
Validation F-beta Score
Class 0: 0.9920, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.3062, Class 5: 0.0001, Class 6: 0.4872, 
Validation mIoU Score
Class 0: 0.9668, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.2242, Class 5: 0.0000, Class 6: 0.2637, 

Overall Mean Dice Score: 0.1557
Overall Mean F-beta Score: 0.1587
Overall Mean IoU Score: 0.0976
Final_score: 0.1281
Training Loss: 0.3445, Validation Loss: 0.3034, Validation hybrid_score: 0.1281
SUPER Best model saved. Loss:0.3034, Score:0.1281
Epoch 2/4000


Training: 100%|██████████| 120/120 [02:57<00:00,  1.48s/it, loss=0.277]
Validation: 100%|██████████| 20/20 [00:18<00:00,  1.08it/s, loss=0.274]


Validation Dice Score
Class 0: 0.9840, Class 1: 0.0913, Class 2: 0.0000, Class 3: 0.2528, Class 4: 0.5400, Class 5: 0.2160, Class 6: 0.7349, 
Validation F-beta Score
Class 0: 0.9883, Class 1: 0.0567, Class 2: 0.0000, Class 3: 0.2429, Class 4: 0.5105, Class 5: 0.1597, Class 6: 0.8600, 
Validation mIoU Score
Class 0: 0.9686, Class 1: 0.0482, Class 2: 0.0000, Class 3: 0.1463, Class 4: 0.3715, Class 5: 0.1216, Class 6: 0.5824, 

Overall Mean Dice Score: 0.3670
Overall Mean F-beta Score: 0.3660
Overall Mean IoU Score: 0.2540
Final_score: 0.3100
Training Loss: 0.2854, Validation Loss: 0.2754, Validation hybrid_score: 0.3100
SUPER Best model saved. Loss:0.2754, Score:0.3100
Epoch 3/4000


Training: 100%|██████████| 120/120 [02:55<00:00,  1.46s/it, loss=0.255]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.20it/s, loss=0.251]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.4066, Class 2: 0.0005, Class 3: 0.3899, Class 4: 0.6224, Class 5: 0.3231, Class 6: 0.8388, 
Validation F-beta Score
Class 0: 0.9897, Class 1: 0.4304, Class 2: 0.0003, Class 3: 0.3487, Class 4: 0.5866, Class 5: 0.2766, Class 6: 0.8005, 
Validation mIoU Score
Class 0: 0.9726, Class 1: 0.2572, Class 2: 0.0003, Class 3: 0.2432, Class 4: 0.4530, Class 5: 0.1939, Class 6: 0.7230, 

Overall Mean Dice Score: 0.5162
Overall Mean F-beta Score: 0.4886
Overall Mean IoU Score: 0.3741
Final_score: 0.4313
Training Loss: 0.2633, Validation Loss: 0.2557, Validation hybrid_score: 0.4313
SUPER Best model saved. Loss:0.2557, Score:0.4313
Epoch 4/4000


Training: 100%|██████████| 120/120 [02:53<00:00,  1.45s/it, loss=0.262]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.245]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.6053, Class 2: 0.0092, Class 3: 0.3516, Class 4: 0.6587, Class 5: 0.3510, Class 6: 0.8357, 
Validation F-beta Score
Class 0: 0.9860, Class 1: 0.6332, Class 2: 0.0050, Class 3: 0.5260, Class 4: 0.6549, Class 5: 0.3155, Class 6: 0.9358, 
Validation mIoU Score
Class 0: 0.9726, Class 1: 0.4351, Class 2: 0.0046, Class 3: 0.2145, Class 4: 0.4931, Class 5: 0.2145, Class 6: 0.7182, 

Overall Mean Dice Score: 0.5604
Overall Mean F-beta Score: 0.6131
Overall Mean IoU Score: 0.4151
Final_score: 0.5141
Training Loss: 0.2539, Validation Loss: 0.2483, Validation hybrid_score: 0.5141
SUPER Best model saved. Loss:0.2483, Score:0.5141
Epoch 5/4000


Training: 100%|██████████| 120/120 [02:51<00:00,  1.43s/it, loss=0.251]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.245]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.6766, Class 2: 0.0257, Class 3: 0.3891, Class 4: 0.6705, Class 5: 0.3712, Class 6: 0.8422, 
Validation F-beta Score
Class 0: 0.9886, Class 1: 0.6832, Class 2: 0.0142, Class 3: 0.4463, Class 4: 0.6705, Class 5: 0.3235, Class 6: 0.9189, 
Validation mIoU Score
Class 0: 0.9751, Class 1: 0.5131, Class 2: 0.0132, Class 3: 0.2436, Class 4: 0.5068, Class 5: 0.2302, Class 6: 0.7282, 

Overall Mean Dice Score: 0.5899
Overall Mean F-beta Score: 0.6085
Overall Mean IoU Score: 0.4444
Final_score: 0.5264
Training Loss: 0.2494, Validation Loss: 0.2436, Validation hybrid_score: 0.5264
SUPER Best model saved. Loss:0.2436, Score:0.5264
Epoch 6/4000


Training: 100%|██████████| 120/120 [02:45<00:00,  1.38s/it, loss=0.242]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.255]


Validation Dice Score
Class 0: 0.9857, Class 1: 0.7051, Class 2: 0.1035, Class 3: 0.4039, Class 4: 0.6865, Class 5: 0.4260, Class 6: 0.8789, 
Validation F-beta Score
Class 0: 0.9840, Class 1: 0.7423, Class 2: 0.0673, Class 3: 0.5439, Class 4: 0.7394, Class 5: 0.4088, Class 6: 0.9166, 
Validation mIoU Score
Class 0: 0.9718, Class 1: 0.5458, Class 2: 0.0567, Class 3: 0.2572, Class 4: 0.5238, Class 5: 0.2727, Class 6: 0.7844, 

Overall Mean Dice Score: 0.6201
Overall Mean F-beta Score: 0.6702
Overall Mean IoU Score: 0.4768
Final_score: 0.5735
Training Loss: 0.2448, Validation Loss: 0.2412, Validation hybrid_score: 0.5735
SUPER Best model saved. Loss:0.2412, Score:0.5735
Epoch 7/4000


Training: 100%|██████████| 120/120 [02:54<00:00,  1.45s/it, loss=0.246]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s, loss=0.25] 


Validation Dice Score
Class 0: 0.9862, Class 1: 0.6544, Class 2: 0.1228, Class 3: 0.4287, Class 4: 0.6900, Class 5: 0.4367, Class 6: 0.8253, 
Validation F-beta Score
Class 0: 0.9855, Class 1: 0.8275, Class 2: 0.0874, Class 3: 0.5526, Class 4: 0.6691, Class 5: 0.4402, Class 6: 0.9460, 
Validation mIoU Score
Class 0: 0.9727, Class 1: 0.4885, Class 2: 0.0681, Class 3: 0.2739, Class 4: 0.5293, Class 5: 0.2808, Class 6: 0.7067, 

Overall Mean Dice Score: 0.6070
Overall Mean F-beta Score: 0.6870
Overall Mean IoU Score: 0.4559
Final_score: 0.5715
Training Loss: 0.2425, Validation Loss: 0.2411, Validation hybrid_score: 0.5715
Epoch 8/4000


Training: 100%|██████████| 120/120 [02:54<00:00,  1.45s/it, loss=0.24] 
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s, loss=0.243]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.7231, Class 2: 0.0830, Class 3: 0.4454, Class 4: 0.6692, Class 5: 0.4301, Class 6: 0.8501, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.8384, Class 2: 0.0546, Class 3: 0.5067, Class 4: 0.6315, Class 5: 0.3810, Class 6: 0.9481, 
Validation mIoU Score
Class 0: 0.9749, Class 1: 0.5687, Class 2: 0.0449, Class 3: 0.2879, Class 4: 0.5041, Class 5: 0.2756, Class 6: 0.7418, 

Overall Mean Dice Score: 0.6236
Overall Mean F-beta Score: 0.6612
Overall Mean IoU Score: 0.4756
Final_score: 0.5684
Training Loss: 0.2407, Validation Loss: 0.2404, Validation hybrid_score: 0.5684
Epoch 9/4000


Training: 100%|██████████| 120/120 [02:48<00:00,  1.40s/it, loss=0.238]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.236]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.7330, Class 2: 0.1719, Class 3: 0.4325, Class 4: 0.6776, Class 5: 0.4389, Class 6: 0.8532, 
Validation F-beta Score
Class 0: 0.9876, Class 1: 0.8327, Class 2: 0.1398, Class 3: 0.6130, Class 4: 0.6458, Class 5: 0.3967, Class 6: 0.9532, 
Validation mIoU Score
Class 0: 0.9744, Class 1: 0.5795, Class 2: 0.0977, Class 3: 0.2774, Class 4: 0.5139, Class 5: 0.2830, Class 6: 0.7465, 

Overall Mean Dice Score: 0.6270
Overall Mean F-beta Score: 0.6883
Overall Mean IoU Score: 0.4800
Final_score: 0.5842
Training Loss: 0.2395, Validation Loss: 0.2383, Validation hybrid_score: 0.5842
SUPER Best model saved. Loss:0.2383, Score:0.5842
Epoch 10/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.241]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.233]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7679, Class 2: 0.1798, Class 3: 0.4733, Class 4: 0.6960, Class 5: 0.4718, Class 6: 0.8988, 
Validation F-beta Score
Class 0: 0.9879, Class 1: 0.7529, Class 2: 0.1763, Class 3: 0.5451, Class 4: 0.6405, Class 5: 0.4811, Class 6: 0.9428, 
Validation mIoU Score
Class 0: 0.9747, Class 1: 0.6250, Class 2: 0.1032, Class 3: 0.3122, Class 4: 0.5349, Class 5: 0.3106, Class 6: 0.8167, 

Overall Mean Dice Score: 0.6616
Overall Mean F-beta Score: 0.6725
Overall Mean IoU Score: 0.5199
Final_score: 0.5962
Training Loss: 0.2379, Validation Loss: 0.2328, Validation hybrid_score: 0.5962
SUPER Best model saved. Loss:0.2328, Score:0.5962
Epoch 11/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.248]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.24] 


Validation Dice Score
Class 0: 0.9882, Class 1: 0.7445, Class 2: 0.1937, Class 3: 0.4445, Class 4: 0.7087, Class 5: 0.4231, Class 6: 0.8605, 
Validation F-beta Score
Class 0: 0.9900, Class 1: 0.8603, Class 2: 0.1483, Class 3: 0.5630, Class 4: 0.6717, Class 5: 0.3569, Class 6: 0.9231, 
Validation mIoU Score
Class 0: 0.9766, Class 1: 0.5942, Class 2: 0.1106, Class 3: 0.2871, Class 4: 0.5498, Class 5: 0.2697, Class 6: 0.7652, 

Overall Mean Dice Score: 0.6363
Overall Mean F-beta Score: 0.6750
Overall Mean IoU Score: 0.4932
Final_score: 0.5841
Training Loss: 0.2367, Validation Loss: 0.2337, Validation hybrid_score: 0.5841
Epoch 12/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.237]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s, loss=0.233]


Validation Dice Score
Class 0: 0.9875, Class 1: 0.7695, Class 2: 0.1689, Class 3: 0.4416, Class 4: 0.7132, Class 5: 0.4672, Class 6: 0.8634, 
Validation F-beta Score
Class 0: 0.9867, Class 1: 0.8185, Class 2: 0.1509, Class 3: 0.5636, Class 4: 0.7308, Class 5: 0.4439, Class 6: 0.9500, 
Validation mIoU Score
Class 0: 0.9753, Class 1: 0.6264, Class 2: 0.0947, Class 3: 0.2849, Class 4: 0.5553, Class 5: 0.3062, Class 6: 0.7610, 

Overall Mean Dice Score: 0.6510
Overall Mean F-beta Score: 0.7014
Overall Mean IoU Score: 0.5068
Final_score: 0.6041
Training Loss: 0.2352, Validation Loss: 0.2341, Validation hybrid_score: 0.6041
Epoch 13/4000


Training: 100%|██████████| 120/120 [02:48<00:00,  1.40s/it, loss=0.237]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.248]


Validation Dice Score
Class 0: 0.9841, Class 1: 0.6812, Class 2: 0.1527, Class 3: 0.4407, Class 4: 0.6565, Class 5: 0.4332, Class 6: 0.8320, 
Validation F-beta Score
Class 0: 0.9794, Class 1: 0.8951, Class 2: 0.1302, Class 3: 0.5799, Class 4: 0.8049, Class 5: 0.4180, Class 6: 0.9511, 
Validation mIoU Score
Class 0: 0.9686, Class 1: 0.5188, Class 2: 0.0881, Class 3: 0.2838, Class 4: 0.4911, Class 5: 0.2784, Class 6: 0.7153, 

Overall Mean Dice Score: 0.6087
Overall Mean F-beta Score: 0.7298
Overall Mean IoU Score: 0.4575
Final_score: 0.5937
Training Loss: 0.2369, Validation Loss: 0.2436, Validation hybrid_score: 0.5937
Epoch 14/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.38s/it, loss=0.246]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s, loss=0.236]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7867, Class 2: 0.1675, Class 3: 0.4910, Class 4: 0.7314, Class 5: 0.4663, Class 6: 0.8960, 
Validation F-beta Score
Class 0: 0.9858, Class 1: 0.8123, Class 2: 0.1643, Class 3: 0.5144, Class 4: 0.7290, Class 5: 0.5095, Class 6: 0.9498, 
Validation mIoU Score
Class 0: 0.9741, Class 1: 0.6498, Class 2: 0.0947, Class 3: 0.3281, Class 4: 0.5772, Class 5: 0.3056, Class 6: 0.8122, 

Overall Mean Dice Score: 0.6743
Overall Mean F-beta Score: 0.7030
Overall Mean IoU Score: 0.5346
Final_score: 0.6188
Training Loss: 0.2347, Validation Loss: 0.2327, Validation hybrid_score: 0.6188
SUPER Best model saved. Loss:0.2327, Score:0.6188
Epoch 15/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.239]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.229]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.7501, Class 2: 0.2005, Class 3: 0.4763, Class 4: 0.6864, Class 5: 0.4522, Class 6: 0.8869, 
Validation F-beta Score
Class 0: 0.9838, Class 1: 0.8760, Class 2: 0.1629, Class 3: 0.5668, Class 4: 0.7862, Class 5: 0.4341, Class 6: 0.9548, 
Validation mIoU Score
Class 0: 0.9734, Class 1: 0.6018, Class 2: 0.1147, Class 3: 0.3140, Class 4: 0.5250, Class 5: 0.2937, Class 6: 0.7974, 

Overall Mean Dice Score: 0.6504
Overall Mean F-beta Score: 0.7236
Overall Mean IoU Score: 0.5064
Final_score: 0.6150
Training Loss: 0.2342, Validation Loss: 0.2358, Validation hybrid_score: 0.6150
Epoch 16/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.227]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.224]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.7232, Class 2: 0.1428, Class 3: 0.4663, Class 4: 0.7289, Class 5: 0.4625, Class 6: 0.8664, 
Validation F-beta Score
Class 0: 0.9850, Class 1: 0.8694, Class 2: 0.1173, Class 3: 0.5632, Class 4: 0.7720, Class 5: 0.4456, Class 6: 0.9501, 
Validation mIoU Score
Class 0: 0.9738, Class 1: 0.5697, Class 2: 0.0801, Class 3: 0.3057, Class 4: 0.5748, Class 5: 0.3029, Class 6: 0.7656, 

Overall Mean Dice Score: 0.6495
Overall Mean F-beta Score: 0.7200
Overall Mean IoU Score: 0.5037
Final_score: 0.6119
Training Loss: 0.2337, Validation Loss: 0.2370, Validation hybrid_score: 0.6119
Epoch 17/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.241]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.227]


Validation Dice Score
Class 0: 0.9879, Class 1: 0.7744, Class 2: 0.2465, Class 3: 0.4520, Class 4: 0.6998, Class 5: 0.4803, Class 6: 0.8927, 
Validation F-beta Score
Class 0: 0.9870, Class 1: 0.8654, Class 2: 0.2636, Class 3: 0.5268, Class 4: 0.7291, Class 5: 0.4652, Class 6: 0.9535, 
Validation mIoU Score
Class 0: 0.9762, Class 1: 0.6338, Class 2: 0.1474, Class 3: 0.2939, Class 4: 0.5399, Class 5: 0.3174, Class 6: 0.8070, 

Overall Mean Dice Score: 0.6599
Overall Mean F-beta Score: 0.7080
Overall Mean IoU Score: 0.5184
Final_score: 0.6132
Training Loss: 0.2325, Validation Loss: 0.2316, Validation hybrid_score: 0.6132
Epoch 18/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.233]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s, loss=0.216]


Validation Dice Score
Class 0: 0.9864, Class 1: 0.7964, Class 2: 0.2117, Class 3: 0.4858, Class 4: 0.7099, Class 5: 0.4823, Class 6: 0.8808, 
Validation F-beta Score
Class 0: 0.9834, Class 1: 0.8454, Class 2: 0.2307, Class 3: 0.6193, Class 4: 0.7637, Class 5: 0.5103, Class 6: 0.9575, 
Validation mIoU Score
Class 0: 0.9731, Class 1: 0.6626, Class 2: 0.1269, Class 3: 0.3228, Class 4: 0.5513, Class 5: 0.3198, Class 6: 0.7880, 

Overall Mean Dice Score: 0.6710
Overall Mean F-beta Score: 0.7392
Overall Mean IoU Score: 0.5289
Final_score: 0.6341
Training Loss: 0.2323, Validation Loss: 0.2309, Validation hybrid_score: 0.6341
SUPER Best model saved. Loss:0.2309, Score:0.6341
Epoch 19/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.229]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.238]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.7630, Class 2: 0.1791, Class 3: 0.4619, Class 4: 0.7310, Class 5: 0.4900, Class 6: 0.8801, 
Validation F-beta Score
Class 0: 0.9865, Class 1: 0.8599, Class 2: 0.1734, Class 3: 0.5642, Class 4: 0.7425, Class 5: 0.4783, Class 6: 0.9562, 
Validation mIoU Score
Class 0: 0.9751, Class 1: 0.6222, Class 2: 0.1034, Class 3: 0.3027, Class 4: 0.5769, Class 5: 0.3255, Class 6: 0.7867, 

Overall Mean Dice Score: 0.6652
Overall Mean F-beta Score: 0.7202
Overall Mean IoU Score: 0.5228
Final_score: 0.6215
Training Loss: 0.2325, Validation Loss: 0.2318, Validation hybrid_score: 0.6215
Epoch 20/4000


Training: 100%|██████████| 120/120 [02:47<00:00,  1.40s/it, loss=0.236]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.222]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.7614, Class 2: 0.2289, Class 3: 0.4902, Class 4: 0.7267, Class 5: 0.4791, Class 6: 0.8669, 
Validation F-beta Score
Class 0: 0.9840, Class 1: 0.8783, Class 2: 0.2450, Class 3: 0.6363, Class 4: 0.8000, Class 5: 0.4865, Class 6: 0.9627, 
Validation mIoU Score
Class 0: 0.9746, Class 1: 0.6163, Class 2: 0.1336, Class 3: 0.3273, Class 4: 0.5723, Class 5: 0.3171, Class 6: 0.7658, 

Overall Mean Dice Score: 0.6649
Overall Mean F-beta Score: 0.7528
Overall Mean IoU Score: 0.5198
Final_score: 0.6363
Training Loss: 0.2301, Validation Loss: 0.2303, Validation hybrid_score: 0.6363
SUPER Best model saved. Loss:0.2303, Score:0.6363
Epoch 21/4000


Training: 100%|██████████| 120/120 [02:45<00:00,  1.38s/it, loss=0.233]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.232]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.7458, Class 2: 0.2132, Class 3: 0.4246, Class 4: 0.7284, Class 5: 0.4941, Class 6: 0.8801, 
Validation F-beta Score
Class 0: 0.9838, Class 1: 0.8864, Class 2: 0.2923, Class 3: 0.6574, Class 4: 0.7560, Class 5: 0.4903, Class 6: 0.9599, 
Validation mIoU Score
Class 0: 0.9735, Class 1: 0.5965, Class 2: 0.1236, Class 3: 0.2708, Class 4: 0.5745, Class 5: 0.3304, Class 6: 0.7864, 

Overall Mean Dice Score: 0.6546
Overall Mean F-beta Score: 0.7500
Overall Mean IoU Score: 0.5117
Final_score: 0.6309
Training Loss: 0.2320, Validation Loss: 0.2336, Validation hybrid_score: 0.6309
Epoch 22/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.38s/it, loss=0.24] 
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.24] 


Validation Dice Score
Class 0: 0.9868, Class 1: 0.7804, Class 2: 0.2101, Class 3: 0.4516, Class 4: 0.7345, Class 5: 0.5049, Class 6: 0.8815, 
Validation F-beta Score
Class 0: 0.9834, Class 1: 0.8456, Class 2: 0.2662, Class 3: 0.6260, Class 4: 0.8045, Class 5: 0.5227, Class 6: 0.9500, 
Validation mIoU Score
Class 0: 0.9740, Class 1: 0.6414, Class 2: 0.1216, Class 3: 0.2926, Class 4: 0.5820, Class 5: 0.3396, Class 6: 0.7894, 

Overall Mean Dice Score: 0.6706
Overall Mean F-beta Score: 0.7498
Overall Mean IoU Score: 0.5290
Final_score: 0.6394
Training Loss: 0.2305, Validation Loss: 0.2321, Validation hybrid_score: 0.6394
Epoch 23/4000


Training: 100%|██████████| 120/120 [02:47<00:00,  1.39s/it, loss=0.231]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.21it/s, loss=0.229]


Validation Dice Score
Class 0: 0.9864, Class 1: 0.7429, Class 2: 0.2089, Class 3: 0.4375, Class 4: 0.7187, Class 5: 0.4688, Class 6: 0.9047, 
Validation F-beta Score
Class 0: 0.9836, Class 1: 0.8777, Class 2: 0.2279, Class 3: 0.5752, Class 4: 0.7231, Class 5: 0.5409, Class 6: 0.9516, 
Validation mIoU Score
Class 0: 0.9732, Class 1: 0.5932, Class 2: 0.1223, Class 3: 0.2827, Class 4: 0.5630, Class 5: 0.3084, Class 6: 0.8267, 

Overall Mean Dice Score: 0.6545
Overall Mean F-beta Score: 0.7337
Overall Mean IoU Score: 0.5148
Final_score: 0.6242
Training Loss: 0.2297, Validation Loss: 0.2326, Validation hybrid_score: 0.6242
Epoch 24/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.231]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.25it/s, loss=0.233]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.7847, Class 2: 0.1963, Class 3: 0.4714, Class 4: 0.7067, Class 5: 0.5091, Class 6: 0.8667, 
Validation F-beta Score
Class 0: 0.9842, Class 1: 0.8724, Class 2: 0.2097, Class 3: 0.6122, Class 4: 0.7324, Class 5: 0.5586, Class 6: 0.9616, 
Validation mIoU Score
Class 0: 0.9740, Class 1: 0.6481, Class 2: 0.1112, Class 3: 0.3108, Class 4: 0.5487, Class 5: 0.3423, Class 6: 0.7652, 

Overall Mean Dice Score: 0.6677
Overall Mean F-beta Score: 0.7475
Overall Mean IoU Score: 0.5230
Final_score: 0.6352
Training Loss: 0.2304, Validation Loss: 0.2322, Validation hybrid_score: 0.6352
Epoch 25/4000


Training: 100%|██████████| 120/120 [02:47<00:00,  1.39s/it, loss=0.23] 
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.229]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.7963, Class 2: 0.1860, Class 3: 0.4917, Class 4: 0.7249, Class 5: 0.4783, Class 6: 0.8593, 
Validation F-beta Score
Class 0: 0.9837, Class 1: 0.8569, Class 2: 0.1968, Class 3: 0.5668, Class 4: 0.7204, Class 5: 0.5710, Class 6: 0.9640, 
Validation mIoU Score
Class 0: 0.9735, Class 1: 0.6626, Class 2: 0.1053, Class 3: 0.3273, Class 4: 0.5699, Class 5: 0.3173, Class 6: 0.7540, 

Overall Mean Dice Score: 0.6701
Overall Mean F-beta Score: 0.7358
Overall Mean IoU Score: 0.5262
Final_score: 0.6310
Training Loss: 0.2298, Validation Loss: 0.2319, Validation hybrid_score: 0.6310
Epoch 26/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.225]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.227]


Validation Dice Score
Class 0: 0.9864, Class 1: 0.7494, Class 2: 0.2123, Class 3: 0.4888, Class 4: 0.7390, Class 5: 0.5128, Class 6: 0.8360, 
Validation F-beta Score
Class 0: 0.9835, Class 1: 0.8968, Class 2: 0.2337, Class 3: 0.6274, Class 4: 0.7503, Class 5: 0.5706, Class 6: 0.9636, 
Validation mIoU Score
Class 0: 0.9732, Class 1: 0.6005, Class 2: 0.1216, Class 3: 0.3243, Class 4: 0.5875, Class 5: 0.3465, Class 6: 0.7197, 

Overall Mean Dice Score: 0.6652
Overall Mean F-beta Score: 0.7617
Overall Mean IoU Score: 0.5157
Final_score: 0.6387
Training Loss: 0.2296, Validation Loss: 0.2333, Validation hybrid_score: 0.6387
Epoch 27/4000


Training: 100%|██████████| 120/120 [02:57<00:00,  1.48s/it, loss=0.228]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.247]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.7876, Class 2: 0.2495, Class 3: 0.4647, Class 4: 0.7304, Class 5: 0.4850, Class 6: 0.8676, 
Validation F-beta Score
Class 0: 0.9832, Class 1: 0.7994, Class 2: 0.3051, Class 3: 0.5457, Class 4: 0.7566, Class 5: 0.5807, Class 6: 0.9647, 
Validation mIoU Score
Class 0: 0.9737, Class 1: 0.6510, Class 2: 0.1457, Class 3: 0.3051, Class 4: 0.5778, Class 5: 0.3234, Class 6: 0.7680, 

Overall Mean Dice Score: 0.6671
Overall Mean F-beta Score: 0.7294
Overall Mean IoU Score: 0.5251
Final_score: 0.6272
Training Loss: 0.2295, Validation Loss: 0.2330, Validation hybrid_score: 0.6272
Epoch 28/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.44s/it, loss=0.238]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s, loss=0.231]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7368, Class 2: 0.2657, Class 3: 0.4848, Class 4: 0.7458, Class 5: 0.4835, Class 6: 0.8790, 
Validation F-beta Score
Class 0: 0.9850, Class 1: 0.8800, Class 2: 0.2612, Class 3: 0.6347, Class 4: 0.7625, Class 5: 0.5157, Class 6: 0.9605, 
Validation mIoU Score
Class 0: 0.9747, Class 1: 0.5852, Class 2: 0.1575, Class 3: 0.3209, Class 4: 0.5954, Class 5: 0.3208, Class 6: 0.7850, 

Overall Mean Dice Score: 0.6660
Overall Mean F-beta Score: 0.7507
Overall Mean IoU Score: 0.5215
Final_score: 0.6361
Training Loss: 0.2302, Validation Loss: 0.2305, Validation hybrid_score: 0.6361
Epoch 29/4000


Training: 100%|██████████| 120/120 [02:53<00:00,  1.45s/it, loss=0.244]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.21it/s, loss=0.238]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7818, Class 2: 0.2945, Class 3: 0.4876, Class 4: 0.7425, Class 5: 0.4827, Class 6: 0.8846, 
Validation F-beta Score
Class 0: 0.9846, Class 1: 0.8844, Class 2: 0.2826, Class 3: 0.6181, Class 4: 0.7243, Class 5: 0.5710, Class 6: 0.9594, 
Validation mIoU Score
Class 0: 0.9741, Class 1: 0.6433, Class 2: 0.1793, Class 3: 0.3240, Class 4: 0.5914, Class 5: 0.3201, Class 6: 0.7936, 

Overall Mean Dice Score: 0.6759
Overall Mean F-beta Score: 0.7514
Overall Mean IoU Score: 0.5345
Final_score: 0.6430
Training Loss: 0.2285, Validation Loss: 0.2296, Validation hybrid_score: 0.6430
SUPER Best model saved. Loss:0.2296, Score:0.6430
Epoch 30/4000


Training: 100%|██████████| 120/120 [02:55<00:00,  1.46s/it, loss=0.235]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.25it/s, loss=0.226]


Validation Dice Score
Class 0: 0.9862, Class 1: 0.7965, Class 2: 0.1730, Class 3: 0.4753, Class 4: 0.7248, Class 5: 0.5002, Class 6: 0.8727, 
Validation F-beta Score
Class 0: 0.9841, Class 1: 0.8580, Class 2: 0.1605, Class 3: 0.5910, Class 4: 0.7359, Class 5: 0.5461, Class 6: 0.9606, 
Validation mIoU Score
Class 0: 0.9728, Class 1: 0.6630, Class 2: 0.0998, Class 3: 0.3131, Class 4: 0.5693, Class 5: 0.3351, Class 6: 0.7752, 

Overall Mean Dice Score: 0.6739
Overall Mean F-beta Score: 0.7383
Overall Mean IoU Score: 0.5311
Final_score: 0.6347
Training Loss: 0.2284, Validation Loss: 0.2341, Validation hybrid_score: 0.6347
Epoch 31/4000


Training: 100%|██████████| 120/120 [02:51<00:00,  1.43s/it, loss=0.234]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s, loss=0.246]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.7471, Class 2: 0.2668, Class 3: 0.4651, Class 4: 0.7277, Class 5: 0.4941, Class 6: 0.8578, 
Validation F-beta Score
Class 0: 0.9808, Class 1: 0.8874, Class 2: 0.3243, Class 3: 0.6059, Class 4: 0.7794, Class 5: 0.6093, Class 6: 0.9689, 
Validation mIoU Score
Class 0: 0.9723, Class 1: 0.5991, Class 2: 0.1609, Class 3: 0.3047, Class 4: 0.5733, Class 5: 0.3299, Class 6: 0.7522, 

Overall Mean Dice Score: 0.6584
Overall Mean F-beta Score: 0.7702
Overall Mean IoU Score: 0.5118
Final_score: 0.6410
Training Loss: 0.2293, Validation Loss: 0.2351, Validation hybrid_score: 0.6410
Epoch 32/4000


Training: 100%|██████████| 120/120 [02:51<00:00,  1.43s/it, loss=0.249]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.219]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7942, Class 2: 0.2756, Class 3: 0.4625, Class 4: 0.7131, Class 5: 0.4591, Class 6: 0.8862, 
Validation F-beta Score
Class 0: 0.9875, Class 1: 0.8426, Class 2: 0.2438, Class 3: 0.5904, Class 4: 0.7358, Class 5: 0.4057, Class 6: 0.9549, 
Validation mIoU Score
Class 0: 0.9756, Class 1: 0.6598, Class 2: 0.1660, Class 3: 0.3024, Class 4: 0.5557, Class 5: 0.3009, Class 6: 0.7966, 

Overall Mean Dice Score: 0.6630
Overall Mean F-beta Score: 0.7059
Overall Mean IoU Score: 0.5231
Final_score: 0.6145
Training Loss: 0.2285, Validation Loss: 0.2309, Validation hybrid_score: 0.6145
Epoch 33/4000


Training: 100%|██████████| 120/120 [02:51<00:00,  1.43s/it, loss=0.22] 
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s, loss=0.224]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.7553, Class 2: 0.2934, Class 3: 0.4706, Class 4: 0.7376, Class 5: 0.4881, Class 6: 0.8673, 
Validation F-beta Score
Class 0: 0.9847, Class 1: 0.8531, Class 2: 0.3188, Class 3: 0.6612, Class 4: 0.7370, Class 5: 0.5328, Class 6: 0.9668, 
Validation mIoU Score
Class 0: 0.9745, Class 1: 0.6091, Class 2: 0.1758, Class 3: 0.3096, Class 4: 0.5865, Class 5: 0.3246, Class 6: 0.7668, 

Overall Mean Dice Score: 0.6638
Overall Mean F-beta Score: 0.7502
Overall Mean IoU Score: 0.5193
Final_score: 0.6348
Training Loss: 0.2275, Validation Loss: 0.2324, Validation hybrid_score: 0.6348
Epoch 34/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.44s/it, loss=0.238]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s, loss=0.224]


Validation Dice Score
Class 0: 0.9876, Class 1: 0.7647, Class 2: 0.1246, Class 3: 0.4934, Class 4: 0.7245, Class 5: 0.5089, Class 6: 0.8973, 
Validation F-beta Score
Class 0: 0.9874, Class 1: 0.8832, Class 2: 0.0904, Class 3: 0.6105, Class 4: 0.7109, Class 5: 0.5094, Class 6: 0.9536, 
Validation mIoU Score
Class 0: 0.9755, Class 1: 0.6213, Class 2: 0.0695, Class 3: 0.3300, Class 4: 0.5697, Class 5: 0.3431, Class 6: 0.8142, 

Overall Mean Dice Score: 0.6778
Overall Mean F-beta Score: 0.7335
Overall Mean IoU Score: 0.5357
Final_score: 0.6346
Training Loss: 0.2276, Validation Loss: 0.2309, Validation hybrid_score: 0.6346
Epoch 35/4000


Training: 100%|██████████| 120/120 [02:53<00:00,  1.45s/it, loss=0.225]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s, loss=0.236]


Validation Dice Score
Class 0: 0.9857, Class 1: 0.7815, Class 2: 0.2534, Class 3: 0.4805, Class 4: 0.7301, Class 5: 0.4905, Class 6: 0.8799, 
Validation F-beta Score
Class 0: 0.9811, Class 1: 0.8297, Class 2: 0.2318, Class 3: 0.5859, Class 4: 0.8069, Class 5: 0.5817, Class 6: 0.9604, 
Validation mIoU Score
Class 0: 0.9718, Class 1: 0.6425, Class 2: 0.1503, Class 3: 0.3181, Class 4: 0.5756, Class 5: 0.3268, Class 6: 0.7862, 

Overall Mean Dice Score: 0.6725
Overall Mean F-beta Score: 0.7529
Overall Mean IoU Score: 0.5299
Final_score: 0.6414
Training Loss: 0.2284, Validation Loss: 0.2325, Validation hybrid_score: 0.6414
Epoch 36/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.43s/it, loss=0.233]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.224]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7657, Class 2: 0.2086, Class 3: 0.4667, Class 4: 0.7097, Class 5: 0.5142, Class 6: 0.8743, 
Validation F-beta Score
Class 0: 0.9837, Class 1: 0.8811, Class 2: 0.2090, Class 3: 0.6151, Class 4: 0.8008, Class 5: 0.5086, Class 6: 0.9607, 
Validation mIoU Score
Class 0: 0.9741, Class 1: 0.6221, Class 2: 0.1219, Class 3: 0.3058, Class 4: 0.5514, Class 5: 0.3486, Class 6: 0.7776, 

Overall Mean Dice Score: 0.6661
Overall Mean F-beta Score: 0.7533
Overall Mean IoU Score: 0.5211
Final_score: 0.6372
Training Loss: 0.2285, Validation Loss: 0.2295, Validation hybrid_score: 0.6372
Epoch 37/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.43s/it, loss=0.229]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.229]


Validation Dice Score
Class 0: 0.9882, Class 1: 0.7920, Class 2: 0.2970, Class 3: 0.4942, Class 4: 0.7367, Class 5: 0.4865, Class 6: 0.9059, 
Validation F-beta Score
Class 0: 0.9866, Class 1: 0.8055, Class 2: 0.2892, Class 3: 0.6020, Class 4: 0.7739, Class 5: 0.5019, Class 6: 0.9491, 
Validation mIoU Score
Class 0: 0.9767, Class 1: 0.6571, Class 2: 0.1813, Class 3: 0.3294, Class 4: 0.5843, Class 5: 0.3239, Class 6: 0.8282, 

Overall Mean Dice Score: 0.6830
Overall Mean F-beta Score: 0.7265
Overall Mean IoU Score: 0.5446
Final_score: 0.6355
Training Loss: 0.2275, Validation Loss: 0.2277, Validation hybrid_score: 0.6355
Epoch 38/4000


Training: 100%|██████████| 120/120 [02:51<00:00,  1.43s/it, loss=0.234]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s, loss=0.243]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.8002, Class 2: 0.2309, Class 3: 0.4798, Class 4: 0.7487, Class 5: 0.4835, Class 6: 0.8875, 
Validation F-beta Score
Class 0: 0.9868, Class 1: 0.8673, Class 2: 0.2557, Class 3: 0.6008, Class 4: 0.7534, Class 5: 0.4736, Class 6: 0.9651, 
Validation mIoU Score
Class 0: 0.9758, Class 1: 0.6681, Class 2: 0.1379, Class 3: 0.3175, Class 4: 0.5994, Class 5: 0.3203, Class 6: 0.7981, 

Overall Mean Dice Score: 0.6800
Overall Mean F-beta Score: 0.7320
Overall Mean IoU Score: 0.5407
Final_score: 0.6364
Training Loss: 0.2270, Validation Loss: 0.2287, Validation hybrid_score: 0.6364
Epoch 39/4000


Training: 100%|██████████| 120/120 [02:51<00:00,  1.43s/it, loss=0.223]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.232]


Validation Dice Score
Class 0: 0.9862, Class 1: 0.7595, Class 2: 0.2387, Class 3: 0.4780, Class 4: 0.7373, Class 5: 0.4920, Class 6: 0.8845, 
Validation F-beta Score
Class 0: 0.9829, Class 1: 0.9041, Class 2: 0.2756, Class 3: 0.5866, Class 4: 0.7662, Class 5: 0.5582, Class 6: 0.9667, 
Validation mIoU Score
Class 0: 0.9728, Class 1: 0.6138, Class 2: 0.1435, Class 3: 0.3157, Class 4: 0.5847, Class 5: 0.3279, Class 6: 0.7940, 

Overall Mean Dice Score: 0.6703
Overall Mean F-beta Score: 0.7564
Overall Mean IoU Score: 0.5272
Final_score: 0.6418
Training Loss: 0.2271, Validation Loss: 0.2310, Validation hybrid_score: 0.6418
Epoch 40/4000


Training: 100%|██████████| 120/120 [02:50<00:00,  1.42s/it, loss=0.23] 
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.22it/s, loss=0.23] 


Validation Dice Score
Class 0: 0.9871, Class 1: 0.8018, Class 2: 0.2173, Class 3: 0.4886, Class 4: 0.7318, Class 5: 0.4704, Class 6: 0.8875, 
Validation F-beta Score
Class 0: 0.9867, Class 1: 0.8192, Class 2: 0.2436, Class 3: 0.6059, Class 4: 0.7260, Class 5: 0.4523, Class 6: 0.9587, 
Validation mIoU Score
Class 0: 0.9745, Class 1: 0.6700, Class 2: 0.1264, Class 3: 0.3255, Class 4: 0.5776, Class 5: 0.3096, Class 6: 0.7984, 

Overall Mean Dice Score: 0.6760
Overall Mean F-beta Score: 0.7124
Overall Mean IoU Score: 0.5362
Final_score: 0.6243
Training Loss: 0.2266, Validation Loss: 0.2298, Validation hybrid_score: 0.6243
Epoch 41/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.44s/it, loss=0.234]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s, loss=0.24] 


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7703, Class 2: 0.2469, Class 3: 0.4847, Class 4: 0.7141, Class 5: 0.4844, Class 6: 0.8714, 
Validation F-beta Score
Class 0: 0.9856, Class 1: 0.8788, Class 2: 0.2414, Class 3: 0.6074, Class 4: 0.7233, Class 5: 0.4951, Class 6: 0.9643, 
Validation mIoU Score
Class 0: 0.9741, Class 1: 0.6279, Class 2: 0.1445, Class 3: 0.3224, Class 4: 0.5565, Class 5: 0.3221, Class 6: 0.7723, 

Overall Mean Dice Score: 0.6650
Overall Mean F-beta Score: 0.7338
Overall Mean IoU Score: 0.5202
Final_score: 0.6270
Training Loss: 0.2264, Validation Loss: 0.2325, Validation hybrid_score: 0.6270
Epoch 42/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.43s/it, loss=0.235]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.22it/s, loss=0.233]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.8033, Class 2: 0.2452, Class 3: 0.4793, Class 4: 0.7245, Class 5: 0.5045, Class 6: 0.8847, 
Validation F-beta Score
Class 0: 0.9838, Class 1: 0.8597, Class 2: 0.2411, Class 3: 0.5882, Class 4: 0.7317, Class 5: 0.5853, Class 6: 0.9675, 
Validation mIoU Score
Class 0: 0.9735, Class 1: 0.6727, Class 2: 0.1467, Class 3: 0.3181, Class 4: 0.5692, Class 5: 0.3391, Class 6: 0.7937, 

Overall Mean Dice Score: 0.6793
Overall Mean F-beta Score: 0.7465
Overall Mean IoU Score: 0.5386
Final_score: 0.6425
Training Loss: 0.2270, Validation Loss: 0.2297, Validation hybrid_score: 0.6425
Epoch 43/4000


Training: 100%|██████████| 120/120 [02:50<00:00,  1.42s/it, loss=0.243]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.221]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.7938, Class 2: 0.2709, Class 3: 0.4940, Class 4: 0.7194, Class 5: 0.4908, Class 6: 0.8927, 
Validation F-beta Score
Class 0: 0.9834, Class 1: 0.8605, Class 2: 0.3002, Class 3: 0.5555, Class 4: 0.7172, Class 5: 0.6186, Class 6: 0.9651, 
Validation mIoU Score
Class 0: 0.9740, Class 1: 0.6598, Class 2: 0.1655, Class 3: 0.3300, Class 4: 0.5634, Class 5: 0.3272, Class 6: 0.8068, 

Overall Mean Dice Score: 0.6781
Overall Mean F-beta Score: 0.7434
Overall Mean IoU Score: 0.5374
Final_score: 0.6404
Training Loss: 0.2269, Validation Loss: 0.2286, Validation hybrid_score: 0.6404
Epoch 44/4000


Training: 100%|██████████| 120/120 [02:55<00:00,  1.46s/it, loss=0.226]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.20it/s, loss=0.237]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.7765, Class 2: 0.2647, Class 3: 0.5125, Class 4: 0.7399, Class 5: 0.5136, Class 6: 0.8950, 
Validation F-beta Score
Class 0: 0.9854, Class 1: 0.8555, Class 2: 0.2376, Class 3: 0.5783, Class 4: 0.7640, Class 5: 0.5486, Class 6: 0.9624, 
Validation mIoU Score
Class 0: 0.9749, Class 1: 0.6369, Class 2: 0.1620, Class 3: 0.3463, Class 4: 0.5884, Class 5: 0.3470, Class 6: 0.8104, 

Overall Mean Dice Score: 0.6875
Overall Mean F-beta Score: 0.7418
Overall Mean IoU Score: 0.5458
Final_score: 0.6438
Training Loss: 0.2245, Validation Loss: 0.2307, Validation hybrid_score: 0.6438
Epoch 45/4000


Training: 100%|██████████| 120/120 [02:51<00:00,  1.43s/it, loss=0.229]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s, loss=0.227]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.7996, Class 2: 0.2489, Class 3: 0.4797, Class 4: 0.7361, Class 5: 0.4983, Class 6: 0.9010, 
Validation F-beta Score
Class 0: 0.9834, Class 1: 0.8665, Class 2: 0.2403, Class 3: 0.6260, Class 4: 0.7942, Class 5: 0.5620, Class 6: 0.9610, 
Validation mIoU Score
Class 0: 0.9744, Class 1: 0.6684, Class 2: 0.1507, Class 3: 0.3179, Class 4: 0.5831, Class 5: 0.3333, Class 6: 0.8202, 

Overall Mean Dice Score: 0.6829
Overall Mean F-beta Score: 0.7619
Overall Mean IoU Score: 0.5446
Final_score: 0.6533
Training Loss: 0.2265, Validation Loss: 0.2306, Validation hybrid_score: 0.6533
Epoch 46/4000


Training: 100%|██████████| 120/120 [02:48<00:00,  1.40s/it, loss=0.221]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.25it/s, loss=0.243]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7884, Class 2: 0.3731, Class 3: 0.5142, Class 4: 0.7499, Class 5: 0.4772, Class 6: 0.9017, 
Validation F-beta Score
Class 0: 0.9867, Class 1: 0.8640, Class 2: 0.3625, Class 3: 0.5746, Class 4: 0.7608, Class 5: 0.4979, Class 6: 0.9326, 
Validation mIoU Score
Class 0: 0.9756, Class 1: 0.6524, Class 2: 0.2334, Class 3: 0.3484, Class 4: 0.6005, Class 5: 0.3160, Class 6: 0.8213, 

Overall Mean Dice Score: 0.6863
Overall Mean F-beta Score: 0.7260
Overall Mean IoU Score: 0.5477
Final_score: 0.6368
Training Loss: 0.2248, Validation Loss: 0.2316, Validation hybrid_score: 0.6368
Epoch 47/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.38s/it, loss=0.223]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.235]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.7973, Class 2: 0.2652, Class 3: 0.4759, Class 4: 0.7257, Class 5: 0.4675, Class 6: 0.8904, 
Validation F-beta Score
Class 0: 0.9836, Class 1: 0.8546, Class 2: 0.3446, Class 3: 0.6286, Class 4: 0.7190, Class 5: 0.5450, Class 6: 0.9660, 
Validation mIoU Score
Class 0: 0.9736, Class 1: 0.6638, Class 2: 0.1581, Class 3: 0.3134, Class 4: 0.5707, Class 5: 0.3074, Class 6: 0.8028, 

Overall Mean Dice Score: 0.6714
Overall Mean F-beta Score: 0.7426
Overall Mean IoU Score: 0.5316
Final_score: 0.6371
Training Loss: 0.2258, Validation Loss: 0.2305, Validation hybrid_score: 0.6371
Epoch 48/4000


Training: 100%|██████████| 120/120 [02:57<00:00,  1.48s/it, loss=0.23] 
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.25it/s, loss=0.236]


Validation Dice Score
Class 0: 0.9864, Class 1: 0.7792, Class 2: 0.2660, Class 3: 0.4636, Class 4: 0.7217, Class 5: 0.4954, Class 6: 0.8918, 
Validation F-beta Score
Class 0: 0.9838, Class 1: 0.8636, Class 2: 0.2930, Class 3: 0.6448, Class 4: 0.7025, Class 5: 0.5917, Class 6: 0.9690, 
Validation mIoU Score
Class 0: 0.9732, Class 1: 0.6402, Class 2: 0.1572, Class 3: 0.3033, Class 4: 0.5661, Class 5: 0.3315, Class 6: 0.8053, 

Overall Mean Dice Score: 0.6703
Overall Mean F-beta Score: 0.7543
Overall Mean IoU Score: 0.5293
Final_score: 0.6418
Training Loss: 0.2255, Validation Loss: 0.2321, Validation hybrid_score: 0.6418
Epoch 49/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.44s/it, loss=0.222]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.25it/s, loss=0.223]


Validation Dice Score
Class 0: 0.9879, Class 1: 0.7705, Class 2: 0.2870, Class 3: 0.4660, Class 4: 0.7282, Class 5: 0.4850, Class 6: 0.8872, 
Validation F-beta Score
Class 0: 0.9864, Class 1: 0.8911, Class 2: 0.2541, Class 3: 0.6382, Class 4: 0.7207, Class 5: 0.5124, Class 6: 0.9661, 
Validation mIoU Score
Class 0: 0.9761, Class 1: 0.6303, Class 2: 0.1765, Class 3: 0.3061, Class 4: 0.5737, Class 5: 0.3230, Class 6: 0.7978, 

Overall Mean Dice Score: 0.6674
Overall Mean F-beta Score: 0.7457
Overall Mean IoU Score: 0.5262
Final_score: 0.6359
Training Loss: 0.2252, Validation Loss: 0.2300, Validation hybrid_score: 0.6359
Epoch 50/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.44s/it, loss=0.23] 
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.21it/s, loss=0.229]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.7524, Class 2: 0.3105, Class 3: 0.5071, Class 4: 0.7391, Class 5: 0.4869, Class 6: 0.8932, 
Validation F-beta Score
Class 0: 0.9856, Class 1: 0.8942, Class 2: 0.3701, Class 3: 0.6126, Class 4: 0.7510, Class 5: 0.5038, Class 6: 0.9641, 
Validation mIoU Score
Class 0: 0.9749, Class 1: 0.6069, Class 2: 0.1911, Class 3: 0.3413, Class 4: 0.5868, Class 5: 0.3239, Class 6: 0.8075, 

Overall Mean Dice Score: 0.6757
Overall Mean F-beta Score: 0.7451
Overall Mean IoU Score: 0.5333
Final_score: 0.6392
Training Loss: 0.2249, Validation Loss: 0.2286, Validation hybrid_score: 0.6392
Epoch 51/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.44s/it, loss=0.23] 
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.229]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.7904, Class 2: 0.2954, Class 3: 0.4818, Class 4: 0.7347, Class 5: 0.4921, Class 6: 0.9093, 
Validation F-beta Score
Class 0: 0.9853, Class 1: 0.8609, Class 2: 0.3321, Class 3: 0.6168, Class 4: 0.7529, Class 5: 0.5230, Class 6: 0.9567, 
Validation mIoU Score
Class 0: 0.9749, Class 1: 0.6557, Class 2: 0.1793, Class 3: 0.3190, Class 4: 0.5817, Class 5: 0.3290, Class 6: 0.8342, 

Overall Mean Dice Score: 0.6817
Overall Mean F-beta Score: 0.7420
Overall Mean IoU Score: 0.5439
Final_score: 0.6430
Training Loss: 0.2248, Validation Loss: 0.2298, Validation hybrid_score: 0.6430
Epoch 52/4000


Training: 100%|██████████| 120/120 [02:47<00:00,  1.39s/it, loss=0.231]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.25it/s, loss=0.221]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.7843, Class 2: 0.2840, Class 3: 0.4894, Class 4: 0.7249, Class 5: 0.4867, Class 6: 0.8795, 
Validation F-beta Score
Class 0: 0.9823, Class 1: 0.8678, Class 2: 0.2741, Class 3: 0.6163, Class 4: 0.7920, Class 5: 0.5579, Class 6: 0.9715, 
Validation mIoU Score
Class 0: 0.9730, Class 1: 0.6479, Class 2: 0.1715, Class 3: 0.3251, Class 4: 0.5700, Class 5: 0.3229, Class 6: 0.7853, 

Overall Mean Dice Score: 0.6730
Overall Mean F-beta Score: 0.7611
Overall Mean IoU Score: 0.5302
Final_score: 0.6457
Training Loss: 0.2242, Validation Loss: 0.2305, Validation hybrid_score: 0.6457
Epoch 53/4000


Training: 100%|██████████| 120/120 [02:56<00:00,  1.47s/it, loss=0.234]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.22it/s, loss=0.214]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7783, Class 2: 0.2747, Class 3: 0.4708, Class 4: 0.7444, Class 5: 0.5229, Class 6: 0.9010, 
Validation F-beta Score
Class 0: 0.9841, Class 1: 0.8804, Class 2: 0.3326, Class 3: 0.6022, Class 4: 0.7607, Class 5: 0.5986, Class 6: 0.9638, 
Validation mIoU Score
Class 0: 0.9747, Class 1: 0.6395, Class 2: 0.1642, Class 3: 0.3092, Class 4: 0.5939, Class 5: 0.3554, Class 6: 0.8205, 

Overall Mean Dice Score: 0.6835
Overall Mean F-beta Score: 0.7612
Overall Mean IoU Score: 0.5437
Final_score: 0.6524
Training Loss: 0.2244, Validation Loss: 0.2293, Validation hybrid_score: 0.6524
SUPER Best model saved. Loss:0.2293, Score:0.6524
Epoch 54/4000


Training: 100%|██████████| 120/120 [02:51<00:00,  1.43s/it, loss=0.23] 
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.25it/s, loss=0.232]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.7903, Class 2: 0.2459, Class 3: 0.4940, Class 4: 0.7395, Class 5: 0.5138, Class 6: 0.8986, 
Validation F-beta Score
Class 0: 0.9842, Class 1: 0.8809, Class 2: 0.2930, Class 3: 0.6276, Class 4: 0.7644, Class 5: 0.5608, Class 6: 0.9646, 
Validation mIoU Score
Class 0: 0.9740, Class 1: 0.6550, Class 2: 0.1447, Class 3: 0.3295, Class 4: 0.5874, Class 5: 0.3467, Class 6: 0.8165, 

Overall Mean Dice Score: 0.6872
Overall Mean F-beta Score: 0.7597
Overall Mean IoU Score: 0.5470
Final_score: 0.6533
Training Loss: 0.2225, Validation Loss: 0.2284, Validation hybrid_score: 0.6533
SUPER Best model saved. Loss:0.2284, Score:0.6533
Epoch 55/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.227]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s, loss=0.227]


Validation Dice Score
Class 0: 0.9876, Class 1: 0.7772, Class 2: 0.2189, Class 3: 0.5176, Class 4: 0.7404, Class 5: 0.5084, Class 6: 0.9037, 
Validation F-beta Score
Class 0: 0.9855, Class 1: 0.8663, Class 2: 0.2667, Class 3: 0.6249, Class 4: 0.7426, Class 5: 0.5616, Class 6: 0.9565, 
Validation mIoU Score
Class 0: 0.9755, Class 1: 0.6399, Class 2: 0.1266, Class 3: 0.3508, Class 4: 0.5889, Class 5: 0.3440, Class 6: 0.8257, 

Overall Mean Dice Score: 0.6894
Overall Mean F-beta Score: 0.7504
Overall Mean IoU Score: 0.5499
Final_score: 0.6501
Training Loss: 0.2232, Validation Loss: 0.2271, Validation hybrid_score: 0.6501
Epoch 56/4000


Training: 100%|██████████| 120/120 [02:55<00:00,  1.46s/it, loss=0.229]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.22it/s, loss=0.225]


Validation Dice Score
Class 0: 0.9875, Class 1: 0.7954, Class 2: 0.2770, Class 3: 0.5263, Class 4: 0.7432, Class 5: 0.5010, Class 6: 0.8864, 
Validation F-beta Score
Class 0: 0.9843, Class 1: 0.8858, Class 2: 0.2878, Class 3: 0.6106, Class 4: 0.7813, Class 5: 0.5697, Class 6: 0.9631, 
Validation mIoU Score
Class 0: 0.9752, Class 1: 0.6618, Class 2: 0.1677, Class 3: 0.3592, Class 4: 0.5926, Class 5: 0.3362, Class 6: 0.7965, 

Overall Mean Dice Score: 0.6905
Overall Mean F-beta Score: 0.7621
Overall Mean IoU Score: 0.5493
Final_score: 0.6557
Training Loss: 0.2229, Validation Loss: 0.2289, Validation hybrid_score: 0.6557
Epoch 57/4000


Training: 100%|██████████| 120/120 [02:52<00:00,  1.44s/it, loss=0.228]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.22it/s, loss=0.236]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.7900, Class 2: 0.2596, Class 3: 0.4728, Class 4: 0.7435, Class 5: 0.4914, Class 6: 0.8984, 
Validation F-beta Score
Class 0: 0.9831, Class 1: 0.8666, Class 2: 0.2657, Class 3: 0.6245, Class 4: 0.7802, Class 5: 0.5500, Class 6: 0.9627, 
Validation mIoU Score
Class 0: 0.9730, Class 1: 0.6559, Class 2: 0.1550, Class 3: 0.3118, Class 4: 0.5927, Class 5: 0.3279, Class 6: 0.8162, 

Overall Mean Dice Score: 0.6792
Overall Mean F-beta Score: 0.7568
Overall Mean IoU Score: 0.5409
Final_score: 0.6489
Training Loss: 0.2224, Validation Loss: 0.2307, Validation hybrid_score: 0.6489
Epoch 58/4000


Training: 100%|██████████| 120/120 [02:49<00:00,  1.41s/it, loss=0.222]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.219]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7957, Class 2: 0.2642, Class 3: 0.4925, Class 4: 0.7386, Class 5: 0.5076, Class 6: 0.8939, 
Validation F-beta Score
Class 0: 0.9865, Class 1: 0.8891, Class 2: 0.2825, Class 3: 0.5626, Class 4: 0.7168, Class 5: 0.5301, Class 6: 0.9581, 
Validation mIoU Score
Class 0: 0.9747, Class 1: 0.6628, Class 2: 0.1591, Class 3: 0.3287, Class 4: 0.5864, Class 5: 0.3413, Class 6: 0.8085, 

Overall Mean Dice Score: 0.6857
Overall Mean F-beta Score: 0.7313
Overall Mean IoU Score: 0.5455
Final_score: 0.6384
Training Loss: 0.2216, Validation Loss: 0.2316, Validation hybrid_score: 0.6384
Epoch 59/4000


Training: 100%|██████████| 120/120 [02:47<00:00,  1.39s/it, loss=0.227]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.25it/s, loss=0.232]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.7804, Class 2: 0.3059, Class 3: 0.5098, Class 4: 0.7517, Class 5: 0.5012, Class 6: 0.8963, 
Validation F-beta Score
Class 0: 0.9856, Class 1: 0.8705, Class 2: 0.3496, Class 3: 0.5952, Class 4: 0.7629, Class 5: 0.5560, Class 6: 0.9642, 
Validation mIoU Score
Class 0: 0.9758, Class 1: 0.6450, Class 2: 0.1845, Class 3: 0.3437, Class 4: 0.6033, Class 5: 0.3356, Class 6: 0.8124, 

Overall Mean Dice Score: 0.6879
Overall Mean F-beta Score: 0.7498
Overall Mean IoU Score: 0.5480
Final_score: 0.6489
Training Loss: 0.2239, Validation Loss: 0.2291, Validation hybrid_score: 0.6489
Epoch 60/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.38s/it, loss=0.23] 
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.25it/s, loss=0.227]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.7860, Class 2: 0.2423, Class 3: 0.4959, Class 4: 0.7363, Class 5: 0.5017, Class 6: 0.8871, 
Validation F-beta Score
Class 0: 0.9857, Class 1: 0.8638, Class 2: 0.2341, Class 3: 0.6060, Class 4: 0.7354, Class 5: 0.5484, Class 6: 0.9612, 
Validation mIoU Score
Class 0: 0.9752, Class 1: 0.6489, Class 2: 0.1450, Class 3: 0.3313, Class 4: 0.5836, Class 5: 0.3375, Class 6: 0.7977, 

Overall Mean Dice Score: 0.6814
Overall Mean F-beta Score: 0.7430
Overall Mean IoU Score: 0.5398
Final_score: 0.6414
Training Loss: 0.2235, Validation Loss: 0.2300, Validation hybrid_score: 0.6414
Epoch 61/4000


Training: 100%|██████████| 120/120 [02:47<00:00,  1.40s/it, loss=0.228]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.236]


Validation Dice Score
Class 0: 0.9880, Class 1: 0.7985, Class 2: 0.2494, Class 3: 0.4892, Class 4: 0.7481, Class 5: 0.4955, Class 6: 0.9010, 
Validation F-beta Score
Class 0: 0.9867, Class 1: 0.8649, Class 2: 0.2665, Class 3: 0.5999, Class 4: 0.7382, Class 5: 0.5321, Class 6: 0.9552, 
Validation mIoU Score
Class 0: 0.9763, Class 1: 0.6666, Class 2: 0.1489, Class 3: 0.3259, Class 4: 0.5986, Class 5: 0.3323, Class 6: 0.8204, 

Overall Mean Dice Score: 0.6865
Overall Mean F-beta Score: 0.7381
Overall Mean IoU Score: 0.5487
Final_score: 0.6434
Training Loss: 0.2219, Validation Loss: 0.2274, Validation hybrid_score: 0.6434
Epoch 62/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.227]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.25it/s, loss=0.218]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.7957, Class 2: 0.3214, Class 3: 0.4916, Class 4: 0.7435, Class 5: 0.5085, Class 6: 0.9077, 
Validation F-beta Score
Class 0: 0.9847, Class 1: 0.8797, Class 2: 0.4105, Class 3: 0.6234, Class 4: 0.7884, Class 5: 0.5347, Class 6: 0.9544, 
Validation mIoU Score
Class 0: 0.9750, Class 1: 0.6629, Class 2: 0.1983, Class 3: 0.3278, Class 4: 0.5930, Class 5: 0.3425, Class 6: 0.8312, 

Overall Mean Dice Score: 0.6894
Overall Mean F-beta Score: 0.7561
Overall Mean IoU Score: 0.5515
Final_score: 0.6538
Training Loss: 0.2225, Validation Loss: 0.2281, Validation hybrid_score: 0.6538
SUPER Best model saved. Loss:0.2281, Score:0.6538
Epoch 63/4000


Training: 100%|██████████| 120/120 [02:47<00:00,  1.39s/it, loss=0.227]
Validation: 100%|██████████| 20/20 [00:22<00:00,  1.15s/it, loss=0.229]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.7987, Class 2: 0.2233, Class 3: 0.4680, Class 4: 0.7247, Class 5: 0.5044, Class 6: 0.9070, 
Validation F-beta Score
Class 0: 0.9857, Class 1: 0.8959, Class 2: 0.2692, Class 3: 0.5405, Class 4: 0.7508, Class 5: 0.5260, Class 6: 0.9610, 
Validation mIoU Score
Class 0: 0.9752, Class 1: 0.6663, Class 2: 0.1318, Class 3: 0.3076, Class 4: 0.5697, Class 5: 0.3392, Class 6: 0.8304, 

Overall Mean Dice Score: 0.6806
Overall Mean F-beta Score: 0.7348
Overall Mean IoU Score: 0.5426
Final_score: 0.6387
Training Loss: 0.2223, Validation Loss: 0.2293, Validation hybrid_score: 0.6387
Epoch 64/4000


Training: 100%|██████████| 120/120 [02:57<00:00,  1.48s/it, loss=0.23] 
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.22it/s, loss=0.229]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.7901, Class 2: 0.2837, Class 3: 0.4547, Class 4: 0.7558, Class 5: 0.5012, Class 6: 0.8885, 
Validation F-beta Score
Class 0: 0.9845, Class 1: 0.9020, Class 2: 0.3419, Class 3: 0.5730, Class 4: 0.7640, Class 5: 0.5674, Class 6: 0.9671, 
Validation mIoU Score
Class 0: 0.9746, Class 1: 0.6544, Class 2: 0.1729, Class 3: 0.2974, Class 4: 0.6081, Class 5: 0.3361, Class 6: 0.8001, 

Overall Mean Dice Score: 0.6781
Overall Mean F-beta Score: 0.7547
Overall Mean IoU Score: 0.5392
Final_score: 0.6470
Training Loss: 0.2225, Validation Loss: 0.2303, Validation hybrid_score: 0.6470
Epoch 65/4000


Training: 100%|██████████| 120/120 [02:56<00:00,  1.47s/it, loss=0.226]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.25it/s, loss=0.231]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.8092, Class 2: 0.2296, Class 3: 0.4857, Class 4: 0.7401, Class 5: 0.5171, Class 6: 0.9050, 
Validation F-beta Score
Class 0: 0.9849, Class 1: 0.8739, Class 2: 0.2995, Class 3: 0.6053, Class 4: 0.7283, Class 5: 0.5892, Class 6: 0.9627, 
Validation mIoU Score
Class 0: 0.9745, Class 1: 0.6805, Class 2: 0.1351, Class 3: 0.3228, Class 4: 0.5884, Class 5: 0.3512, Class 6: 0.8270, 

Overall Mean Dice Score: 0.6914
Overall Mean F-beta Score: 0.7519
Overall Mean IoU Score: 0.5540
Final_score: 0.6529
Training Loss: 0.2215, Validation Loss: 0.2288, Validation hybrid_score: 0.6529
Epoch 66/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.38s/it, loss=0.222]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.225]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.7760, Class 2: 0.2894, Class 3: 0.4734, Class 4: 0.7428, Class 5: 0.4969, Class 6: 0.9003, 
Validation F-beta Score
Class 0: 0.9836, Class 1: 0.8682, Class 2: 0.3205, Class 3: 0.6115, Class 4: 0.7828, Class 5: 0.5509, Class 6: 0.9705, 
Validation mIoU Score
Class 0: 0.9739, Class 1: 0.6371, Class 2: 0.1750, Class 3: 0.3112, Class 4: 0.5917, Class 5: 0.3334, Class 6: 0.8193, 

Overall Mean Dice Score: 0.6779
Overall Mean F-beta Score: 0.7568
Overall Mean IoU Score: 0.5386
Final_score: 0.6477
Training Loss: 0.2212, Validation Loss: 0.2305, Validation hybrid_score: 0.6477
Epoch 67/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.225]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.238]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7886, Class 2: 0.2566, Class 3: 0.4893, Class 4: 0.7520, Class 5: 0.5005, Class 6: 0.8841, 
Validation F-beta Score
Class 0: 0.9846, Class 1: 0.8752, Class 2: 0.2958, Class 3: 0.5641, Class 4: 0.8064, Class 5: 0.5519, Class 6: 0.9636, 
Validation mIoU Score
Class 0: 0.9757, Class 1: 0.6562, Class 2: 0.1516, Class 3: 0.3267, Class 4: 0.6033, Class 5: 0.3363, Class 6: 0.7928, 

Overall Mean Dice Score: 0.6829
Overall Mean F-beta Score: 0.7523
Overall Mean IoU Score: 0.5431
Final_score: 0.6477
Training Loss: 0.2215, Validation Loss: 0.2273, Validation hybrid_score: 0.6477
Epoch 68/4000


Training: 100%|██████████| 120/120 [02:46<00:00,  1.39s/it, loss=0.219]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s, loss=0.232]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.8203, Class 2: 0.3370, Class 3: 0.4864, Class 4: 0.7403, Class 5: 0.4862, Class 6: 0.8836, 
Validation F-beta Score
Class 0: 0.9859, Class 1: 0.8906, Class 2: 0.3780, Class 3: 0.5553, Class 4: 0.7287, Class 5: 0.5465, Class 6: 0.9657, 
Validation mIoU Score
Class 0: 0.9758, Class 1: 0.6964, Class 2: 0.2085, Class 3: 0.3241, Class 4: 0.5890, Class 5: 0.3245, Class 6: 0.7918, 

Overall Mean Dice Score: 0.6834
Overall Mean F-beta Score: 0.7373
Overall Mean IoU Score: 0.5452
Final_score: 0.6413
Training Loss: 0.2218, Validation Loss: 0.2286, Validation hybrid_score: 0.6413
Epoch 69/4000


Training: 100%|██████████| 120/120 [02:47<00:00,  1.39s/it, loss=0.219]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.25it/s, loss=0.212]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.7691, Class 2: 0.2577, Class 3: 0.5181, Class 4: 0.7545, Class 5: 0.5013, Class 6: 0.9030, 
Validation F-beta Score
Class 0: 0.9849, Class 1: 0.8820, Class 2: 0.3262, Class 3: 0.6143, Class 4: 0.7779, Class 5: 0.5482, Class 6: 0.9622, 
Validation mIoU Score
Class 0: 0.9750, Class 1: 0.6272, Class 2: 0.1548, Class 3: 0.3507, Class 4: 0.6062, Class 5: 0.3354, Class 6: 0.8236, 

Overall Mean Dice Score: 0.6892
Overall Mean F-beta Score: 0.7569
Overall Mean IoU Score: 0.5486
Final_score: 0.6528
Training Loss: 0.2216, Validation Loss: 0.2284, Validation hybrid_score: 0.6528
Epoch 70/4000


Training: 100%|██████████| 120/120 [02:49<00:00,  1.41s/it, loss=0.226]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, loss=0.232]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7881, Class 2: 0.2886, Class 3: 0.5146, Class 4: 0.7657, Class 5: 0.4909, Class 6: 0.8882, 
Validation F-beta Score
Class 0: 0.9848, Class 1: 0.8953, Class 2: 0.2835, Class 3: 0.5868, Class 4: 0.7744, Class 5: 0.5650, Class 6: 0.9658, 
Validation mIoU Score
Class 0: 0.9747, Class 1: 0.6512, Class 2: 0.1788, Class 3: 0.3486, Class 4: 0.6214, Class 5: 0.3283, Class 6: 0.7994, 

Overall Mean Dice Score: 0.6895
Overall Mean F-beta Score: 0.7575
Overall Mean IoU Score: 0.5498
Final_score: 0.6536
Training Loss: 0.2209, Validation Loss: 0.2296, Validation hybrid_score: 0.6536
Epoch 71/4000


Training: 100%|██████████| 120/120 [02:47<00:00,  1.39s/it, loss=0.229]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, loss=0.222]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.7964, Class 2: 0.2858, Class 3: 0.4884, Class 4: 0.7308, Class 5: 0.4859, Class 6: 0.8989, 
Validation F-beta Score
Class 0: 0.9842, Class 1: 0.8864, Class 2: 0.3525, Class 3: 0.5839, Class 4: 0.7681, Class 5: 0.5481, Class 6: 0.9653, 
Validation mIoU Score
Class 0: 0.9749, Class 1: 0.6639, Class 2: 0.1709, Class 3: 0.3245, Class 4: 0.5771, Class 5: 0.3235, Class 6: 0.8166, 

Overall Mean Dice Score: 0.6801
Overall Mean F-beta Score: 0.7503
Overall Mean IoU Score: 0.5411
Final_score: 0.6457
Training Loss: 0.2216, Validation Loss: 0.2284, Validation hybrid_score: 0.6457
Epoch 72/4000


Training:  62%|██████▏   | 74/120 [01:52<01:10,  1.53s/it, loss=0.214]

In [None]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# VAl

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import TverskyLoss
import torch
import numpy as np
from tqdm import tqdm
import wandb
from src.dataset.dataset import make_val_dataloader

val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
lamda = 0.52

wandb.init(
    project='czii_SwinUnetR_val',  # 프로젝트 이름 설정
    name='SwinUNETR96_96_lr0.001_lambda0.52_batch2',         # 실행(run) 이름 설정
    config={
        'learning_rate': 0.001,
        'batch_size': batch_size,
        'lambda': lamda,
        'img_size': img_size,
        'device': 'cuda',
        "checkpoint_dir": "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2",
        
    }
)

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

val_loader = make_val_dataloader(
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0
)
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    softmax=True
)
    
    
from monai.metrics import DiceMetric

img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

val_loss, overall_mean_fbeta_score = validate_one_epoch(
    model=model, 
    val_loader=val_loader, 
    criterion=criterion, 
    device=device, 
    epoch=0, 
    calculate_dice_interval=1
)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
class_0_dice_score,▁
class_0_f_beta_score,▁
class_1_dice_score,▁
class_1_f_beta_score,▁
class_2_dice_score,▁
class_2_f_beta_score,▁
class_3_dice_score,▁
class_3_f_beta_score,▁
class_4_dice_score,▁
class_4_f_beta_score,▁

0,1
class_0_dice_score,0.65703
class_0_f_beta_score,0.50748
class_1_dice_score,0.53332
class_1_f_beta_score,0.64703
class_2_dice_score,0.00286
class_2_f_beta_score,0.02334
class_3_dice_score,0.23703
class_3_f_beta_score,0.23033
class_4_dice_score,0.65487
class_4_f_beta_score,0.62525


Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]
  checkpoint = torch.load(pretrain_path, map_location=device)
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s, loss=0.865]

Validation Dice Score
Class 0: 0.6570, Class 1: 0.5333, Class 2: 0.0029, Class 3: 0.2370, 
Class 4: 0.6549, Class 5: 0.4790, Class 6: 0.4255, 
Validation F-beta Score
Class 0: 0.5075, Class 1: 0.6470, Class 2: 0.0233, Class 3: 0.2303, 
Class 4: 0.6252, Class 5: 0.5145, Class 6: 0.4720, 
Overall Mean Dice Score: 0.4659
Overall Mean F-beta Score: 0.4978






: 

: 

# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

: 

: 

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


: 

: 

In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


: 

: 

In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

: 

: 

In [None]:
val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=1, 
            calculate_dice_interval=0
        )

Validation:   0%|          | 0/4 [00:03<?, ?it/s, loss=0.764]


ZeroDivisionError: integer modulo by zero

: 

: 

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv


: 

: 