In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

import random
import numpy as np
import torch

from src.models import UNet_CBAM

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.networks.layers.factories import Act, Norm

from monai.config import print_config
from monai.metrics import DiceMetric
# from src.models.swincspunetr import SwinCSPUNETR
# from src.models.swincspunetr_unet import SwinCSPUNETR_unet
# from src.models.swincspunetr3plus import SwinCSPUNETR3plus

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

# 랜덤 시드 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


print_config()

  from .autonotebook import tqdm as notebook_tqdm


MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: c:\Users\<username>\.conda\envs\UM\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.2.0
Tensorboard version: 2.18.0
gdown version: 5.2.0
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.17.2
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.

In [2]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 1500}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 1500},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.16393442622950818, 2: 0.01639344262295082, 3: 0.2459016393442623, 4: 0.16393442622950818, 5: 0.2459016393442623, 6: 0.16393442622950818}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.16393442622950818, 0.01639344262295082, 0.2459016393442623, 0.16393442622950818, 0.2459016393442623, 0.16393442622950818]


# 모델 설정

In [3]:
from src.dataset.dataset import create_dataloaders, create_dataloaders_bw
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd, RandCropd,RandCropByPosNegLabeld, RandGaussianSmoothd
)
from monai.transforms import CastToTyped
import numpy as np

train_img_dir = "./datasets/pretrain_exdata/images"
train_label_dir = "./datasets/pretrain_exdata/labels"
val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
# DATA CONFIG
img_size =  96 # Match your patch size
img_depth = 32
n_classes = 7
batch_size = 16 # 13.8GB GPU memory required for 128x128 img size
loader_batch = 1
num_samples = batch_size // loader_batch # 한 이미지에서 뽑을 샘플 수
num_repeat = 4
# MODEL CONFIG
num_epochs = 4000
lamda = 0.52
ce_weight = 0.4
lr = 0.001
feature_size = 48
use_checkpoint = True
use_v2 = True
drop_rate= 0.25
attn_drop_rate = 0.25
num_bottleneck = 2
# CLASS_WEIGHTS
class_weights = None
# class_weights = torch.tensor([0.0001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치
# class_weights = torch.tensor([0.9,1,0.9,1.1,1,1.1,1], dtype=torch.float32)  # 클래스별 가중치
# class_weights = torch.tensor([1,1,1,1,1,1,1], dtype=torch.float32)  # 클래스별 가중치
sigma = 2.0
accumulation_steps = 1
# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    # GaussianSmoothd(
    #     keys=["image"],      # 변환을 적용할 키
    #     sigma=[sigma, sigma, sigma]  # 각 축(x, y, z)의 시그마 값
    #     ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    RandGaussianSmoothd(
    keys=["image"],      # 변환을 적용할 키
    sigma_x = (0.0, sigma), # 각 축(x, y, z)의 시그마 값
    sigma_y = (0.0, sigma),
    sigma_z = (0.0, sigma),
    prob=1.0,
    ),
])
val_random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    # RandGaussianSmoothd(
    # keys=["image"],      # 변환을 적용할 키
    # sigma_x = (0.0, sigma), # 각 축(x, y, z)의 시그마 값
    # sigma_y = (0.0, sigma),
    # sigma_z = (0.0, sigma),
    # prob=1.0,
    # ),
])


In [4]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders_bw(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    val_non_random_transforms=non_random_transforms,
    random_transforms = random_transforms, 
    val_random_transforms=val_random_transforms,
    batch_size = loader_batch,
    num_workers=0,
    train_num_repeat=num_repeat,
    val_num_repeat=num_repeat,
    )

Loading dataset: 100%|██████████| 51/51 [00:05<00:00,  9.42it/s]
Loading dataset: 100%|██████████| 1/1 [00:00<00:00,  6.94it/s]


https://monai.io/model-zoo.html

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.losses import TverskyLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DynamicTverskyLoss 클래스 정의
class DynamicTverskyLoss(TverskyLoss):
    def __init__(self, lamda=0.5, **kwargs):
        super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
        self.lamda = lamda

    def set_lamda(self, lamda):
        self.lamda = lamda
        self.alpha = 1 - lamda
        self.beta = lamda


# CombinedCETverskyLoss 클래스
class CombinedCETverskyLoss(nn.Module):

    
    def __init__(self, lamda=0.5, ce_weight=0.5, n_classes=7, class_weights=None, ignore_index=-1, **kwargs):
        super().__init__()
        self.n_classes = n_classes
        self.ce_weight = ce_weight
        self.ignore_index = ignore_index
        
        # CrossEntropyLoss에서 클래스별 가중치를 적용
        self.ce = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction='mean', **kwargs)
        
        # TverskyLoss
        self.tversky = DynamicTverskyLoss(lamda=lamda, reduction="mean",softmax=True, **kwargs)

    def forward(self, inputs, targets):
        
        # CrossEntropyLoss는 정수형 클래스 인덱스를 사용
        ce_loss = self.ce(inputs, targets)

        # TverskyLoss 계산 (원핫 인코딩된 라벨을 사용)
        
        tversky_loss = self.tversky(inputs, targets)

        # 최종 손실 계산
        final_loss = self.ce_weight * ce_loss + (1 - self.ce_weight) * tversky_loss  # mean()으로 배치에 대해 평균
        
        return final_loss

    def set_lamda(self, lamda):
        self.tversky.set_lamda(lamda)

    @property
    def lamda(self):
        return self.tversky.lamda

criterion = CombinedCETverskyLoss(
    lamda=lamda,
    ce_weight=ce_weight,
    n_classes=n_classes,
    class_weights=class_weights,
).to(device)

In [6]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
from monai.networks.nets import UNet
from src.models import *


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

enc_channels = (32, 64, 128, 256)
enc_strides = (2, 2, 2)
num_layers_enc = (1, 1, 1, 2)

core_channels = 64
dec_channels = (128, 64, 32)
dec_strides = (2, 2, 2)
num_layers_dec = (1, 1, 1)

skip_map = {
    0: [("enc", 2)],       # 디코더0 => 인코더2
    1: [("enc", 3), ("enc", 1)],  # 디코더1 => 인코더1 + 디코더0
    2: [("enc", 3), ("dec", 0), ("enc", 0)]   # 디코더2 => 인코더0 + 디코더1
}

model = FlexibleUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    encoder_channels=enc_channels,
    encoder_strides=enc_strides,
    core_channels=core_channels,
    decoder_channels=dec_channels,
    decoder_strides=dec_strides,
    num_layers_encoder=num_layers_enc,
    num_layers_decoder=num_layers_dec,
    skip_connections=skip_map,
    kernel_size=3,
    up_kernel_size=3,
    act=Act.PRELU,
    norm=Norm.INSTANCE,
    dropout=drop_rate,
    bias=True,
    mode="trilinear",
    align_corners=False,
).to(device)

pretrain_str = "yes" if use_checkpoint else "no"
weight_str = "weighted" if class_weights is not None else ""

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
folder_name = f"FLEX_randGaus_511_241_noclswt_gelu_batch_f{feature_size}_lr{lr:.0e}_a{lamda:.2f}_b{batch_size}_r{num_repeat}_ce{ce_weight}_ac{accumulation_steps}"
checkpoint_dir = checkpoint_base_dir / folder_name
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model_pretrained.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss', 'best_val_fbeta_score']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                best_val_fbeta_score = checkpoint['best_val_fbeta_score']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")

기존 best model 발견: model_checkpoints\FLEX_randGaus_511_241_noclswt_gelu_batch_f48_lr1e-03_a0.52_b16_r4_ce0.4_ac1\best_model_pretrained.pt
기존 학습된 가중치를 성공적으로 로드했습니다.


  checkpoint = torch.load(best_model_path, map_location=device)


In [7]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([16, 1, 32, 96, 96]) torch.Size([16, 1, 32, 96, 96])


In [8]:
torch.backends.cudnn.benchmark = True

In [9]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_UNet',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        # "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        # "attn_drop_rate": attn_drop_rate,
        # "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        # "num_bottleneck": num_bottleneck,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpook0612[0m ([33mlimbw[0m). Use [1m`wandb login --relogin`[0m to force relogin


# 학습

In [10]:
from monai.metrics import DiceMetric
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    loss = criterion(outputs, labels_onehot)
    # loss = loss_fn(criterion(outputs, labels_onehot),class_weights=class_weights, device=device)
    return loss, outputs, labels, outputs.argmax(dim=1)

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    with tqdm(train_loader, desc='Training') as pbar:
        for i, batch_data in enumerate(pbar):
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
    avg_loss = epoch_loss / len(train_loader)
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss


def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval):
    model.eval()
    val_loss = 0
    
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
                
        print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1})
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\n")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    return val_loss / len(val_loader), overall_mean_fbeta

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4, pretrained=False
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps
        )
        
        scheduler.step(train_loss)
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F-beta: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            if pretrained:
                checkpoint_path = os.path.join(checkpoint_dir, 'best_model_pretrained.pt')
            else:
                checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        # if epochs_no_improve % 6 == 0 & epochs_no_improve != 0:
        #     # 손실이 개선되지 않았으므로 lambda 감소
        #     new_lamda = max(criterion.lamda - 0.01, 0.35)  # 최소값은 0.1로 설정
        #     criterion.set_lamda(new_lamda)
        #     print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")

    wandb.finish()


In [11]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps,
    pretrained=True,
    ) 

Epoch 6/4000


Training: 100%|██████████| 204/204 [03:25<00:00,  1.01s/it, loss=0.455]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.74it/s, loss=0.488]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.2206, Class 2: 0.0000, Class 3: 0.2965, Class 4: 0.6636, Class 5: 0.2697, Class 6: 0.8818, 
Validation F-beta Score
Class 0: 0.9870, Class 1: 0.2572, Class 2: 0.0000, Class 3: 0.4791, Class 4: 0.7301, Class 5: 0.2203, Class 6: 0.8906, 

Overall Mean Dice Score: 0.4665
Overall Mean F-beta Score: 0.5155

Training Loss: 0.4814, Validation Loss: 0.4771, Validation F-beta: 0.5155
SUPER Best model saved. Loss:0.4771, Score:0.5155
Epoch 7/4000


Training: 100%|██████████| 204/204 [02:50<00:00,  1.19it/s, loss=0.439]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.83it/s, loss=0.468]


Validation Dice Score
Class 0: 0.9836, Class 1: 0.4223, Class 2: 0.0000, Class 3: 0.3646, Class 4: 0.6410, Class 5: 0.3024, Class 6: 0.9021, 
Validation F-beta Score
Class 0: 0.9813, Class 1: 0.4090, Class 2: 0.0000, Class 3: 0.4851, Class 4: 0.7076, Class 5: 0.2963, Class 6: 0.9275, 

Overall Mean Dice Score: 0.5265
Overall Mean F-beta Score: 0.5651

Training Loss: 0.4731, Validation Loss: 0.4706, Validation F-beta: 0.5651
SUPER Best model saved. Loss:0.4706, Score:0.5651
Epoch 8/4000


Training: 100%|██████████| 204/204 [02:47<00:00,  1.22it/s, loss=0.443]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.47] 


Validation Dice Score
Class 0: 0.9853, Class 1: 0.5726, Class 2: 0.0000, Class 3: 0.3224, Class 4: 0.6935, Class 5: 0.3751, Class 6: 0.8898, 
Validation F-beta Score
Class 0: 0.9817, Class 1: 0.6727, Class 2: 0.0000, Class 3: 0.5664, Class 4: 0.7962, Class 5: 0.3635, Class 6: 0.9166, 

Overall Mean Dice Score: 0.5707
Overall Mean F-beta Score: 0.6631

Training Loss: 0.4655, Validation Loss: 0.4619, Validation F-beta: 0.6631
SUPER Best model saved. Loss:0.4619, Score:0.6631
Epoch 9/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.455]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.69it/s, loss=0.472]


Validation Dice Score
Class 0: 0.9837, Class 1: 0.6264, Class 2: 0.0000, Class 3: 0.2742, Class 4: 0.6963, Class 5: 0.3337, Class 6: 0.8894, 
Validation F-beta Score
Class 0: 0.9791, Class 1: 0.6644, Class 2: 0.0000, Class 3: 0.4710, Class 4: 0.8001, Class 5: 0.3376, Class 6: 0.9337, 

Overall Mean Dice Score: 0.5640
Overall Mean F-beta Score: 0.6414

Training Loss: 0.4570, Validation Loss: 0.4634, Validation F-beta: 0.6414
Epoch 10/4000


Training: 100%|██████████| 204/204 [02:51<00:00,  1.19it/s, loss=0.432]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.95it/s, loss=0.452]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.5410, Class 2: 0.0000, Class 3: 0.3351, Class 4: 0.7243, Class 5: 0.3683, Class 6: 0.9144, 
Validation F-beta Score
Class 0: 0.9841, Class 1: 0.4900, Class 2: 0.0000, Class 3: 0.5407, Class 4: 0.7813, Class 5: 0.3949, Class 6: 0.9095, 

Overall Mean Dice Score: 0.5766
Overall Mean F-beta Score: 0.6233

Training Loss: 0.4551, Validation Loss: 0.4592, Validation F-beta: 0.6233
Epoch 11/4000


Training: 100%|██████████| 204/204 [02:43<00:00,  1.24it/s, loss=0.435]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.76it/s, loss=0.449]


Validation Dice Score
Class 0: 0.9851, Class 1: 0.5387, Class 2: 0.0016, Class 3: 0.4157, Class 4: 0.6347, Class 5: 0.3143, Class 6: 0.9165, 
Validation F-beta Score
Class 0: 0.9815, Class 1: 0.6042, Class 2: 0.0008, Class 3: 0.5988, Class 4: 0.7662, Class 5: 0.3268, Class 6: 0.9428, 

Overall Mean Dice Score: 0.5640
Overall Mean F-beta Score: 0.6478

Training Loss: 0.4585, Validation Loss: 0.4570, Validation F-beta: 0.6478
Epoch 12/4000


Training: 100%|██████████| 204/204 [02:47<00:00,  1.21it/s, loss=0.454]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.73it/s, loss=0.453]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.5631, Class 2: 0.0544, Class 3: 0.2956, Class 4: 0.6881, Class 5: 0.4954, Class 6: 0.8690, 
Validation F-beta Score
Class 0: 0.9820, Class 1: 0.5345, Class 2: 0.0312, Class 3: 0.4789, Class 4: 0.7994, Class 5: 0.5539, Class 6: 0.9131, 

Overall Mean Dice Score: 0.5823
Overall Mean F-beta Score: 0.6559

Training Loss: 0.4537, Validation Loss: 0.4487, Validation F-beta: 0.6559
Epoch 13/4000


Training: 100%|██████████| 204/204 [02:41<00:00,  1.26it/s, loss=0.434]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.98it/s, loss=0.438]


Validation Dice Score
Class 0: 0.9826, Class 1: 0.7194, Class 2: 0.1359, Class 3: 0.2648, Class 4: 0.6315, Class 5: 0.5189, Class 6: 0.8640, 
Validation F-beta Score
Class 0: 0.9743, Class 1: 0.7543, Class 2: 0.0936, Class 3: 0.4854, Class 4: 0.8448, Class 5: 0.6035, Class 6: 0.9344, 

Overall Mean Dice Score: 0.5997
Overall Mean F-beta Score: 0.7245

Training Loss: 0.4537, Validation Loss: 0.4436, Validation F-beta: 0.7245
SUPER Best model saved. Loss:0.4436, Score:0.7245
Epoch 14/4000


Training: 100%|██████████| 204/204 [02:49<00:00,  1.21it/s, loss=0.439]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.89it/s, loss=0.461]


Validation Dice Score
Class 0: 0.9847, Class 1: 0.6158, Class 2: 0.1238, Class 3: 0.4243, Class 4: 0.7412, Class 5: 0.3775, Class 6: 0.8980, 
Validation F-beta Score
Class 0: 0.9804, Class 1: 0.6352, Class 2: 0.0960, Class 3: 0.5336, Class 4: 0.7923, Class 5: 0.4832, Class 6: 0.9176, 

Overall Mean Dice Score: 0.6114
Overall Mean F-beta Score: 0.6724

Training Loss: 0.4508, Validation Loss: 0.4490, Validation F-beta: 0.6724
Epoch 15/4000


Training: 100%|██████████| 204/204 [02:50<00:00,  1.20it/s, loss=0.404]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.88it/s, loss=0.414]


Validation Dice Score
Class 0: 0.9834, Class 1: 0.5433, Class 2: 0.1743, Class 3: 0.3389, Class 4: 0.7381, Class 5: 0.4071, Class 6: 0.8914, 
Validation F-beta Score
Class 0: 0.9779, Class 1: 0.5623, Class 2: 0.1681, Class 3: 0.5125, Class 4: 0.8123, Class 5: 0.4949, Class 6: 0.8983, 

Overall Mean Dice Score: 0.5838
Overall Mean F-beta Score: 0.6561

Training Loss: 0.4497, Validation Loss: 0.4490, Validation F-beta: 0.6561
Epoch 16/4000


Training: 100%|██████████| 204/204 [02:46<00:00,  1.23it/s, loss=0.428]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.79it/s, loss=0.447]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.5921, Class 2: 0.0956, Class 3: 0.2327, Class 4: 0.6359, Class 5: 0.4306, Class 6: 0.9327, 
Validation F-beta Score
Class 0: 0.9784, Class 1: 0.6037, Class 2: 0.1512, Class 3: 0.4176, Class 4: 0.8181, Class 5: 0.4861, Class 6: 0.9546, 

Overall Mean Dice Score: 0.5648
Overall Mean F-beta Score: 0.6560

Training Loss: 0.4465, Validation Loss: 0.4528, Validation F-beta: 0.6560
Epoch 17/4000


Training: 100%|██████████| 204/204 [02:45<00:00,  1.23it/s, loss=0.434]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.66it/s, loss=0.465]


Validation Dice Score
Class 0: 0.9848, Class 1: 0.7024, Class 2: 0.1977, Class 3: 0.4000, Class 4: 0.6326, Class 5: 0.4477, Class 6: 0.9033, 
Validation F-beta Score
Class 0: 0.9776, Class 1: 0.7137, Class 2: 0.2149, Class 3: 0.6816, Class 4: 0.8216, Class 5: 0.5108, Class 6: 0.9540, 

Overall Mean Dice Score: 0.6172
Overall Mean F-beta Score: 0.7363

Training Loss: 0.4448, Validation Loss: 0.4483, Validation F-beta: 0.7363
Epoch 18/4000


Training: 100%|██████████| 204/204 [02:52<00:00,  1.18it/s, loss=0.424]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.79it/s, loss=0.459]


Validation Dice Score
Class 0: 0.9844, Class 1: 0.5100, Class 2: 0.1164, Class 3: 0.3745, Class 4: 0.6385, Class 5: 0.3939, Class 6: 0.8888, 
Validation F-beta Score
Class 0: 0.9776, Class 1: 0.6464, Class 2: 0.2295, Class 3: 0.5422, Class 4: 0.7840, Class 5: 0.4969, Class 6: 0.9472, 

Overall Mean Dice Score: 0.5611
Overall Mean F-beta Score: 0.6833

Training Loss: 0.4447, Validation Loss: 0.4434, Validation F-beta: 0.6833
Epoch 19/4000


Training: 100%|██████████| 204/204 [02:44<00:00,  1.24it/s, loss=0.402]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.82it/s, loss=0.442]


Validation Dice Score
Class 0: 0.9824, Class 1: 0.7134, Class 2: 0.1129, Class 3: 0.4021, Class 4: 0.6982, Class 5: 0.4198, Class 6: 0.9046, 
Validation F-beta Score
Class 0: 0.9762, Class 1: 0.8210, Class 2: 0.1435, Class 3: 0.5783, Class 4: 0.8064, Class 5: 0.5083, Class 6: 0.9672, 

Overall Mean Dice Score: 0.6276
Overall Mean F-beta Score: 0.7362

Training Loss: 0.4433, Validation Loss: 0.4441, Validation F-beta: 0.7362
Epoch 20/4000


Training: 100%|██████████| 204/204 [02:43<00:00,  1.25it/s, loss=0.435]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.93it/s, loss=0.44] 


Validation Dice Score
Class 0: 0.9837, Class 1: 0.7358, Class 2: 0.1383, Class 3: 0.2802, Class 4: 0.7112, Class 5: 0.4232, Class 6: 0.9173, 
Validation F-beta Score
Class 0: 0.9769, Class 1: 0.8046, Class 2: 0.1601, Class 3: 0.5189, Class 4: 0.7898, Class 5: 0.5428, Class 6: 0.9364, 

Overall Mean Dice Score: 0.6135
Overall Mean F-beta Score: 0.7185

Training Loss: 0.4405, Validation Loss: 0.4456, Validation F-beta: 0.7185
Epoch 21/4000


Training: 100%|██████████| 204/204 [02:36<00:00,  1.30it/s, loss=0.418]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.92it/s, loss=0.453]


Validation Dice Score
Class 0: 0.9851, Class 1: 0.6876, Class 2: 0.1016, Class 3: 0.1728, Class 4: 0.7368, Class 5: 0.4125, Class 6: 0.9173, 
Validation F-beta Score
Class 0: 0.9809, Class 1: 0.7202, Class 2: 0.1936, Class 3: 0.2762, Class 4: 0.8374, Class 5: 0.4086, Class 6: 0.9304, 

Overall Mean Dice Score: 0.5854
Overall Mean F-beta Score: 0.6345

Training Loss: 0.4412, Validation Loss: 0.4519, Validation F-beta: 0.6345
Epoch 22/4000


Training: 100%|██████████| 204/204 [02:39<00:00,  1.28it/s, loss=0.389]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.91it/s, loss=0.44] 


Validation Dice Score
Class 0: 0.9868, Class 1: 0.7144, Class 2: 0.1980, Class 3: 0.3779, Class 4: 0.7329, Class 5: 0.4327, Class 6: 0.9081, 
Validation F-beta Score
Class 0: 0.9825, Class 1: 0.7681, Class 2: 0.2107, Class 3: 0.5635, Class 4: 0.8058, Class 5: 0.5059, Class 6: 0.9397, 

Overall Mean Dice Score: 0.6332
Overall Mean F-beta Score: 0.7166

Training Loss: 0.4405, Validation Loss: 0.4354, Validation F-beta: 0.7166
Epoch 23/4000


Training: 100%|██████████| 204/204 [02:46<00:00,  1.23it/s, loss=0.4]  
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.76it/s, loss=0.413]


Validation Dice Score
Class 0: 0.9858, Class 1: 0.6909, Class 2: 0.2527, Class 3: 0.3382, Class 4: 0.7124, Class 5: 0.4674, Class 6: 0.9110, 
Validation F-beta Score
Class 0: 0.9786, Class 1: 0.8470, Class 2: 0.3824, Class 3: 0.5356, Class 4: 0.8547, Class 5: 0.5925, Class 6: 0.9392, 

Overall Mean Dice Score: 0.6240
Overall Mean F-beta Score: 0.7538

Training Loss: 0.4400, Validation Loss: 0.4285, Validation F-beta: 0.7538
SUPER Best model saved. Loss:0.4285, Score:0.7538
Epoch 24/4000


Training: 100%|██████████| 204/204 [02:44<00:00,  1.24it/s, loss=0.422]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.98it/s, loss=0.447]


Validation Dice Score
Class 0: 0.9826, Class 1: 0.7766, Class 2: 0.1732, Class 3: 0.3988, Class 4: 0.7149, Class 5: 0.4080, Class 6: 0.8154, 
Validation F-beta Score
Class 0: 0.9748, Class 1: 0.8506, Class 2: 0.1967, Class 3: 0.5741, Class 4: 0.8673, Class 5: 0.4586, Class 6: 0.9272, 

Overall Mean Dice Score: 0.6228
Overall Mean F-beta Score: 0.7356

Training Loss: 0.4389, Validation Loss: 0.4406, Validation F-beta: 0.7356
Epoch 25/4000


Training: 100%|██████████| 204/204 [02:39<00:00,  1.28it/s, loss=0.423]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.92it/s, loss=0.446]


Validation Dice Score
Class 0: 0.9854, Class 1: 0.7564, Class 2: 0.2226, Class 3: 0.3626, Class 4: 0.7503, Class 5: 0.4361, Class 6: 0.9193, 
Validation F-beta Score
Class 0: 0.9809, Class 1: 0.8275, Class 2: 0.3551, Class 3: 0.5185, Class 4: 0.8141, Class 5: 0.4963, Class 6: 0.9203, 

Overall Mean Dice Score: 0.6449
Overall Mean F-beta Score: 0.7153

Training Loss: 0.4376, Validation Loss: 0.4396, Validation F-beta: 0.7153
Epoch 26/4000


Training: 100%|██████████| 204/204 [02:44<00:00,  1.24it/s, loss=0.42] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.88it/s, loss=0.442]


Validation Dice Score
Class 0: 0.9844, Class 1: 0.6746, Class 2: 0.0835, Class 3: 0.4020, Class 4: 0.7496, Class 5: 0.5385, Class 6: 0.8932, 
Validation F-beta Score
Class 0: 0.9787, Class 1: 0.8557, Class 2: 0.1111, Class 3: 0.6437, Class 4: 0.8457, Class 5: 0.5470, Class 6: 0.9368, 

Overall Mean Dice Score: 0.6516
Overall Mean F-beta Score: 0.7658

Training Loss: 0.4381, Validation Loss: 0.4302, Validation F-beta: 0.7658
Epoch 27/4000


Training: 100%|██████████| 204/204 [02:39<00:00,  1.28it/s, loss=0.402]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s, loss=0.445]


Validation Dice Score
Class 0: 0.9857, Class 1: 0.7219, Class 2: 0.2729, Class 3: 0.4127, Class 4: 0.6491, Class 5: 0.2990, Class 6: 0.8929, 
Validation F-beta Score
Class 0: 0.9797, Class 1: 0.8418, Class 2: 0.3112, Class 3: 0.6321, Class 4: 0.8076, Class 5: 0.3849, Class 6: 0.9525, 

Overall Mean Dice Score: 0.5951
Overall Mean F-beta Score: 0.7238

Training Loss: 0.4355, Validation Loss: 0.4347, Validation F-beta: 0.7238
Epoch 28/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.412]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.79it/s, loss=0.428]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.7014, Class 2: 0.2001, Class 3: 0.4280, Class 4: 0.7453, Class 5: 0.4989, Class 6: 0.9191, 
Validation F-beta Score
Class 0: 0.9831, Class 1: 0.8282, Class 2: 0.3117, Class 3: 0.6222, Class 4: 0.8117, Class 5: 0.5857, Class 6: 0.9447, 

Overall Mean Dice Score: 0.6585
Overall Mean F-beta Score: 0.7585

Training Loss: 0.4366, Validation Loss: 0.4217, Validation F-beta: 0.7585
SUPER Best model saved. Loss:0.4217, Score:0.7585
Epoch 29/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.434]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.423]


Validation Dice Score
Class 0: 0.9840, Class 1: 0.6989, Class 2: 0.1152, Class 3: 0.4241, Class 4: 0.7254, Class 5: 0.5433, Class 6: 0.8446, 
Validation F-beta Score
Class 0: 0.9775, Class 1: 0.8263, Class 2: 0.1400, Class 3: 0.6245, Class 4: 0.8187, Class 5: 0.6162, Class 6: 0.9169, 

Overall Mean Dice Score: 0.6473
Overall Mean F-beta Score: 0.7605

Training Loss: 0.4353, Validation Loss: 0.4351, Validation F-beta: 0.7605
Epoch 30/4000


Training: 100%|██████████| 204/204 [02:37<00:00,  1.30it/s, loss=0.424]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.98it/s, loss=0.446]


Validation Dice Score
Class 0: 0.9820, Class 1: 0.6755, Class 2: 0.1189, Class 3: 0.4107, Class 4: 0.6863, Class 5: 0.4716, Class 6: 0.8873, 
Validation F-beta Score
Class 0: 0.9728, Class 1: 0.8037, Class 2: 0.2509, Class 3: 0.5542, Class 4: 0.7981, Class 5: 0.6554, Class 6: 0.9324, 

Overall Mean Dice Score: 0.6263
Overall Mean F-beta Score: 0.7488

Training Loss: 0.4348, Validation Loss: 0.4398, Validation F-beta: 0.7488
Epoch 31/4000


Training: 100%|██████████| 204/204 [02:44<00:00,  1.24it/s, loss=0.401]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.88it/s, loss=0.436]


Validation Dice Score
Class 0: 0.9860, Class 1: 0.6253, Class 2: 0.1783, Class 3: 0.3534, Class 4: 0.7335, Class 5: 0.4108, Class 6: 0.9107, 
Validation F-beta Score
Class 0: 0.9814, Class 1: 0.6968, Class 2: 0.2305, Class 3: 0.5642, Class 4: 0.8076, Class 5: 0.4601, Class 6: 0.9374, 

Overall Mean Dice Score: 0.6067
Overall Mean F-beta Score: 0.6932

Training Loss: 0.4361, Validation Loss: 0.4464, Validation F-beta: 0.6932
Epoch 32/4000


Training: 100%|██████████| 204/204 [02:53<00:00,  1.17it/s, loss=0.419]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.73it/s, loss=0.43]


Validation Dice Score
Class 0: 0.9831, Class 1: 0.6819, Class 2: 0.1444, Class 3: 0.4571, Class 4: 0.7146, Class 5: 0.4469, Class 6: 0.8863, 
Validation F-beta Score
Class 0: 0.9749, Class 1: 0.7978, Class 2: 0.1933, Class 3: 0.6206, Class 4: 0.8535, Class 5: 0.5935, Class 6: 0.9321, 

Overall Mean Dice Score: 0.6373
Overall Mean F-beta Score: 0.7595

Training Loss: 0.4356, Validation Loss: 0.4401, Validation F-beta: 0.7595
Epoch 33/4000


Training: 100%|██████████| 204/204 [02:54<00:00,  1.17it/s, loss=0.392]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s, loss=0.457]


Validation Dice Score
Class 0: 0.9832, Class 1: 0.7513, Class 2: 0.1217, Class 3: 0.3735, Class 4: 0.6371, Class 5: 0.4122, Class 6: 0.9195, 
Validation F-beta Score
Class 0: 0.9738, Class 1: 0.8441, Class 2: 0.2255, Class 3: 0.5637, Class 4: 0.8634, Class 5: 0.5295, Class 6: 0.9419, 

Overall Mean Dice Score: 0.6187
Overall Mean F-beta Score: 0.7485

Training Loss: 0.4342, Validation Loss: 0.4387, Validation F-beta: 0.7485
Epoch 34/4000


Training: 100%|██████████| 204/204 [02:55<00:00,  1.16it/s, loss=0.401]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.80it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.7643, Class 2: 0.1234, Class 3: 0.4005, Class 4: 0.6011, Class 5: 0.4268, Class 6: 0.9066, 
Validation F-beta Score
Class 0: 0.9768, Class 1: 0.8614, Class 2: 0.1752, Class 3: 0.6121, Class 4: 0.7615, Class 5: 0.5525, Class 6: 0.9542, 

Overall Mean Dice Score: 0.6199
Overall Mean F-beta Score: 0.7484

Training Loss: 0.4342, Validation Loss: 0.4319, Validation F-beta: 0.7484
Epoch 35/4000


Training: 100%|██████████| 204/204 [02:53<00:00,  1.18it/s, loss=0.426]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.44] 


Validation Dice Score
Class 0: 0.9860, Class 1: 0.7152, Class 2: 0.1780, Class 3: 0.3832, Class 4: 0.6781, Class 5: 0.5000, Class 6: 0.9131, 
Validation F-beta Score
Class 0: 0.9806, Class 1: 0.8654, Class 2: 0.2290, Class 3: 0.6197, Class 4: 0.7737, Class 5: 0.5922, Class 6: 0.9361, 

Overall Mean Dice Score: 0.6379
Overall Mean F-beta Score: 0.7574

Training Loss: 0.4326, Validation Loss: 0.4318, Validation F-beta: 0.7574
Epoch 36/4000


Training: 100%|██████████| 204/204 [02:45<00:00,  1.23it/s, loss=0.428]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.88it/s, loss=0.414]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.7629, Class 2: 0.0731, Class 3: 0.4008, Class 4: 0.6886, Class 5: 0.4347, Class 6: 0.8680, 
Validation F-beta Score
Class 0: 0.9775, Class 1: 0.8078, Class 2: 0.1075, Class 3: 0.5808, Class 4: 0.8341, Class 5: 0.5605, Class 6: 0.9147, 

Overall Mean Dice Score: 0.6310
Overall Mean F-beta Score: 0.7396

Training Loss: 0.4333, Validation Loss: 0.4327, Validation F-beta: 0.7396
Epoch 37/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.394]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.407]


Validation Dice Score
Class 0: 0.9842, Class 1: 0.7015, Class 2: 0.1541, Class 3: 0.4032, Class 4: 0.7517, Class 5: 0.4923, Class 6: 0.8704, 
Validation F-beta Score
Class 0: 0.9779, Class 1: 0.8258, Class 2: 0.2441, Class 3: 0.6188, Class 4: 0.8327, Class 5: 0.5678, Class 6: 0.9008, 

Overall Mean Dice Score: 0.6439
Overall Mean F-beta Score: 0.7492

Training Loss: 0.4323, Validation Loss: 0.4344, Validation F-beta: 0.7492
Epoch 38/4000


Training: 100%|██████████| 204/204 [02:45<00:00,  1.23it/s, loss=0.406]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.94it/s, loss=0.442]


Validation Dice Score
Class 0: 0.9851, Class 1: 0.7340, Class 2: 0.2115, Class 3: 0.4293, Class 4: 0.7415, Class 5: 0.4548, Class 6: 0.8607, 
Validation F-beta Score
Class 0: 0.9804, Class 1: 0.8197, Class 2: 0.2852, Class 3: 0.6786, Class 4: 0.8048, Class 5: 0.5179, Class 6: 0.9172, 

Overall Mean Dice Score: 0.6441
Overall Mean F-beta Score: 0.7477

Training Loss: 0.4323, Validation Loss: 0.4306, Validation F-beta: 0.7477
Epoch 39/4000


Training: 100%|██████████| 204/204 [02:40<00:00,  1.27it/s, loss=0.403]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.89it/s, loss=0.409]


Validation Dice Score
Class 0: 0.9852, Class 1: 0.6701, Class 2: 0.0355, Class 3: 0.4363, Class 4: 0.7189, Class 5: 0.5595, Class 6: 0.9032, 
Validation F-beta Score
Class 0: 0.9799, Class 1: 0.8331, Class 2: 0.0676, Class 3: 0.6105, Class 4: 0.7982, Class 5: 0.6429, Class 6: 0.9330, 

Overall Mean Dice Score: 0.6576
Overall Mean F-beta Score: 0.7635

Training Loss: 0.4337, Validation Loss: 0.4286, Validation F-beta: 0.7635
Epoch 40/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.412]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.75it/s, loss=0.438]


Validation Dice Score
Class 0: 0.9846, Class 1: 0.7623, Class 2: 0.1073, Class 3: 0.4394, Class 4: 0.6867, Class 5: 0.4872, Class 6: 0.9108, 
Validation F-beta Score
Class 0: 0.9772, Class 1: 0.8411, Class 2: 0.1795, Class 3: 0.5952, Class 4: 0.8455, Class 5: 0.6120, Class 6: 0.9550, 

Overall Mean Dice Score: 0.6573
Overall Mean F-beta Score: 0.7698

Training Loss: 0.4320, Validation Loss: 0.4247, Validation F-beta: 0.7698
Epoch 41/4000


Training: 100%|██████████| 204/204 [02:42<00:00,  1.25it/s, loss=0.418]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.91it/s, loss=0.464]


Validation Dice Score
Class 0: 0.9845, Class 1: 0.7092, Class 2: 0.1255, Class 3: 0.3555, Class 4: 0.7541, Class 5: 0.4143, Class 6: 0.8800, 
Validation F-beta Score
Class 0: 0.9802, Class 1: 0.8503, Class 2: 0.2555, Class 3: 0.5840, Class 4: 0.8158, Class 5: 0.4250, Class 6: 0.9430, 

Overall Mean Dice Score: 0.6226
Overall Mean F-beta Score: 0.7236

Training Loss: 0.4311, Validation Loss: 0.4411, Validation F-beta: 0.7236
Epoch 42/4000


Training: 100%|██████████| 204/204 [02:47<00:00,  1.22it/s, loss=0.405]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.84it/s, loss=0.431]


Validation Dice Score
Class 0: 0.9854, Class 1: 0.7293, Class 2: 0.1273, Class 3: 0.3398, Class 4: 0.7044, Class 5: 0.4268, Class 6: 0.9319, 
Validation F-beta Score
Class 0: 0.9794, Class 1: 0.8562, Class 2: 0.1963, Class 3: 0.5247, Class 4: 0.8171, Class 5: 0.5198, Class 6: 0.9227, 

Overall Mean Dice Score: 0.6264
Overall Mean F-beta Score: 0.7281

Training Loss: 0.4322, Validation Loss: 0.4249, Validation F-beta: 0.7281
Epoch 43/4000


Training: 100%|██████████| 204/204 [02:52<00:00,  1.19it/s, loss=0.404]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.87it/s, loss=0.454]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.6161, Class 2: 0.1011, Class 3: 0.4607, Class 4: 0.7798, Class 5: 0.5703, Class 6: 0.9061, 
Validation F-beta Score
Class 0: 0.9839, Class 1: 0.8213, Class 2: 0.1944, Class 3: 0.6256, Class 4: 0.8261, Class 5: 0.5893, Class 6: 0.9379, 

Overall Mean Dice Score: 0.6666
Overall Mean F-beta Score: 0.7601

Training Loss: 0.4314, Validation Loss: 0.4238, Validation F-beta: 0.7601
Epoch 44/4000


Training: 100%|██████████| 204/204 [02:39<00:00,  1.28it/s, loss=0.432]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.94it/s, loss=0.467]


Validation Dice Score
Class 0: 0.9846, Class 1: 0.6958, Class 2: 0.1497, Class 3: 0.3601, Class 4: 0.7474, Class 5: 0.4248, Class 6: 0.8986, 
Validation F-beta Score
Class 0: 0.9795, Class 1: 0.8308, Class 2: 0.2495, Class 3: 0.5257, Class 4: 0.7987, Class 5: 0.5362, Class 6: 0.9170, 

Overall Mean Dice Score: 0.6254
Overall Mean F-beta Score: 0.7217

Training Loss: 0.4314, Validation Loss: 0.4431, Validation F-beta: 0.7217
Epoch 45/4000


Training: 100%|██████████| 204/204 [02:36<00:00,  1.30it/s, loss=0.419]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.452]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.6922, Class 2: 0.1829, Class 3: 0.4084, Class 4: 0.7170, Class 5: 0.4186, Class 6: 0.9243, 
Validation F-beta Score
Class 0: 0.9817, Class 1: 0.8075, Class 2: 0.3224, Class 3: 0.5666, Class 4: 0.8109, Class 5: 0.4636, Class 6: 0.9290, 

Overall Mean Dice Score: 0.6321
Overall Mean F-beta Score: 0.7155

Training Loss: 0.4300, Validation Loss: 0.4342, Validation F-beta: 0.7155
Epoch 46/4000


Training: 100%|██████████| 204/204 [02:37<00:00,  1.30it/s, loss=0.399]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.91it/s, loss=0.434]


Validation Dice Score
Class 0: 0.9845, Class 1: 0.7082, Class 2: 0.1474, Class 3: 0.3868, Class 4: 0.7687, Class 5: 0.4903, Class 6: 0.9184, 
Validation F-beta Score
Class 0: 0.9789, Class 1: 0.7883, Class 2: 0.2183, Class 3: 0.5510, Class 4: 0.8363, Class 5: 0.5862, Class 6: 0.9064, 

Overall Mean Dice Score: 0.6545
Overall Mean F-beta Score: 0.7336

Training Loss: 0.4306, Validation Loss: 0.4310, Validation F-beta: 0.7336
Epoch 47/4000


Training: 100%|██████████| 204/204 [02:45<00:00,  1.23it/s, loss=0.402]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.82it/s, loss=0.442]


Validation Dice Score
Class 0: 0.9857, Class 1: 0.8040, Class 2: 0.1993, Class 3: 0.4375, Class 4: 0.6178, Class 5: 0.5238, Class 6: 0.9076, 
Validation F-beta Score
Class 0: 0.9786, Class 1: 0.9068, Class 2: 0.3208, Class 3: 0.6479, Class 4: 0.7999, Class 5: 0.6475, Class 6: 0.9323, 

Overall Mean Dice Score: 0.6582
Overall Mean F-beta Score: 0.7869

Training Loss: 0.4295, Validation Loss: 0.4356, Validation F-beta: 0.7869
Epoch 48/4000


Training: 100%|██████████| 204/204 [02:47<00:00,  1.22it/s, loss=0.42] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.80it/s, loss=0.42] 


Validation Dice Score
Class 0: 0.9844, Class 1: 0.7864, Class 2: 0.0970, Class 3: 0.4240, Class 4: 0.7582, Class 5: 0.4847, Class 6: 0.9096, 
Validation F-beta Score
Class 0: 0.9782, Class 1: 0.8652, Class 2: 0.1943, Class 3: 0.6128, Class 4: 0.8078, Class 5: 0.6340, Class 6: 0.9528, 

Overall Mean Dice Score: 0.6726
Overall Mean F-beta Score: 0.7745

Training Loss: 0.4298, Validation Loss: 0.4246, Validation F-beta: 0.7745
Epoch 49/4000


Training: 100%|██████████| 204/204 [02:43<00:00,  1.25it/s, loss=0.393]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.95it/s, loss=0.433]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.7205, Class 2: 0.2246, Class 3: 0.3647, Class 4: 0.7704, Class 5: 0.5086, Class 6: 0.9033, 
Validation F-beta Score
Class 0: 0.9792, Class 1: 0.8465, Class 2: 0.2703, Class 3: 0.6045, Class 4: 0.8621, Class 5: 0.6344, Class 6: 0.9438, 

Overall Mean Dice Score: 0.6535
Overall Mean F-beta Score: 0.7783

Training Loss: 0.4290, Validation Loss: 0.4307, Validation F-beta: 0.7783
Epoch 50/4000


Training: 100%|██████████| 204/204 [02:35<00:00,  1.31it/s, loss=0.432]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.439]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.7590, Class 2: 0.1120, Class 3: 0.4400, Class 4: 0.7004, Class 5: 0.4895, Class 6: 0.9218, 
Validation F-beta Score
Class 0: 0.9797, Class 1: 0.8545, Class 2: 0.2154, Class 3: 0.5774, Class 4: 0.7962, Class 5: 0.5628, Class 6: 0.9291, 

Overall Mean Dice Score: 0.6621
Overall Mean F-beta Score: 0.7440

Training Loss: 0.4292, Validation Loss: 0.4322, Validation F-beta: 0.7440
Epoch 51/4000


Training: 100%|██████████| 204/204 [02:36<00:00,  1.31it/s, loss=0.418]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.99it/s, loss=0.438]


Validation Dice Score
Class 0: 0.9838, Class 1: 0.6834, Class 2: 0.1926, Class 3: 0.4535, Class 4: 0.7112, Class 5: 0.5339, Class 6: 0.9049, 
Validation F-beta Score
Class 0: 0.9763, Class 1: 0.8687, Class 2: 0.3008, Class 3: 0.6135, Class 4: 0.8308, Class 5: 0.6428, Class 6: 0.9535, 

Overall Mean Dice Score: 0.6574
Overall Mean F-beta Score: 0.7819

Training Loss: 0.4291, Validation Loss: 0.4333, Validation F-beta: 0.7819
Epoch 52/4000


Training: 100%|██████████| 204/204 [02:47<00:00,  1.22it/s, loss=0.425]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.82it/s, loss=0.438]


Validation Dice Score
Class 0: 0.9842, Class 1: 0.7630, Class 2: 0.1557, Class 3: 0.4490, Class 4: 0.7501, Class 5: 0.4877, Class 6: 0.7763, 
Validation F-beta Score
Class 0: 0.9769, Class 1: 0.8566, Class 2: 0.1668, Class 3: 0.5899, Class 4: 0.8483, Class 5: 0.6702, Class 6: 0.8622, 

Overall Mean Dice Score: 0.6452
Overall Mean F-beta Score: 0.7654

Training Loss: 0.4292, Validation Loss: 0.4381, Validation F-beta: 0.7654
Epoch 53/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.39] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.95it/s, loss=0.416]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.6637, Class 2: 0.1910, Class 3: 0.4906, Class 4: 0.7467, Class 5: 0.5392, Class 6: 0.8955, 
Validation F-beta Score
Class 0: 0.9778, Class 1: 0.7819, Class 2: 0.3464, Class 3: 0.5651, Class 4: 0.8474, Class 5: 0.7052, Class 6: 0.9333, 

Overall Mean Dice Score: 0.6671
Overall Mean F-beta Score: 0.7666

Training Loss: 0.4277, Validation Loss: 0.4210, Validation F-beta: 0.7666
SUPER Best model saved. Loss:0.4210, Score:0.7666
Epoch 54/4000


Training: 100%|██████████| 204/204 [02:49<00:00,  1.20it/s, loss=0.401]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.77it/s, loss=0.4]  


Validation Dice Score
Class 0: 0.9870, Class 1: 0.8202, Class 2: 0.1932, Class 3: 0.4266, Class 4: 0.6924, Class 5: 0.5136, Class 6: 0.9082, 
Validation F-beta Score
Class 0: 0.9828, Class 1: 0.8393, Class 2: 0.2919, Class 3: 0.5837, Class 4: 0.8051, Class 5: 0.5487, Class 6: 0.9484, 

Overall Mean Dice Score: 0.6722
Overall Mean F-beta Score: 0.7450

Training Loss: 0.4298, Validation Loss: 0.4084, Validation F-beta: 0.7450
Epoch 55/4000


Training: 100%|██████████| 204/204 [02:50<00:00,  1.20it/s, loss=0.398]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.87it/s, loss=0.425]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.8035, Class 2: 0.1367, Class 3: 0.4428, Class 4: 0.6820, Class 5: 0.5460, Class 6: 0.8890, 
Validation F-beta Score
Class 0: 0.9817, Class 1: 0.8660, Class 2: 0.1801, Class 3: 0.5797, Class 4: 0.7885, Class 5: 0.6683, Class 6: 0.9300, 

Overall Mean Dice Score: 0.6726
Overall Mean F-beta Score: 0.7665

Training Loss: 0.4285, Validation Loss: 0.4133, Validation F-beta: 0.7665
Epoch 56/4000


Training: 100%|██████████| 204/204 [02:36<00:00,  1.30it/s, loss=0.391]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.98it/s, loss=0.437]


Validation Dice Score
Class 0: 0.9829, Class 1: 0.7490, Class 2: 0.1914, Class 3: 0.4344, Class 4: 0.7151, Class 5: 0.4841, Class 6: 0.9315, 
Validation F-beta Score
Class 0: 0.9746, Class 1: 0.8652, Class 2: 0.3303, Class 3: 0.5756, Class 4: 0.8280, Class 5: 0.6510, Class 6: 0.9544, 

Overall Mean Dice Score: 0.6628
Overall Mean F-beta Score: 0.7748

Training Loss: 0.4280, Validation Loss: 0.4429, Validation F-beta: 0.7748
Epoch 57/4000


Training: 100%|██████████| 204/204 [02:41<00:00,  1.26it/s, loss=0.412]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.432]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.7103, Class 2: 0.1847, Class 3: 0.4188, Class 4: 0.7465, Class 5: 0.4816, Class 6: 0.9244, 
Validation F-beta Score
Class 0: 0.9807, Class 1: 0.8593, Class 2: 0.2756, Class 3: 0.5440, Class 4: 0.8199, Class 5: 0.5961, Class 6: 0.9426, 

Overall Mean Dice Score: 0.6563
Overall Mean F-beta Score: 0.7524

Training Loss: 0.4276, Validation Loss: 0.4244, Validation F-beta: 0.7524
Epoch 58/4000


Training: 100%|██████████| 204/204 [02:47<00:00,  1.22it/s, loss=0.392]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.406]


Validation Dice Score
Class 0: 0.9854, Class 1: 0.7743, Class 2: 0.1402, Class 3: 0.3909, Class 4: 0.7175, Class 5: 0.4826, Class 6: 0.9108, 
Validation F-beta Score
Class 0: 0.9795, Class 1: 0.8371, Class 2: 0.2835, Class 3: 0.5812, Class 4: 0.8195, Class 5: 0.5957, Class 6: 0.9337, 

Overall Mean Dice Score: 0.6552
Overall Mean F-beta Score: 0.7534

Training Loss: 0.4260, Validation Loss: 0.4227, Validation F-beta: 0.7534
Epoch 59/4000


Training: 100%|██████████| 204/204 [02:44<00:00,  1.24it/s, loss=0.42] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.79it/s, loss=0.432]


Validation Dice Score
Class 0: 0.9851, Class 1: 0.7507, Class 2: 0.1638, Class 3: 0.4296, Class 4: 0.7467, Class 5: 0.4437, Class 6: 0.9016, 
Validation F-beta Score
Class 0: 0.9781, Class 1: 0.7816, Class 2: 0.2384, Class 3: 0.6358, Class 4: 0.8473, Class 5: 0.5840, Class 6: 0.9250, 

Overall Mean Dice Score: 0.6544
Overall Mean F-beta Score: 0.7547

Training Loss: 0.4266, Validation Loss: 0.4393, Validation F-beta: 0.7547
Epoch 60/4000


Training: 100%|██████████| 204/204 [02:51<00:00,  1.19it/s, loss=0.395]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.79it/s, loss=0.403]


Validation Dice Score
Class 0: 0.9858, Class 1: 0.7159, Class 2: 0.1736, Class 3: 0.3810, Class 4: 0.6666, Class 5: 0.5357, Class 6: 0.9111, 
Validation F-beta Score
Class 0: 0.9790, Class 1: 0.8364, Class 2: 0.2684, Class 3: 0.5956, Class 4: 0.8038, Class 5: 0.6424, Class 6: 0.9425, 

Overall Mean Dice Score: 0.6421
Overall Mean F-beta Score: 0.7641

Training Loss: 0.4277, Validation Loss: 0.4296, Validation F-beta: 0.7641
Epoch 61/4000


Training: 100%|██████████| 204/204 [02:38<00:00,  1.29it/s, loss=0.408]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.406]


Validation Dice Score
Class 0: 0.9851, Class 1: 0.7795, Class 2: 0.1987, Class 3: 0.4798, Class 4: 0.7600, Class 5: 0.4051, Class 6: 0.9201, 
Validation F-beta Score
Class 0: 0.9789, Class 1: 0.8907, Class 2: 0.2892, Class 3: 0.6758, Class 4: 0.8666, Class 5: 0.5091, Class 6: 0.9322, 

Overall Mean Dice Score: 0.6689
Overall Mean F-beta Score: 0.7749

Training Loss: 0.4264, Validation Loss: 0.4231, Validation F-beta: 0.7749
Epoch 62/4000


Training: 100%|██████████| 204/204 [02:46<00:00,  1.22it/s, loss=0.399]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.89it/s, loss=0.43] 


Validation Dice Score
Class 0: 0.9842, Class 1: 0.6996, Class 2: 0.1519, Class 3: 0.3622, Class 4: 0.7244, Class 5: 0.5479, Class 6: 0.9156, 
Validation F-beta Score
Class 0: 0.9773, Class 1: 0.8141, Class 2: 0.2812, Class 3: 0.5095, Class 4: 0.8637, Class 5: 0.6446, Class 6: 0.9067, 

Overall Mean Dice Score: 0.6499
Overall Mean F-beta Score: 0.7477

Training Loss: 0.4275, Validation Loss: 0.4289, Validation F-beta: 0.7477
Epoch 63/4000


Training: 100%|██████████| 204/204 [02:45<00:00,  1.23it/s, loss=0.421]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.88it/s, loss=0.421]


Validation Dice Score
Class 0: 0.9857, Class 1: 0.7489, Class 2: 0.1452, Class 3: 0.4304, Class 4: 0.7684, Class 5: 0.5024, Class 6: 0.9359, 
Validation F-beta Score
Class 0: 0.9798, Class 1: 0.8508, Class 2: 0.1832, Class 3: 0.6598, Class 4: 0.8527, Class 5: 0.6055, Class 6: 0.9415, 

Overall Mean Dice Score: 0.6772
Overall Mean F-beta Score: 0.7821

Training Loss: 0.4260, Validation Loss: 0.4249, Validation F-beta: 0.7821
Epoch 64/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.404]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.77it/s, loss=0.423]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.7969, Class 2: 0.0941, Class 3: 0.3652, Class 4: 0.7661, Class 5: 0.5061, Class 6: 0.9329, 
Validation F-beta Score
Class 0: 0.9802, Class 1: 0.8916, Class 2: 0.1808, Class 3: 0.6125, Class 4: 0.8272, Class 5: 0.6373, Class 6: 0.9418, 

Overall Mean Dice Score: 0.6734
Overall Mean F-beta Score: 0.7821

Training Loss: 0.4260, Validation Loss: 0.4193, Validation F-beta: 0.7821
SUPER Best model saved. Loss:0.4193, Score:0.7821
Epoch 65/4000


Training: 100%|██████████| 204/204 [02:49<00:00,  1.21it/s, loss=0.392]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.75it/s, loss=0.422]


Validation Dice Score
Class 0: 0.9854, Class 1: 0.7818, Class 2: 0.1629, Class 3: 0.4184, Class 4: 0.7768, Class 5: 0.4530, Class 6: 0.9037, 
Validation F-beta Score
Class 0: 0.9793, Class 1: 0.8671, Class 2: 0.2473, Class 3: 0.6215, Class 4: 0.8394, Class 5: 0.5732, Class 6: 0.9320, 

Overall Mean Dice Score: 0.6667
Overall Mean F-beta Score: 0.7666

Training Loss: 0.4238, Validation Loss: 0.4326, Validation F-beta: 0.7666
Epoch 66/4000


Training: 100%|██████████| 204/204 [02:50<00:00,  1.20it/s, loss=0.396]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s, loss=0.396]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.8053, Class 2: 0.1338, Class 3: 0.3742, Class 4: 0.7409, Class 5: 0.5569, Class 6: 0.9165, 
Validation F-beta Score
Class 0: 0.9816, Class 1: 0.8893, Class 2: 0.1781, Class 3: 0.5833, Class 4: 0.8410, Class 5: 0.6246, Class 6: 0.9157, 

Overall Mean Dice Score: 0.6788
Overall Mean F-beta Score: 0.7708

Training Loss: 0.4240, Validation Loss: 0.4152, Validation F-beta: 0.7708
Epoch 67/4000


Training: 100%|██████████| 204/204 [02:45<00:00,  1.23it/s, loss=0.433]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.90it/s, loss=0.425]


Validation Dice Score
Class 0: 0.9852, Class 1: 0.8107, Class 2: 0.1422, Class 3: 0.4194, Class 4: 0.7073, Class 5: 0.4222, Class 6: 0.9171, 
Validation F-beta Score
Class 0: 0.9793, Class 1: 0.8787, Class 2: 0.2423, Class 3: 0.5828, Class 4: 0.7950, Class 5: 0.5423, Class 6: 0.9474, 

Overall Mean Dice Score: 0.6553
Overall Mean F-beta Score: 0.7493

Training Loss: 0.4249, Validation Loss: 0.4229, Validation F-beta: 0.7493
Epoch 68/4000


Training: 100%|██████████| 204/204 [02:46<00:00,  1.23it/s, loss=0.388]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.82it/s, loss=0.409]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.7296, Class 2: 0.2002, Class 3: 0.4762, Class 4: 0.7507, Class 5: 0.5233, Class 6: 0.9239, 
Validation F-beta Score
Class 0: 0.9815, Class 1: 0.8672, Class 2: 0.2359, Class 3: 0.6083, Class 4: 0.8440, Class 5: 0.5688, Class 6: 0.9404, 

Overall Mean Dice Score: 0.6807
Overall Mean F-beta Score: 0.7658

Training Loss: 0.4226, Validation Loss: 0.4273, Validation F-beta: 0.7658
Epoch 69/4000


Training: 100%|██████████| 204/204 [02:40<00:00,  1.27it/s, loss=0.392]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.88it/s, loss=0.416]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.7804, Class 2: 0.1703, Class 3: 0.4991, Class 4: 0.7149, Class 5: 0.4973, Class 6: 0.9103, 
Validation F-beta Score
Class 0: 0.9806, Class 1: 0.8924, Class 2: 0.2993, Class 3: 0.6792, Class 4: 0.8098, Class 5: 0.5784, Class 6: 0.9406, 

Overall Mean Dice Score: 0.6804
Overall Mean F-beta Score: 0.7801

Training Loss: 0.4234, Validation Loss: 0.4156, Validation F-beta: 0.7801
Epoch 70/4000


Training: 100%|██████████| 204/204 [02:35<00:00,  1.31it/s, loss=0.404]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.94it/s, loss=0.41] 


Validation Dice Score
Class 0: 0.9861, Class 1: 0.7584, Class 2: 0.2779, Class 3: 0.4988, Class 4: 0.7567, Class 5: 0.4447, Class 6: 0.8830, 
Validation F-beta Score
Class 0: 0.9826, Class 1: 0.8251, Class 2: 0.3171, Class 3: 0.5877, Class 4: 0.8076, Class 5: 0.4950, Class 6: 0.9357, 

Overall Mean Dice Score: 0.6683
Overall Mean F-beta Score: 0.7302

Training Loss: 0.4221, Validation Loss: 0.4225, Validation F-beta: 0.7302
Epoch 71/4000


Training: 100%|██████████| 204/204 [02:48<00:00,  1.21it/s, loss=0.391]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s, loss=0.426]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7097, Class 2: 0.0723, Class 3: 0.4929, Class 4: 0.7125, Class 5: 0.5322, Class 6: 0.9150, 
Validation F-beta Score
Class 0: 0.9824, Class 1: 0.8510, Class 2: 0.0903, Class 3: 0.6266, Class 4: 0.8069, Class 5: 0.5902, Class 6: 0.9537, 

Overall Mean Dice Score: 0.6725
Overall Mean F-beta Score: 0.7657

Training Loss: 0.4213, Validation Loss: 0.4235, Validation F-beta: 0.7657
Epoch 72/4000


Training: 100%|██████████| 204/204 [02:42<00:00,  1.25it/s, loss=0.414]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s, loss=0.4]  


Validation Dice Score
Class 0: 0.9880, Class 1: 0.8292, Class 2: 0.1462, Class 3: 0.4487, Class 4: 0.7429, Class 5: 0.4682, Class 6: 0.8815, 
Validation F-beta Score
Class 0: 0.9839, Class 1: 0.8850, Class 2: 0.2663, Class 3: 0.6182, Class 4: 0.8102, Class 5: 0.5420, Class 6: 0.9244, 

Overall Mean Dice Score: 0.6741
Overall Mean F-beta Score: 0.7560

Training Loss: 0.4239, Validation Loss: 0.4248, Validation F-beta: 0.7560
Epoch 73/4000


Training: 100%|██████████| 204/204 [02:35<00:00,  1.31it/s, loss=0.391]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.426]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.7145, Class 2: 0.2222, Class 3: 0.4098, Class 4: 0.7435, Class 5: 0.5451, Class 6: 0.9014, 
Validation F-beta Score
Class 0: 0.9840, Class 1: 0.8265, Class 2: 0.2421, Class 3: 0.5497, Class 4: 0.7842, Class 5: 0.6016, Class 6: 0.9256, 

Overall Mean Dice Score: 0.6629
Overall Mean F-beta Score: 0.7375

Training Loss: 0.4217, Validation Loss: 0.4399, Validation F-beta: 0.7375
Epoch 74/4000


Training: 100%|██████████| 204/204 [02:35<00:00,  1.31it/s, loss=0.406]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.97it/s, loss=0.438]


Validation Dice Score
Class 0: 0.9862, Class 1: 0.8185, Class 2: 0.1699, Class 3: 0.3682, Class 4: 0.6813, Class 5: 0.4974, Class 6: 0.9339, 
Validation F-beta Score
Class 0: 0.9804, Class 1: 0.8736, Class 2: 0.2004, Class 3: 0.4661, Class 4: 0.8297, Class 5: 0.6217, Class 6: 0.9544, 

Overall Mean Dice Score: 0.6599
Overall Mean F-beta Score: 0.7491

Training Loss: 0.4221, Validation Loss: 0.4317, Validation F-beta: 0.7491
Epoch 75/4000


Training: 100%|██████████| 204/204 [02:41<00:00,  1.27it/s, loss=0.4]  
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.6842, Class 2: 0.1185, Class 3: 0.3816, Class 4: 0.7437, Class 5: 0.5233, Class 6: 0.8907, 
Validation F-beta Score
Class 0: 0.9831, Class 1: 0.8559, Class 2: 0.1781, Class 3: 0.5418, Class 4: 0.7980, Class 5: 0.5836, Class 6: 0.9496, 

Overall Mean Dice Score: 0.6447
Overall Mean F-beta Score: 0.7458

Training Loss: 0.4218, Validation Loss: 0.4347, Validation F-beta: 0.7458
Epoch 76/4000


Training: 100%|██████████| 204/204 [02:41<00:00,  1.26it/s, loss=0.406]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.95it/s, loss=0.43] 


Validation Dice Score
Class 0: 0.9865, Class 1: 0.7822, Class 2: 0.0862, Class 3: 0.4291, Class 4: 0.7444, Class 5: 0.5250, Class 6: 0.9004, 
Validation F-beta Score
Class 0: 0.9814, Class 1: 0.8908, Class 2: 0.1533, Class 3: 0.5700, Class 4: 0.8325, Class 5: 0.6146, Class 6: 0.9286, 

Overall Mean Dice Score: 0.6762
Overall Mean F-beta Score: 0.7673

Training Loss: 0.4217, Validation Loss: 0.4254, Validation F-beta: 0.7673
Epoch 77/4000


Training: 100%|██████████| 204/204 [02:41<00:00,  1.27it/s, loss=0.396]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s, loss=0.436]


Validation Dice Score
Class 0: 0.9857, Class 1: 0.6350, Class 2: 0.1119, Class 3: 0.4374, Class 4: 0.7626, Class 5: 0.4402, Class 6: 0.8951, 
Validation F-beta Score
Class 0: 0.9815, Class 1: 0.7742, Class 2: 0.1445, Class 3: 0.5458, Class 4: 0.8047, Class 5: 0.5372, Class 6: 0.9211, 

Overall Mean Dice Score: 0.6341
Overall Mean F-beta Score: 0.7166

Training Loss: 0.4228, Validation Loss: 0.4370, Validation F-beta: 0.7166
Epoch 78/4000


Training: 100%|██████████| 204/204 [02:40<00:00,  1.27it/s, loss=0.41] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s, loss=0.423]


Validation Dice Score
Class 0: 0.9862, Class 1: 0.7727, Class 2: 0.1946, Class 3: 0.4374, Class 4: 0.7377, Class 5: 0.5335, Class 6: 0.8932, 
Validation F-beta Score
Class 0: 0.9824, Class 1: 0.8858, Class 2: 0.3157, Class 3: 0.5682, Class 4: 0.7888, Class 5: 0.6041, Class 6: 0.9401, 

Overall Mean Dice Score: 0.6749
Overall Mean F-beta Score: 0.7574

Training Loss: 0.4219, Validation Loss: 0.4255, Validation F-beta: 0.7574
Epoch 79/4000


Training: 100%|██████████| 204/204 [02:45<00:00,  1.23it/s, loss=0.376]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.84it/s, loss=0.451]

Validation Dice Score
Class 0: 0.9879, Class 1: 0.7494, Class 2: 0.2209, Class 3: 0.4159, Class 4: 0.7559, Class 5: 0.4756, Class 6: 0.8696, 
Validation F-beta Score
Class 0: 0.9829, Class 1: 0.9017, Class 2: 0.2909, Class 3: 0.5576, Class 4: 0.8475, Class 5: 0.5795, Class 6: 0.9416, 

Overall Mean Dice Score: 0.6533
Overall Mean F-beta Score: 0.7656

Training Loss: 0.4208, Validation Loss: 0.4308, Validation F-beta: 0.7656
Early stopping





0,1
class_0_dice_score,▇▃▃▆▄▄▃▅▇▆▅▃▁▆▂▅▅▄▅▆▅▄▆▄▃▄▇▂▅▅▆▆▅▆▇▇▇▆▅█
class_0_f_beta_score,█▆▅▅▁▃▂▅▄▃▅▁▃▅▃▄▆▄▅▄▃▄▃▃▆▄▃▄▃▄▄▄▅▆▆▆▄▆▅▆
class_1_dice_score,▁▄▅▄▃▆▃▆▆▆▇▇▅▆▆▅▆▇▅▇▆▄▆▆▇▇▅██▇▆▆▇▇▇█▆▆▅▇
class_1_f_beta_score,▄▄▁▂▅▂▃▇▆▅▇▇▇▇▇▆▇▇▇▇▇▇█▇▇▆▇▇▇▇▆▇█▇█▇▇▇██
class_2_dice_score,▁▁▁▁▁▂▆▄▄▄▆▇▇▅▅▅▃▆▄▄▅▃▇▄▆▆▄▆▆▅▅▆▅▅▅▆█▇▃▇
class_2_f_beta_score,▁▁▁▁▁▄▅▄▄▅██▇▇▄▅▅▃▅▆▄▅▅▇▇▅▇▄▆▄▆▆▄▄▆▆▅▅▄▄
class_3_dice_score,▄▅▃▄▆▂▆▆▁▅▆▆▆▆▅▅▆▆▇▇▅▇▅▇▇▇▆▇▅▅▅▆▅██▅▅▆▇▆
class_3_f_beta_score,▁▂▄▃▁█▅▃▃▅▆▆▄▆▆█▆▅▃▆▄▇▆▅▄▅▅▇▅▂▆▅▆█▅▆▁▃▄▄
class_4_dice_score,▃▂▄▆▂▆▂▂▄▅▆▅▅▂▆▅▂▄▄▆█▇█▁▇▇▄▅▇▃█▇▇▅▆▄▆▆▇▇
class_4_f_beta_score,▁▄▅▄▅▅▆▅▄▆█▅▅▄▅▃▆▅▄▇▅▆▅▆▅▄▆▅▄▆▇█▇▇▄▅▅▄▆▅

0,1
class_0_dice_score,0.98785
class_0_f_beta_score,0.98294
class_1_dice_score,0.74941
class_1_f_beta_score,0.90166
class_2_dice_score,0.22091
class_2_f_beta_score,0.29086
class_3_dice_score,0.41589
class_3_f_beta_score,0.55761
class_4_dice_score,0.7559
class_4_f_beta_score,0.84749


In [12]:
train_img_dir = './datasets/public_data/images'
train_label_dir = './datasets/public_data/labels'

train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders_bw(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    val_non_random_transforms=non_random_transforms,
    random_transforms = random_transforms, 
    val_random_transforms=val_random_transforms,
    batch_size = loader_batch,
    num_workers=0,
    train_num_repeat=num_repeat
    )

batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

Loading dataset: 100%|██████████| 24/24 [00:04<00:00,  5.50it/s]
Loading dataset: 100%|██████████| 1/1 [00:00<00:00,  3.04it/s]


torch.Size([16, 1, 32, 96, 96]) torch.Size([16, 1, 32, 96, 96])


In [13]:
torch.backends.cudnn.benchmark = True

In [14]:

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model_pretrained.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss', 'best_val_fbeta_score']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                best_val_fbeta_score = checkpoint['best_val_fbeta_score']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")
            
lr = lr/2
            
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)


기존 best model 발견: model_checkpoints\FLEX_randGaus_511_241_noclswt_gelu_batch_f48_lr1e-03_a0.52_b16_r4_ce0.4_ac1\best_model_pretrained.pt


  checkpoint = torch.load(best_model_path, map_location=device)


기존 학습된 가중치를 성공적으로 로드했습니다.


In [15]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_UNet',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        # "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        # "attn_drop_rate": attn_drop_rate,
        # "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        # "num_bottleneck": num_bottleneck,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

In [16]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps,
    pretrained=False,
    ) 

Epoch 65/4000


Training: 100%|██████████| 96/96 [01:19<00:00,  1.20it/s, loss=0.434]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.93it/s, loss=0.416]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.7543, Class 2: 0.0000, Class 3: 0.4627, Class 4: 0.7784, Class 5: 0.4993, Class 6: 0.9357, 
Validation F-beta Score
Class 0: 0.9822, Class 1: 0.8482, Class 2: 0.0000, Class 3: 0.5227, Class 4: 0.8553, Class 5: 0.5658, Class 6: 0.9601, 

Overall Mean Dice Score: 0.6861
Overall Mean F-beta Score: 0.7504

Training Loss: 0.4408, Validation Loss: 0.4161, Validation F-beta: 0.7504
Epoch 66/4000


Training: 100%|██████████| 96/96 [01:18<00:00,  1.23it/s, loss=0.455]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.89it/s, loss=0.406]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.7467, Class 2: 0.2922, Class 3: 0.5460, Class 4: 0.7802, Class 5: 0.4758, Class 6: 0.9176, 
Validation F-beta Score
Class 0: 0.9806, Class 1: 0.8498, Class 2: 0.3442, Class 3: 0.7822, Class 4: 0.8705, Class 5: 0.6579, Class 6: 0.9843, 

Overall Mean Dice Score: 0.6933
Overall Mean F-beta Score: 0.8289

Training Loss: 0.4345, Validation Loss: 0.4059, Validation F-beta: 0.8289
SUPER Best model saved. Loss:0.4059, Score:0.8289
Epoch 67/4000


Training: 100%|██████████| 96/96 [01:16<00:00,  1.25it/s, loss=0.462]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.02it/s, loss=0.435]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.6392, Class 2: 0.4487, Class 3: 0.3072, Class 4: 0.6218, Class 5: 0.4432, Class 6: 0.8798, 
Validation F-beta Score
Class 0: 0.9826, Class 1: 0.8489, Class 2: 0.5309, Class 3: 0.3653, Class 4: 0.7131, Class 5: 0.5461, Class 6: 0.9137, 

Overall Mean Dice Score: 0.5783
Overall Mean F-beta Score: 0.6774

Training Loss: 0.4380, Validation Loss: 0.4347, Validation F-beta: 0.6774
Epoch 68/4000


Training: 100%|██████████| 96/96 [01:11<00:00,  1.34it/s, loss=0.454]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.01it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9856, Class 1: 0.7192, Class 2: 0.3155, Class 3: 0.5851, Class 4: 0.7651, Class 5: 0.5589, Class 6: 0.9453, 
Validation F-beta Score
Class 0: 0.9804, Class 1: 0.8687, Class 2: 0.4902, Class 3: 0.7065, Class 4: 0.7993, Class 5: 0.7071, Class 6: 0.9662, 

Overall Mean Dice Score: 0.7147
Overall Mean F-beta Score: 0.8096

Training Loss: 0.4386, Validation Loss: 0.4426, Validation F-beta: 0.8096
Epoch 69/4000


Training: 100%|██████████| 96/96 [01:12<00:00,  1.33it/s, loss=0.439]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.92it/s, loss=0.445]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.2858, Class 2: 0.0000, Class 3: 0.4974, Class 4: 0.7486, Class 5: 0.4869, Class 6: 0.8880, 
Validation F-beta Score
Class 0: 0.9838, Class 1: 0.6176, Class 2: 0.0000, Class 3: 0.6832, Class 4: 0.7575, Class 5: 0.5010, Class 6: 0.9648, 

Overall Mean Dice Score: 0.5813
Overall Mean F-beta Score: 0.7048

Training Loss: 0.4376, Validation Loss: 0.4451, Validation F-beta: 0.7048
Epoch 70/4000


Training: 100%|██████████| 96/96 [01:12<00:00,  1.32it/s, loss=0.433]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.00it/s, loss=0.429]


Validation Dice Score
Class 0: 0.9855, Class 1: 0.7318, Class 2: 0.0207, Class 3: 0.4363, Class 4: 0.7310, Class 5: 0.5790, Class 6: 0.9348, 
Validation F-beta Score
Class 0: 0.9825, Class 1: 0.8867, Class 2: 0.0231, Class 3: 0.5587, Class 4: 0.7536, Class 5: 0.6445, Class 6: 0.9652, 

Overall Mean Dice Score: 0.6826
Overall Mean F-beta Score: 0.7617

Training Loss: 0.4338, Validation Loss: 0.4293, Validation F-beta: 0.7617
Epoch 71/4000


Training: 100%|██████████| 96/96 [01:12<00:00,  1.32it/s, loss=0.44] 
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.93it/s, loss=0.434]


Validation Dice Score
Class 0: 0.9830, Class 1: 0.8525, Class 2: 0.1387, Class 3: 0.5554, Class 4: 0.6837, Class 5: 0.4802, Class 6: 0.9344, 
Validation F-beta Score
Class 0: 0.9762, Class 1: 0.9079, Class 2: 0.1957, Class 3: 0.6885, Class 4: 0.8401, Class 5: 0.5183, Class 6: 0.9704, 

Overall Mean Dice Score: 0.7012
Overall Mean F-beta Score: 0.7850

Training Loss: 0.4341, Validation Loss: 0.4335, Validation F-beta: 0.7850
Epoch 72/4000


Training: 100%|██████████| 96/96 [01:13<00:00,  1.31it/s, loss=0.433]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.92it/s, loss=0.431]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.7753, Class 2: 0.3013, Class 3: 0.5399, Class 4: 0.7286, Class 5: 0.2433, Class 6: 0.9439, 
Validation F-beta Score
Class 0: 0.9813, Class 1: 0.8852, Class 2: 0.5099, Class 3: 0.5608, Class 4: 0.7470, Class 5: 0.3998, Class 6: 0.9559, 

Overall Mean Dice Score: 0.6462
Overall Mean F-beta Score: 0.7098

Training Loss: 0.4350, Validation Loss: 0.4312, Validation F-beta: 0.7098
Epoch 73/4000


Training: 100%|██████████| 96/96 [01:10<00:00,  1.36it/s, loss=0.443]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.00it/s, loss=0.4]


Validation Dice Score
Class 0: 0.9882, Class 1: 0.7729, Class 2: 0.1665, Class 3: 0.5759, Class 4: 0.8200, Class 5: 0.5018, Class 6: 0.9249, 
Validation F-beta Score
Class 0: 0.9824, Class 1: 0.8846, Class 2: 0.3402, Class 3: 0.6973, Class 4: 0.8762, Class 5: 0.6838, Class 6: 0.9702, 

Overall Mean Dice Score: 0.7191
Overall Mean F-beta Score: 0.8224

Training Loss: 0.4352, Validation Loss: 0.3999, Validation F-beta: 0.8224
Epoch 74/4000


Training: 100%|██████████| 96/96 [01:10<00:00,  1.37it/s, loss=0.444]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.06it/s, loss=0.429]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.7777, Class 2: 0.0000, Class 3: 0.5938, Class 4: 0.6883, Class 5: 0.5033, Class 6: 0.9499, 
Validation F-beta Score
Class 0: 0.9859, Class 1: 0.8874, Class 2: 0.0000, Class 3: 0.6914, Class 4: 0.8132, Class 5: 0.5226, Class 6: 0.9622, 

Overall Mean Dice Score: 0.7026
Overall Mean F-beta Score: 0.7753

Training Loss: 0.4348, Validation Loss: 0.4289, Validation F-beta: 0.7753
Epoch 75/4000


Training: 100%|██████████| 96/96 [01:10<00:00,  1.36it/s, loss=0.425]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.99it/s, loss=0.437]


Validation Dice Score
Class 0: 0.9875, Class 1: 0.8700, Class 2: 0.2141, Class 3: 0.3286, Class 4: 0.8254, Class 5: 0.5546, Class 6: 0.8946, 
Validation F-beta Score
Class 0: 0.9843, Class 1: 0.8973, Class 2: 0.3005, Class 3: 0.5377, Class 4: 0.8492, Class 5: 0.6122, Class 6: 0.9437, 

Overall Mean Dice Score: 0.6947
Overall Mean F-beta Score: 0.7680

Training Loss: 0.4362, Validation Loss: 0.4372, Validation F-beta: 0.7680
Epoch 76/4000


Training: 100%|██████████| 96/96 [01:10<00:00,  1.35it/s, loss=0.431]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.04it/s, loss=0.428]


Validation Dice Score
Class 0: 0.9840, Class 1: 0.7896, Class 2: 0.3123, Class 3: 0.4720, Class 4: 0.7242, Class 5: 0.4595, Class 6: 0.8955, 
Validation F-beta Score
Class 0: 0.9820, Class 1: 0.8844, Class 2: 0.4396, Class 3: 0.5973, Class 4: 0.7192, Class 5: 0.5070, Class 6: 0.9360, 

Overall Mean Dice Score: 0.6681
Overall Mean F-beta Score: 0.7288

Training Loss: 0.4358, Validation Loss: 0.4275, Validation F-beta: 0.7288
Epoch 77/4000


Training: 100%|██████████| 96/96 [01:11<00:00,  1.33it/s, loss=0.45] 
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.08it/s, loss=0.439]


Validation Dice Score
Class 0: 0.9821, Class 1: 0.8183, Class 2: 0.0390, Class 3: 0.0294, Class 4: 0.7401, Class 5: 0.5195, Class 6: 0.9228, 
Validation F-beta Score
Class 0: 0.9786, Class 1: 0.8849, Class 2: 0.0506, Class 3: 0.0348, Class 4: 0.7812, Class 5: 0.5507, Class 6: 0.9584, 

Overall Mean Dice Score: 0.6060
Overall Mean F-beta Score: 0.6420

Training Loss: 0.4336, Validation Loss: 0.4387, Validation F-beta: 0.6420
Epoch 78/4000


Training: 100%|██████████| 96/96 [01:11<00:00,  1.35it/s, loss=0.445]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.02it/s, loss=0.426]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.8244, Class 2: 0.1282, Class 3: 0.3651, Class 4: 0.7634, Class 5: 0.5493, Class 6: 0.8975, 
Validation F-beta Score
Class 0: 0.9821, Class 1: 0.9119, Class 2: 0.2046, Class 3: 0.4866, Class 4: 0.8512, Class 5: 0.6163, Class 6: 0.9531, 

Overall Mean Dice Score: 0.6799
Overall Mean F-beta Score: 0.7638

Training Loss: 0.4332, Validation Loss: 0.4263, Validation F-beta: 0.7638
Epoch 79/4000


Training: 100%|██████████| 96/96 [01:11<00:00,  1.34it/s, loss=0.442]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.95it/s, loss=0.414]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.8209, Class 2: 0.2444, Class 3: 0.5915, Class 4: 0.6950, Class 5: 0.5023, Class 6: 0.9071, 
Validation F-beta Score
Class 0: 0.9815, Class 1: 0.8891, Class 2: 0.2880, Class 3: 0.6647, Class 4: 0.7888, Class 5: 0.6527, Class 6: 0.9623, 

Overall Mean Dice Score: 0.7034
Overall Mean F-beta Score: 0.7915

Training Loss: 0.4341, Validation Loss: 0.4138, Validation F-beta: 0.7915
Epoch 80/4000


Training: 100%|██████████| 96/96 [01:11<00:00,  1.35it/s, loss=0.446]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.00it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.6817, Class 2: 0.3431, Class 3: 0.5083, Class 4: 0.7964, Class 5: 0.4169, Class 6: 0.9333, 
Validation F-beta Score
Class 0: 0.9833, Class 1: 0.9044, Class 2: 0.7044, Class 3: 0.6548, Class 4: 0.8201, Class 5: 0.5337, Class 6: 0.9557, 

Overall Mean Dice Score: 0.6673
Overall Mean F-beta Score: 0.7737

Training Loss: 0.4339, Validation Loss: 0.4432, Validation F-beta: 0.7737
Epoch 81/4000


Training: 100%|██████████| 96/96 [01:12<00:00,  1.33it/s, loss=0.398]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.95it/s, loss=0.439]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7665, Class 2: 0.1712, Class 3: 0.5274, Class 4: 0.7236, Class 5: 0.4948, Class 6: 0.9325, 
Validation F-beta Score
Class 0: 0.9849, Class 1: 0.8957, Class 2: 0.1720, Class 3: 0.6867, Class 4: 0.7668, Class 5: 0.4785, Class 6: 0.9779, 

Overall Mean Dice Score: 0.6890
Overall Mean F-beta Score: 0.7611

Training Loss: 0.4321, Validation Loss: 0.4389, Validation F-beta: 0.7611
Epoch 82/4000


Training: 100%|██████████| 96/96 [01:14<00:00,  1.30it/s, loss=0.439]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.92it/s, loss=0.447]


Validation Dice Score
Class 0: 0.9851, Class 1: 0.8412, Class 2: 0.0027, Class 3: 0.3131, Class 4: 0.7419, Class 5: 0.1902, Class 6: 0.9335, 
Validation F-beta Score
Class 0: 0.9782, Class 1: 0.8644, Class 2: 0.0043, Class 3: 0.5580, Class 4: 0.8847, Class 5: 0.2380, Class 6: 0.9611, 

Overall Mean Dice Score: 0.6040
Overall Mean F-beta Score: 0.7013

Training Loss: 0.4337, Validation Loss: 0.4474, Validation F-beta: 0.7013
Epoch 83/4000


Training: 100%|██████████| 96/96 [01:13<00:00,  1.30it/s, loss=0.434]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.94it/s, loss=0.436]

Validation Dice Score
Class 0: 0.9889, Class 1: 0.7747, Class 2: 0.1371, Class 3: 0.5213, Class 4: 0.7809, Class 5: 0.4775, Class 6: 0.9334, 
Validation F-beta Score
Class 0: 0.9843, Class 1: 0.8838, Class 2: 0.2232, Class 3: 0.6392, Class 4: 0.8528, Class 5: 0.6035, Class 6: 0.9697, 

Overall Mean Dice Score: 0.6976
Overall Mean F-beta Score: 0.7898

Training Loss: 0.4337, Validation Loss: 0.4364, Validation F-beta: 0.7898
Early stopping





0,1
class_0_dice_score,▆▆▆▅▅▄▂▆▇█▇▃▁▆▆▆▆▄█
class_0_f_beta_score,▅▄▆▄▆▆▁▅▅█▇▅▃▅▅▆▇▂▇
class_1_dice_score,▇▇▅▆▁▆█▇▇▇█▇▇▇▇▆▇█▇
class_1_f_beta_score,▆▇▆▇▁▇█▇▇▇█▇▇█▇██▇▇
class_2_dice_score,▁▆█▆▁▁▃▆▄▁▄▆▂▃▅▆▄▁▃
class_2_f_beta_score,▁▄▆▆▁▁▃▆▄▁▄▅▂▃▄█▃▁▃
class_3_dice_score,▆▇▄█▇▆█▇██▅▆▁▅█▇▇▅▇
class_3_f_beta_score,▆█▄▇▇▆▇▆▇▇▆▆▁▅▇▇▇▆▇
class_4_dice_score,▆▆▁▆▅▅▃▅█▃█▅▅▆▄▇▅▅▆
class_4_f_beta_score,▇▇▁▅▃▃▆▂█▅▇▁▄▇▄▅▃█▇

0,1
class_0_dice_score,0.9889
class_0_f_beta_score,0.98426
class_1_dice_score,0.77472
class_1_f_beta_score,0.88382
class_2_dice_score,0.13707
class_2_f_beta_score,0.22318
class_3_dice_score,0.52134
class_3_f_beta_score,0.63915
class_4_dice_score,0.78092
class_4_f_beta_score,0.85282


In [17]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# VAl

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import TverskyLoss
import torch
import numpy as np
from tqdm import tqdm
import wandb
from src.dataset.dataset import make_val_dataloader

val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
lamda = 0.52

wandb.init(
    project='czii_SwinUnetR_val',  # 프로젝트 이름 설정
    name='SwinUNETR96_96_lr0.001_lambda0.52_batch2',         # 실행(run) 이름 설정
    config={
        'learning_rate': 0.001,
        'batch_size': batch_size,
        'lambda': lamda,
        'img_size': img_size,
        'device': 'cuda',
        "checkpoint_dir": "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2",
        
    }
)

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

val_loader = make_val_dataloader(
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0
)
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    softmax=True
)
    
    
from monai.metrics import DiceMetric

img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

val_loss, overall_mean_fbeta_score = validate_one_epoch(
    model=model, 
    val_loader=val_loader, 
    criterion=criterion, 
    device=device, 
    epoch=0, 
    calculate_dice_interval=1
)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
class_0_dice_score,▁
class_0_f_beta_score,▁
class_1_dice_score,▁
class_1_f_beta_score,▁
class_2_dice_score,▁
class_2_f_beta_score,▁
class_3_dice_score,▁
class_3_f_beta_score,▁
class_4_dice_score,▁
class_4_f_beta_score,▁

0,1
class_0_dice_score,0.65703
class_0_f_beta_score,0.50748
class_1_dice_score,0.53332
class_1_f_beta_score,0.64703
class_2_dice_score,0.00286
class_2_f_beta_score,0.02334
class_3_dice_score,0.23703
class_3_f_beta_score,0.23033
class_4_dice_score,0.65487
class_4_f_beta_score,0.62525


Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]
  checkpoint = torch.load(pretrain_path, map_location=device)
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s, loss=0.865]

Validation Dice Score
Class 0: 0.6570, Class 1: 0.5333, Class 2: 0.0029, Class 3: 0.2370, 
Class 4: 0.6549, Class 5: 0.4790, Class 6: 0.4255, 
Validation F-beta Score
Class 0: 0.5075, Class 1: 0.6470, Class 2: 0.0233, Class 3: 0.2303, 
Class 4: 0.6252, Class 5: 0.5145, Class 6: 0.4720, 
Overall Mean Dice Score: 0.4659
Overall Mean F-beta Score: 0.4978






: 

: 

# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

: 

: 

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


: 

: 

In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


: 

: 

In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

: 

: 

In [None]:
val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=1, 
            calculate_dice_interval=0
        )

Validation:   0%|          | 0/4 [00:03<?, ?it/s, loss=0.764]


ZeroDivisionError: integer modulo by zero

: 

: 

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv


: 

: 