In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

import random
import numpy as np
import torch

from src.models import UNet_CBAM

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.config import print_config
from monai.metrics import DiceMetric
# from src.models.swincspunetr import SwinCSPUNETR
# from src.models.swincspunetr_unet import SwinCSPUNETR_unet
# from src.models.swincspunetr3plus import SwinCSPUNETR3plus

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

# 랜덤 시드 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


print_config()

  from .autonotebook import tqdm as notebook_tqdm


MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: c:\Users\<username>\.conda\envs\UM\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.2.0
Tensorboard version: 2.18.0
gdown version: 5.2.0
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.17.2
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.

In [2]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 1500}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 1500},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.16393442622950818, 2: 0.01639344262295082, 3: 0.2459016393442623, 4: 0.16393442622950818, 5: 0.2459016393442623, 6: 0.16393442622950818}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.16393442622950818, 0.01639344262295082, 0.2459016393442623, 0.16393442622950818, 0.2459016393442623, 0.16393442622950818]


# 모델 설정

In [3]:
from src.dataset.dataset import create_dataloaders, create_dataloaders_bw
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd, RandCropd,RandCropByPosNegLabeld, RandGaussianSmoothd
)
from monai.transforms import CastToTyped
import numpy as np

train_img_dir = "./datasets/pretrain_exdata/images"
train_label_dir = "./datasets/pretrain_exdata/labels"
val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
# DATA CONFIG
img_size =  96 # Match your patch size
img_depth = img_size
n_classes = 7
batch_size = 16 # 13.8GB GPU memory required for 128x128 img size
loader_batch = 1
num_samples = batch_size // loader_batch # 한 이미지에서 뽑을 샘플 수
num_repeat = 4
# MODEL CONFIG
num_epochs = 4000
lamda = 0.52
ce_weight = 0.4
lr = 0.001
feature_size = 48
use_checkpoint = True
use_v2 = True
drop_rate= 0.25
attn_drop_rate = 0.25
num_bottleneck = 2
# CLASS_WEIGHTS
class_weights = None
# class_weights = torch.tensor([0.0001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치
# class_weights = torch.tensor([0.9,1,0.9,1.1,1,1.1,1], dtype=torch.float32)  # 클래스별 가중치
class_weights = torch.tensor([1,1,1,1,1,1,1], dtype=torch.float32)  # 클래스별 가중치
sigma = 2.0
accumulation_steps = 1
# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    # GaussianSmoothd(
    #     keys=["image"],      # 변환을 적용할 키
    #     sigma=[sigma, sigma, sigma]  # 각 축(x, y, z)의 시그마 값
    #     ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    RandGaussianSmoothd(
    keys=["image"],      # 변환을 적용할 키
    sigma_x = (0.0, sigma), # 각 축(x, y, z)의 시그마 값
    sigma_y = (0.0, sigma),
    sigma_z = (0.0, sigma),
    prob=1.0,
    ),
])
val_random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    # RandGaussianSmoothd(
    # keys=["image"],      # 변환을 적용할 키
    # sigma_x = (0.0, sigma), # 각 축(x, y, z)의 시그마 값
    # sigma_y = (0.0, sigma),
    # sigma_z = (0.0, sigma),
    # prob=1.0,
    # ),
])


In [4]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders_bw(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    val_non_random_transforms=non_random_transforms,
    random_transforms = random_transforms, 
    val_random_transforms=val_random_transforms,
    batch_size = loader_batch,
    num_workers=0,
    train_num_repeat=num_repeat
    )

Loading dataset: 100%|██████████| 51/51 [00:04<00:00, 10.63it/s]
Loading dataset: 100%|██████████| 1/1 [00:00<00:00,  8.40it/s]


https://monai.io/model-zoo.html

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.losses import TverskyLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DynamicTverskyLoss 클래스 정의
class DynamicTverskyLoss(TverskyLoss):
    def __init__(self, lamda=0.5, **kwargs):
        super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
        self.lamda = lamda

    def set_lamda(self, lamda):
        self.lamda = lamda
        self.alpha = 1 - lamda
        self.beta = lamda


# CombinedCETverskyLoss 클래스
class CombinedCETverskyLoss(nn.Module):
    def __init__(self, lamda=0.5, ce_weight=0.5, n_classes=7, class_weights=None, ignore_index=-1, **kwargs):
        super().__init__()
        self.n_classes = n_classes
        self.ce_weight = ce_weight
        self.ignore_index = ignore_index
        
        # CrossEntropyLoss에서 클래스별 가중치를 적용
        self.ce = nn.CrossEntropyLoss(weight=class_weights, ignore_index=self.ignore_index, reduction='mean', **kwargs)
        
        # TverskyLoss
        self.tversky = DynamicTverskyLoss(lamda=lamda, reduction="none",softmax=True, **kwargs)

    def forward(self, inputs, targets):
        
        # CrossEntropyLoss는 정수형 클래스 인덱스를 사용
        ce_loss = self.ce(inputs, targets)

        # TverskyLoss 계산 (원핫 인코딩된 라벨을 사용)
        
        tversky_loss = self.tversky(inputs, targets)

        # 클래스별 가중치 적용 (Tversky 손실에도 가중치를 곱하기)
        class_weights = torch.tensor(self.ce.weight)  # CrossEntropy의 weight를 사용

        # Tversky 손실이 (B, num_classes) 형태이므로, 가중치를 클래스 차원에 곱합니다.
        tversky_loss = tversky_loss * class_weights.view(1, self.n_classes)

        # 최종 손실 계산
        final_loss = self.ce_weight * ce_loss + (1 - self.ce_weight) * tversky_loss.mean()  # mean()으로 배치에 대해 평균
        return final_loss

    def set_lamda(self, lamda):
        self.tversky.set_lamda(lamda)

    @property
    def lamda(self):
        return self.tversky.lamda

criterion = CombinedCETverskyLoss(
    lamda=lamda,
    ce_weight=ce_weight,
    n_classes=n_classes,
    class_weights=class_weights,
).to(device)

In [6]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
from monai.networks.nets import UNet
from src.models import DP_UNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=n_classes,
    channels=(32, 64, 128, 256),
    strides=(2, 2, 2),
    dropout = drop_rate,
).to(device)

pretrain_str = "yes" if use_checkpoint else "no"
weight_str = "weighted" if class_weights is not None else ""

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
folder_name = f"UNET_randGaus_511_241_noclswt_f{feature_size}_d{img_depth}s{img_size}_numb{num_bottleneck}_lr{lr:.0e}_a{lamda:.2f}_b{1-lamda:.2f}_b{batch_size}_r{num_repeat}_ce{ce_weight}_ac{accumulation_steps}"
checkpoint_dir = checkpoint_base_dir / folder_name
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")

기존 best model 발견: model_checkpoints\UNET_randGaus_511_241_noclswt_f48_d96s96_numb2_lr1e-03_a0.52_b0.48_b16_r4_ce0.4_ac1\best_model.pt
기존 학습된 가중치를 성공적으로 로드했습니다.


  checkpoint = torch.load(best_model_path, map_location=device)


In [7]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([16, 1, 96, 96, 96]) torch.Size([16, 1, 96, 96, 96])


In [8]:
torch.backends.cudnn.benchmark = True

In [9]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_UNet',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        "attn_drop_rate": attn_drop_rate,
        "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        "num_bottleneck": num_bottleneck,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpook0612[0m ([33mlimbw[0m). Use [1m`wandb login --relogin`[0m to force relogin


# 학습

In [10]:
from monai.metrics import DiceMetric
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    loss = criterion(outputs, labels_onehot)
    # loss = loss_fn(criterion(outputs, labels_onehot),class_weights=class_weights, device=device)
    return loss, outputs, labels, outputs.argmax(dim=1)

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    with tqdm(train_loader, desc='Training') as pbar:
        for i, batch_data in enumerate(pbar):
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
    avg_loss = epoch_loss / len(train_loader)
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss


def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval):
    model.eval()
    val_loss = 0
    
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
                
        print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1})
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\n")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    return val_loss / len(val_loader), overall_mean_fbeta

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps
        )
        
        scheduler.step(train_loss)
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F-beta: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        # if epochs_no_improve % 6 == 0 & epochs_no_improve != 0:
        #     # 손실이 개선되지 않았으므로 lambda 감소
        #     new_lamda = max(criterion.lamda - 0.01, 0.35)  # 최소값은 0.1로 설정
        #     criterion.set_lamda(new_lamda)
        #     print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")

    wandb.finish()


In [11]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps
     ) 

Epoch 3/4000


  class_weights = torch.tensor(self.ce.weight)  # CrossEntropy의 weight를 사용
Training: 100%|██████████| 204/204 [03:34<00:00,  1.05s/it, loss=0.443]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.60it/s, loss=0.493]


Validation Dice Score
Class 0: 0.9893, Class 1: 0.0106, Class 2: 0.0000, Class 3: 0.0929, Class 4: 0.5868, Class 5: 0.2053, Class 6: 0.7578, 
Validation F-beta Score
Class 0: 0.9904, Class 1: 0.0074, Class 2: 0.0000, Class 3: 0.1127, Class 4: 0.5912, Class 5: 0.1896, Class 6: 0.7321, 

Overall Mean Dice Score: 0.3307
Overall Mean F-beta Score: 0.3266

Training Loss: 0.4945, Validation Loss: 0.4772, Validation F-beta: 0.3266
SUPER Best model saved. Loss:0.4772, Score:0.3266
Epoch 4/4000


Training: 100%|██████████| 204/204 [03:13<00:00,  1.05it/s, loss=0.437]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.57it/s, loss=0.454]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.0614, Class 2: 0.0004, Class 3: 0.1830, Class 4: 0.6311, Class 5: 0.2590, Class 6: 0.8066, 
Validation F-beta Score
Class 0: 0.9853, Class 1: 0.1259, Class 2: 0.0002, Class 3: 0.2293, Class 4: 0.7216, Class 5: 0.2225, Class 6: 0.8508, 

Overall Mean Dice Score: 0.3882
Overall Mean F-beta Score: 0.4300

Training Loss: 0.4688, Validation Loss: 0.4616, Validation F-beta: 0.4300
SUPER Best model saved. Loss:0.4616, Score:0.4300
Epoch 5/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.412]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9884, Class 1: 0.3103, Class 2: 0.0441, Class 3: 0.1653, Class 4: 0.6360, Class 5: 0.2493, Class 6: 0.8613, 
Validation F-beta Score
Class 0: 0.9871, Class 1: 0.4572, Class 2: 0.0294, Class 3: 0.3332, Class 4: 0.6903, Class 5: 0.2149, Class 6: 0.8009, 

Overall Mean Dice Score: 0.4444
Overall Mean F-beta Score: 0.4993

Training Loss: 0.4563, Validation Loss: 0.4498, Validation F-beta: 0.4993
SUPER Best model saved. Loss:0.4498, Score:0.4993
Epoch 6/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.404]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.65it/s, loss=0.444]


Validation Dice Score
Class 0: 0.9888, Class 1: 0.3784, Class 2: 0.0565, Class 3: 0.1738, Class 4: 0.6812, Class 5: 0.2296, Class 6: 0.8341, 
Validation F-beta Score
Class 0: 0.9886, Class 1: 0.4968, Class 2: 0.0587, Class 3: 0.3406, Class 4: 0.6711, Class 5: 0.2020, Class 6: 0.8933, 

Overall Mean Dice Score: 0.4594
Overall Mean F-beta Score: 0.5208

Training Loss: 0.4416, Validation Loss: 0.4466, Validation F-beta: 0.5208
SUPER Best model saved. Loss:0.4466, Score:0.5208
Epoch 7/4000


Training: 100%|██████████| 204/204 [03:10<00:00,  1.07it/s, loss=0.389]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.65it/s, loss=0.424]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.5472, Class 2: 0.1158, Class 3: 0.1628, Class 4: 0.6257, Class 5: 0.2991, Class 6: 0.8712, 
Validation F-beta Score
Class 0: 0.9831, Class 1: 0.6557, Class 2: 0.1314, Class 3: 0.3148, Class 4: 0.7578, Class 5: 0.2928, Class 6: 0.9062, 

Overall Mean Dice Score: 0.5012
Overall Mean F-beta Score: 0.5854

Training Loss: 0.4343, Validation Loss: 0.4216, Validation F-beta: 0.5854
SUPER Best model saved. Loss:0.4216, Score:0.5854
Epoch 8/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.389]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.424]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.6270, Class 2: 0.0770, Class 3: 0.1901, Class 4: 0.6559, Class 5: 0.3308, Class 6: 0.8857, 
Validation F-beta Score
Class 0: 0.9818, Class 1: 0.7152, Class 2: 0.1345, Class 3: 0.4884, Class 4: 0.7915, Class 5: 0.3708, Class 6: 0.8953, 

Overall Mean Dice Score: 0.5379
Overall Mean F-beta Score: 0.6523

Training Loss: 0.4262, Validation Loss: 0.4232, Validation F-beta: 0.6523
Epoch 9/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.398]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.65it/s, loss=0.417]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.6263, Class 2: 0.1007, Class 3: 0.1496, Class 4: 0.6683, Class 5: 0.3033, Class 6: 0.9095, 
Validation F-beta Score
Class 0: 0.9867, Class 1: 0.6609, Class 2: 0.1128, Class 3: 0.2939, Class 4: 0.8111, Class 5: 0.2868, Class 6: 0.9113, 

Overall Mean Dice Score: 0.5314
Overall Mean F-beta Score: 0.5928

Training Loss: 0.4213, Validation Loss: 0.4216, Validation F-beta: 0.5928
SUPER Best model saved. Loss:0.4216, Score:0.5928
Epoch 10/4000


Training: 100%|██████████| 204/204 [03:21<00:00,  1.01it/s, loss=0.38] 
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.61it/s, loss=0.404]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.7254, Class 2: 0.1664, Class 3: 0.2721, Class 4: 0.6486, Class 5: 0.3824, Class 6: 0.8847, 
Validation F-beta Score
Class 0: 0.9800, Class 1: 0.7482, Class 2: 0.1809, Class 3: 0.5090, Class 4: 0.8391, Class 5: 0.4228, Class 6: 0.9101, 

Overall Mean Dice Score: 0.5827
Overall Mean F-beta Score: 0.6858

Training Loss: 0.4189, Validation Loss: 0.4091, Validation F-beta: 0.6858
SUPER Best model saved. Loss:0.4091, Score:0.6858
Epoch 11/4000


Training: 100%|██████████| 204/204 [03:05<00:00,  1.10it/s, loss=0.397]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.53it/s, loss=0.407]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.6641, Class 2: 0.1485, Class 3: 0.1798, Class 4: 0.6811, Class 5: 0.4176, Class 6: 0.8918, 
Validation F-beta Score
Class 0: 0.9833, Class 1: 0.7479, Class 2: 0.1966, Class 3: 0.2976, Class 4: 0.8122, Class 5: 0.4535, Class 6: 0.9428, 

Overall Mean Dice Score: 0.5669
Overall Mean F-beta Score: 0.6508

Training Loss: 0.4161, Validation Loss: 0.4025, Validation F-beta: 0.6508
Epoch 12/4000


Training: 100%|██████████| 204/204 [03:12<00:00,  1.06it/s, loss=0.338]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.61it/s, loss=0.405]


Validation Dice Score
Class 0: 0.9879, Class 1: 0.6628, Class 2: 0.1480, Class 3: 0.1190, Class 4: 0.6645, Class 5: 0.4326, Class 6: 0.8899, 
Validation F-beta Score
Class 0: 0.9835, Class 1: 0.7461, Class 2: 0.1582, Class 3: 0.2980, Class 4: 0.7766, Class 5: 0.4833, Class 6: 0.9226, 

Overall Mean Dice Score: 0.5538
Overall Mean F-beta Score: 0.6453

Training Loss: 0.4133, Validation Loss: 0.4105, Validation F-beta: 0.6453
Epoch 13/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.377]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.72it/s, loss=0.399]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7300, Class 2: 0.1016, Class 3: 0.1944, Class 4: 0.6294, Class 5: 0.4050, Class 6: 0.9068, 
Validation F-beta Score
Class 0: 0.9828, Class 1: 0.7488, Class 2: 0.1428, Class 3: 0.3602, Class 4: 0.8425, Class 5: 0.3903, Class 6: 0.9396, 

Overall Mean Dice Score: 0.5731
Overall Mean F-beta Score: 0.6563

Training Loss: 0.4107, Validation Loss: 0.3841, Validation F-beta: 0.6563
Epoch 14/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.368]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.43] 


Validation Dice Score
Class 0: 0.9878, Class 1: 0.6763, Class 2: 0.1451, Class 3: 0.1799, Class 4: 0.7009, Class 5: 0.3488, Class 6: 0.8809, 
Validation F-beta Score
Class 0: 0.9839, Class 1: 0.7043, Class 2: 0.1721, Class 3: 0.3343, Class 4: 0.8051, Class 5: 0.3606, Class 6: 0.9017, 

Overall Mean Dice Score: 0.5574
Overall Mean F-beta Score: 0.6212

Training Loss: 0.4096, Validation Loss: 0.4231, Validation F-beta: 0.6212
Epoch 15/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.383]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.64it/s, loss=0.415]


Validation Dice Score
Class 0: 0.9888, Class 1: 0.5303, Class 2: 0.1164, Class 3: 0.1963, Class 4: 0.7161, Class 5: 0.3567, Class 6: 0.8959, 
Validation F-beta Score
Class 0: 0.9851, Class 1: 0.7500, Class 2: 0.1169, Class 3: 0.3482, Class 4: 0.8193, Class 5: 0.3991, Class 6: 0.9256, 

Overall Mean Dice Score: 0.5391
Overall Mean F-beta Score: 0.6484

Training Loss: 0.4086, Validation Loss: 0.4139, Validation F-beta: 0.6484
Epoch 16/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.371]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, loss=0.399]


Validation Dice Score
Class 0: 0.9857, Class 1: 0.7203, Class 2: 0.1259, Class 3: 0.2875, Class 4: 0.6603, Class 5: 0.4177, Class 6: 0.9068, 
Validation F-beta Score
Class 0: 0.9787, Class 1: 0.7563, Class 2: 0.1965, Class 3: 0.4339, Class 4: 0.8228, Class 5: 0.5445, Class 6: 0.9296, 

Overall Mean Dice Score: 0.5985
Overall Mean F-beta Score: 0.6974

Training Loss: 0.4073, Validation Loss: 0.3907, Validation F-beta: 0.6974
SUPER Best model saved. Loss:0.3907, Score:0.6974
Epoch 17/4000


Training: 100%|██████████| 204/204 [02:58<00:00,  1.14it/s, loss=0.388]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.74it/s, loss=0.424]


Validation Dice Score
Class 0: 0.9881, Class 1: 0.7487, Class 2: 0.2102, Class 3: 0.2028, Class 4: 0.6014, Class 5: 0.4406, Class 6: 0.9175, 
Validation F-beta Score
Class 0: 0.9827, Class 1: 0.7791, Class 2: 0.2344, Class 3: 0.4464, Class 4: 0.8029, Class 5: 0.5318, Class 6: 0.9418, 

Overall Mean Dice Score: 0.5822
Overall Mean F-beta Score: 0.7004

Training Loss: 0.4032, Validation Loss: 0.4048, Validation F-beta: 0.7004
Epoch 18/4000


Training: 100%|██████████| 204/204 [03:06<00:00,  1.09it/s, loss=0.377]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.427]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.6751, Class 2: 0.1251, Class 3: 0.3031, Class 4: 0.6621, Class 5: 0.4331, Class 6: 0.8991, 
Validation F-beta Score
Class 0: 0.9822, Class 1: 0.7982, Class 2: 0.1549, Class 3: 0.5677, Class 4: 0.8245, Class 5: 0.5127, Class 6: 0.9267, 

Overall Mean Dice Score: 0.5945
Overall Mean F-beta Score: 0.7259

Training Loss: 0.4031, Validation Loss: 0.4023, Validation F-beta: 0.7259
Epoch 19/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.34] 
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.377]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7396, Class 2: 0.1249, Class 3: 0.2556, Class 4: 0.6617, Class 5: 0.4493, Class 6: 0.8752, 
Validation F-beta Score
Class 0: 0.9808, Class 1: 0.7867, Class 2: 0.2128, Class 3: 0.5122, Class 4: 0.8133, Class 5: 0.5321, Class 6: 0.9451, 

Overall Mean Dice Score: 0.5963
Overall Mean F-beta Score: 0.7179

Training Loss: 0.4019, Validation Loss: 0.3940, Validation F-beta: 0.7179
Epoch 20/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.335]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, loss=0.401]


Validation Dice Score
Class 0: 0.9860, Class 1: 0.7308, Class 2: 0.1529, Class 3: 0.2708, Class 4: 0.6393, Class 5: 0.3843, Class 6: 0.8642, 
Validation F-beta Score
Class 0: 0.9796, Class 1: 0.7657, Class 2: 0.2124, Class 3: 0.5054, Class 4: 0.7956, Class 5: 0.5151, Class 6: 0.9425, 

Overall Mean Dice Score: 0.5779
Overall Mean F-beta Score: 0.7048

Training Loss: 0.4002, Validation Loss: 0.3950, Validation F-beta: 0.7048
Epoch 21/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.35] 
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.39] 


Validation Dice Score
Class 0: 0.9870, Class 1: 0.7279, Class 2: 0.1668, Class 3: 0.2419, Class 4: 0.7139, Class 5: 0.4133, Class 6: 0.9093, 
Validation F-beta Score
Class 0: 0.9808, Class 1: 0.8401, Class 2: 0.2236, Class 3: 0.4314, Class 4: 0.8240, Class 5: 0.5630, Class 6: 0.9421, 

Overall Mean Dice Score: 0.6012
Overall Mean F-beta Score: 0.7201

Training Loss: 0.4013, Validation Loss: 0.4072, Validation F-beta: 0.7201
Epoch 22/4000


Training: 100%|██████████| 204/204 [03:12<00:00,  1.06it/s, loss=0.367]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, loss=0.362]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.7367, Class 2: 0.1571, Class 3: 0.2493, Class 4: 0.6733, Class 5: 0.4018, Class 6: 0.9033, 
Validation F-beta Score
Class 0: 0.9789, Class 1: 0.8017, Class 2: 0.2100, Class 3: 0.4322, Class 4: 0.8229, Class 5: 0.5411, Class 6: 0.9384, 

Overall Mean Dice Score: 0.5929
Overall Mean F-beta Score: 0.7073

Training Loss: 0.3989, Validation Loss: 0.3877, Validation F-beta: 0.7073
SUPER Best model saved. Loss:0.3877, Score:0.7073
Epoch 23/4000


Training: 100%|██████████| 204/204 [03:10<00:00,  1.07it/s, loss=0.37] 
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, loss=0.405]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.7812, Class 2: 0.2194, Class 3: 0.3055, Class 4: 0.6798, Class 5: 0.4272, Class 6: 0.8910, 
Validation F-beta Score
Class 0: 0.9798, Class 1: 0.8253, Class 2: 0.2678, Class 3: 0.5199, Class 4: 0.8557, Class 5: 0.5720, Class 6: 0.9490, 

Overall Mean Dice Score: 0.6169
Overall Mean F-beta Score: 0.7444

Training Loss: 0.3984, Validation Loss: 0.4019, Validation F-beta: 0.7444
Epoch 24/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.337]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s, loss=0.38] 


Validation Dice Score
Class 0: 0.9879, Class 1: 0.7433, Class 2: 0.1904, Class 3: 0.3306, Class 4: 0.6883, Class 5: 0.4293, Class 6: 0.8237, 
Validation F-beta Score
Class 0: 0.9818, Class 1: 0.8406, Class 2: 0.2351, Class 3: 0.5821, Class 4: 0.8427, Class 5: 0.5473, Class 6: 0.9367, 

Overall Mean Dice Score: 0.6030
Overall Mean F-beta Score: 0.7499

Training Loss: 0.3966, Validation Loss: 0.3897, Validation F-beta: 0.7499
Epoch 25/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.367]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.363]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7601, Class 2: 0.1223, Class 3: 0.3340, Class 4: 0.6664, Class 5: 0.4940, Class 6: 0.8861, 
Validation F-beta Score
Class 0: 0.9807, Class 1: 0.8412, Class 2: 0.1831, Class 3: 0.5517, Class 4: 0.8589, Class 5: 0.6477, Class 6: 0.9535, 

Overall Mean Dice Score: 0.6281
Overall Mean F-beta Score: 0.7706

Training Loss: 0.3953, Validation Loss: 0.3793, Validation F-beta: 0.7706
SUPER Best model saved. Loss:0.3793, Score:0.7706
Epoch 26/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.392]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, loss=0.375]


Validation Dice Score
Class 0: 0.9881, Class 1: 0.7652, Class 2: 0.2159, Class 3: 0.3530, Class 4: 0.6729, Class 5: 0.4490, Class 6: 0.9080, 
Validation F-beta Score
Class 0: 0.9815, Class 1: 0.8195, Class 2: 0.2475, Class 3: 0.6181, Class 4: 0.8656, Class 5: 0.6115, Class 6: 0.9314, 

Overall Mean Dice Score: 0.6296
Overall Mean F-beta Score: 0.7692

Training Loss: 0.3952, Validation Loss: 0.3782, Validation F-beta: 0.7692
Epoch 27/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.359]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.377]


Validation Dice Score
Class 0: 0.9891, Class 1: 0.7827, Class 2: 0.1861, Class 3: 0.3430, Class 4: 0.7150, Class 5: 0.5015, Class 6: 0.8936, 
Validation F-beta Score
Class 0: 0.9849, Class 1: 0.8131, Class 2: 0.2383, Class 3: 0.5721, Class 4: 0.8036, Class 5: 0.5958, Class 6: 0.9331, 

Overall Mean Dice Score: 0.6471
Overall Mean F-beta Score: 0.7435

Training Loss: 0.3944, Validation Loss: 0.3787, Validation F-beta: 0.7435
Epoch 28/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.349]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.72it/s, loss=0.412]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.7752, Class 2: 0.1314, Class 3: 0.2855, Class 4: 0.7010, Class 5: 0.4476, Class 6: 0.8777, 
Validation F-beta Score
Class 0: 0.9793, Class 1: 0.8458, Class 2: 0.2125, Class 3: 0.4866, Class 4: 0.8400, Class 5: 0.6007, Class 6: 0.9620, 

Overall Mean Dice Score: 0.6174
Overall Mean F-beta Score: 0.7470

Training Loss: 0.3958, Validation Loss: 0.3989, Validation F-beta: 0.7470
Epoch 29/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.367]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.374]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7199, Class 2: 0.1679, Class 3: 0.2979, Class 4: 0.7132, Class 5: 0.4445, Class 6: 0.9020, 
Validation F-beta Score
Class 0: 0.9805, Class 1: 0.8394, Class 2: 0.2503, Class 3: 0.5210, Class 4: 0.8325, Class 5: 0.6177, Class 6: 0.9462, 

Overall Mean Dice Score: 0.6155
Overall Mean F-beta Score: 0.7513

Training Loss: 0.3941, Validation Loss: 0.3858, Validation F-beta: 0.7513
Epoch 30/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.347]
Validation: 100%|██████████| 3/3 [00:02<00:00,  1.34it/s, loss=0.38] 


Validation Dice Score
Class 0: 0.9859, Class 1: 0.6913, Class 2: 0.1727, Class 3: 0.3332, Class 4: 0.6830, Class 5: 0.4537, Class 6: 0.8628, 
Validation F-beta Score
Class 0: 0.9787, Class 1: 0.8270, Class 2: 0.2635, Class 3: 0.5701, Class 4: 0.8401, Class 5: 0.5699, Class 6: 0.9383, 

Overall Mean Dice Score: 0.6048
Overall Mean F-beta Score: 0.7491

Training Loss: 0.3929, Validation Loss: 0.3790, Validation F-beta: 0.7491
Epoch 31/4000


Training: 100%|██████████| 204/204 [03:05<00:00,  1.10it/s, loss=0.348]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.63it/s, loss=0.376]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.8052, Class 2: 0.1683, Class 3: 0.2594, Class 4: 0.6185, Class 5: 0.4054, Class 6: 0.8772, 
Validation F-beta Score
Class 0: 0.9813, Class 1: 0.8447, Class 2: 0.2454, Class 3: 0.4407, Class 4: 0.8408, Class 5: 0.5115, Class 6: 0.9545, 

Overall Mean Dice Score: 0.5931
Overall Mean F-beta Score: 0.7184

Training Loss: 0.3930, Validation Loss: 0.3875, Validation F-beta: 0.7184
Epoch 32/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.358]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.72it/s, loss=0.369]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.7466, Class 2: 0.1734, Class 3: 0.2663, Class 4: 0.7096, Class 5: 0.4343, Class 6: 0.9041, 
Validation F-beta Score
Class 0: 0.9816, Class 1: 0.8439, Class 2: 0.2481, Class 3: 0.5713, Class 4: 0.8339, Class 5: 0.5685, Class 6: 0.9611, 

Overall Mean Dice Score: 0.6122
Overall Mean F-beta Score: 0.7557

Training Loss: 0.3932, Validation Loss: 0.3856, Validation F-beta: 0.7557
Epoch 33/4000


Training: 100%|██████████| 204/204 [03:10<00:00,  1.07it/s, loss=0.371]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s, loss=0.407]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.7378, Class 2: 0.1314, Class 3: 0.3314, Class 4: 0.6424, Class 5: 0.4265, Class 6: 0.8718, 
Validation F-beta Score
Class 0: 0.9787, Class 1: 0.8279, Class 2: 0.2795, Class 3: 0.6096, Class 4: 0.8004, Class 5: 0.5930, Class 6: 0.9416, 

Overall Mean Dice Score: 0.6020
Overall Mean F-beta Score: 0.7545

Training Loss: 0.3934, Validation Loss: 0.3958, Validation F-beta: 0.7545
Epoch 34/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.343]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.38] 


Validation Dice Score
Class 0: 0.9859, Class 1: 0.7373, Class 2: 0.1580, Class 3: 0.3653, Class 4: 0.6851, Class 5: 0.5331, Class 6: 0.8280, 
Validation F-beta Score
Class 0: 0.9779, Class 1: 0.8488, Class 2: 0.2563, Class 3: 0.5810, Class 4: 0.8624, Class 5: 0.6724, Class 6: 0.9448, 

Overall Mean Dice Score: 0.6297
Overall Mean F-beta Score: 0.7819

Training Loss: 0.3912, Validation Loss: 0.3833, Validation F-beta: 0.7819
Epoch 35/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.346]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.403]


Validation Dice Score
Class 0: 0.9845, Class 1: 0.7395, Class 2: 0.1403, Class 3: 0.3020, Class 4: 0.6419, Class 5: 0.4724, Class 6: 0.8910, 
Validation F-beta Score
Class 0: 0.9753, Class 1: 0.8716, Class 2: 0.2303, Class 3: 0.5465, Class 4: 0.8537, Class 5: 0.6326, Class 6: 0.9423, 

Overall Mean Dice Score: 0.6094
Overall Mean F-beta Score: 0.7693

Training Loss: 0.3920, Validation Loss: 0.3863, Validation F-beta: 0.7693
Epoch 36/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.357]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s, loss=0.385]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7027, Class 2: 0.1879, Class 3: 0.3012, Class 4: 0.6661, Class 5: 0.3876, Class 6: 0.8968, 
Validation F-beta Score
Class 0: 0.9795, Class 1: 0.8156, Class 2: 0.2719, Class 3: 0.5266, Class 4: 0.8329, Class 5: 0.5966, Class 6: 0.9439, 

Overall Mean Dice Score: 0.5909
Overall Mean F-beta Score: 0.7431

Training Loss: 0.3916, Validation Loss: 0.4026, Validation F-beta: 0.7431
Epoch 37/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.349]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s, loss=0.371]


Validation Dice Score
Class 0: 0.9835, Class 1: 0.7052, Class 2: 0.1683, Class 3: 0.3795, Class 4: 0.6436, Class 5: 0.3792, Class 6: 0.8533, 
Validation F-beta Score
Class 0: 0.9725, Class 1: 0.8233, Class 2: 0.3043, Class 3: 0.5752, Class 4: 0.8650, Class 5: 0.6454, Class 6: 0.9465, 

Overall Mean Dice Score: 0.5921
Overall Mean F-beta Score: 0.7711

Training Loss: 0.3906, Validation Loss: 0.4003, Validation F-beta: 0.7711
Epoch 38/4000


Training: 100%|██████████| 204/204 [03:06<00:00,  1.09it/s, loss=0.366]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.51it/s, loss=0.368]


Validation Dice Score
Class 0: 0.9876, Class 1: 0.7514, Class 2: 0.2129, Class 3: 0.3570, Class 4: 0.6991, Class 5: 0.4821, Class 6: 0.8926, 
Validation F-beta Score
Class 0: 0.9817, Class 1: 0.8327, Class 2: 0.2737, Class 3: 0.5726, Class 4: 0.8174, Class 5: 0.6385, Class 6: 0.9391, 

Overall Mean Dice Score: 0.6365
Overall Mean F-beta Score: 0.7601

Training Loss: 0.3922, Validation Loss: 0.3845, Validation F-beta: 0.7601
Epoch 39/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.368]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, loss=0.367]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.7747, Class 2: 0.1891, Class 3: 0.3879, Class 4: 0.6441, Class 5: 0.4527, Class 6: 0.9036, 
Validation F-beta Score
Class 0: 0.9812, Class 1: 0.8859, Class 2: 0.2651, Class 3: 0.5607, Class 4: 0.8452, Class 5: 0.6232, Class 6: 0.9541, 

Overall Mean Dice Score: 0.6326
Overall Mean F-beta Score: 0.7738

Training Loss: 0.3906, Validation Loss: 0.3641, Validation F-beta: 0.7738
SUPER Best model saved. Loss:0.3641, Score:0.7738
Epoch 40/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.362]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.372]


Validation Dice Score
Class 0: 0.9883, Class 1: 0.7446, Class 2: 0.1488, Class 3: 0.3086, Class 4: 0.6950, Class 5: 0.4132, Class 6: 0.8985, 
Validation F-beta Score
Class 0: 0.9819, Class 1: 0.8743, Class 2: 0.2323, Class 3: 0.5195, Class 4: 0.8794, Class 5: 0.5395, Class 6: 0.9399, 

Overall Mean Dice Score: 0.6120
Overall Mean F-beta Score: 0.7505

Training Loss: 0.3899, Validation Loss: 0.3897, Validation F-beta: 0.7505
Epoch 41/4000


Training: 100%|██████████| 204/204 [03:11<00:00,  1.07it/s, loss=0.388]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s, loss=0.408]


Validation Dice Score
Class 0: 0.9882, Class 1: 0.7461, Class 2: 0.1542, Class 3: 0.3714, Class 4: 0.6225, Class 5: 0.4450, Class 6: 0.8716, 
Validation F-beta Score
Class 0: 0.9814, Class 1: 0.8346, Class 2: 0.2556, Class 3: 0.5763, Class 4: 0.8186, Class 5: 0.6630, Class 6: 0.9296, 

Overall Mean Dice Score: 0.6113
Overall Mean F-beta Score: 0.7644

Training Loss: 0.3902, Validation Loss: 0.3948, Validation F-beta: 0.7644
Epoch 42/4000


Training: 100%|██████████| 204/204 [03:03<00:00,  1.11it/s, loss=0.361]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.385]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7243, Class 2: 0.1961, Class 3: 0.3275, Class 4: 0.6580, Class 5: 0.4278, Class 6: 0.9143, 
Validation F-beta Score
Class 0: 0.9808, Class 1: 0.8624, Class 2: 0.3014, Class 3: 0.5536, Class 4: 0.7982, Class 5: 0.5714, Class 6: 0.9586, 

Overall Mean Dice Score: 0.6104
Overall Mean F-beta Score: 0.7488

Training Loss: 0.3902, Validation Loss: 0.3793, Validation F-beta: 0.7488
Epoch 43/4000


Training: 100%|██████████| 204/204 [03:04<00:00,  1.11it/s, loss=0.337]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.72it/s, loss=0.389]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7674, Class 2: 0.1191, Class 3: 0.3187, Class 4: 0.6756, Class 5: 0.4417, Class 6: 0.8727, 
Validation F-beta Score
Class 0: 0.9796, Class 1: 0.8555, Class 2: 0.2186, Class 3: 0.5365, Class 4: 0.8601, Class 5: 0.6294, Class 6: 0.9270, 

Overall Mean Dice Score: 0.6152
Overall Mean F-beta Score: 0.7617

Training Loss: 0.3911, Validation Loss: 0.3811, Validation F-beta: 0.7617
Epoch 44/4000


Training: 100%|██████████| 204/204 [03:11<00:00,  1.06it/s, loss=0.343]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.65it/s, loss=0.415]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.7222, Class 2: 0.1508, Class 3: 0.2731, Class 4: 0.6747, Class 5: 0.4966, Class 6: 0.8509, 
Validation F-beta Score
Class 0: 0.9798, Class 1: 0.8483, Class 2: 0.2237, Class 3: 0.5008, Class 4: 0.8613, Class 5: 0.6528, Class 6: 0.9577, 

Overall Mean Dice Score: 0.6035
Overall Mean F-beta Score: 0.7642

Training Loss: 0.3905, Validation Loss: 0.3948, Validation F-beta: 0.7642
Epoch 45/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.346]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.343]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7573, Class 2: 0.1770, Class 3: 0.4198, Class 4: 0.7079, Class 5: 0.4897, Class 6: 0.8872, 
Validation F-beta Score
Class 0: 0.9813, Class 1: 0.8506, Class 2: 0.2583, Class 3: 0.6170, Class 4: 0.8541, Class 5: 0.6480, Class 6: 0.9232, 

Overall Mean Dice Score: 0.6524
Overall Mean F-beta Score: 0.7786

Training Loss: 0.3919, Validation Loss: 0.3689, Validation F-beta: 0.7786
Epoch 46/4000


Training: 100%|██████████| 204/204 [03:29<00:00,  1.03s/it, loss=0.362]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.63it/s, loss=0.397]


Validation Dice Score
Class 0: 0.9876, Class 1: 0.7361, Class 2: 0.1818, Class 3: 0.3725, Class 4: 0.7318, Class 5: 0.4157, Class 6: 0.8918, 
Validation F-beta Score
Class 0: 0.9817, Class 1: 0.8204, Class 2: 0.2490, Class 3: 0.6044, Class 4: 0.7946, Class 5: 0.6218, Class 6: 0.8925, 

Overall Mean Dice Score: 0.6296
Overall Mean F-beta Score: 0.7467

Training Loss: 0.3934, Validation Loss: 0.3967, Validation F-beta: 0.7467
Epoch 47/4000


Training: 100%|██████████| 204/204 [03:18<00:00,  1.03it/s, loss=0.365]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.54it/s, loss=0.417]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.6948, Class 2: 0.0755, Class 3: 0.3005, Class 4: 0.7252, Class 5: 0.5189, Class 6: 0.8720, 
Validation F-beta Score
Class 0: 0.9842, Class 1: 0.8357, Class 2: 0.1463, Class 3: 0.5269, Class 4: 0.7982, Class 5: 0.6130, Class 6: 0.9465, 

Overall Mean Dice Score: 0.6223
Overall Mean F-beta Score: 0.7440

Training Loss: 0.3887, Validation Loss: 0.4049, Validation F-beta: 0.7440
Epoch 48/4000


Training: 100%|██████████| 204/204 [03:10<00:00,  1.07it/s, loss=0.363]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, loss=0.358]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.7780, Class 2: 0.1930, Class 3: 0.3352, Class 4: 0.7409, Class 5: 0.5193, Class 6: 0.8933, 
Validation F-beta Score
Class 0: 0.9844, Class 1: 0.8850, Class 2: 0.2501, Class 3: 0.4944, Class 4: 0.8149, Class 5: 0.6227, Class 6: 0.9353, 

Overall Mean Dice Score: 0.6533
Overall Mean F-beta Score: 0.7504

Training Loss: 0.3867, Validation Loss: 0.3650, Validation F-beta: 0.7504
Epoch 49/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.357]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.351]


Validation Dice Score
Class 0: 0.9894, Class 1: 0.7533, Class 2: 0.1765, Class 3: 0.3762, Class 4: 0.7274, Class 5: 0.4858, Class 6: 0.9057, 
Validation F-beta Score
Class 0: 0.9848, Class 1: 0.8409, Class 2: 0.2478, Class 3: 0.5562, Class 4: 0.8297, Class 5: 0.6077, Class 6: 0.9632, 

Overall Mean Dice Score: 0.6497
Overall Mean F-beta Score: 0.7595

Training Loss: 0.3871, Validation Loss: 0.3658, Validation F-beta: 0.7595
Epoch 50/4000


Training: 100%|██████████| 204/204 [03:10<00:00,  1.07it/s, loss=0.333]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.392]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.7888, Class 2: 0.1383, Class 3: 0.3521, Class 4: 0.7440, Class 5: 0.4161, Class 6: 0.8855, 
Validation F-beta Score
Class 0: 0.9844, Class 1: 0.8753, Class 2: 0.2449, Class 3: 0.5472, Class 4: 0.7756, Class 5: 0.5541, Class 6: 0.9286, 

Overall Mean Dice Score: 0.6373
Overall Mean F-beta Score: 0.7361

Training Loss: 0.3857, Validation Loss: 0.3895, Validation F-beta: 0.7361
Epoch 51/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.358]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, loss=0.359]


Validation Dice Score
Class 0: 0.9885, Class 1: 0.7912, Class 2: 0.1107, Class 3: 0.3760, Class 4: 0.6845, Class 5: 0.4521, Class 6: 0.9067, 
Validation F-beta Score
Class 0: 0.9826, Class 1: 0.8919, Class 2: 0.1990, Class 3: 0.5098, Class 4: 0.8193, Class 5: 0.6210, Class 6: 0.9562, 

Overall Mean Dice Score: 0.6421
Overall Mean F-beta Score: 0.7596

Training Loss: 0.3849, Validation Loss: 0.3669, Validation F-beta: 0.7596
Epoch 52/4000


Training: 100%|██████████| 204/204 [03:08<00:00,  1.08it/s, loss=0.354]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.365]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.7553, Class 2: 0.1899, Class 3: 0.3530, Class 4: 0.6965, Class 5: 0.4629, Class 6: 0.8980, 
Validation F-beta Score
Class 0: 0.9832, Class 1: 0.8736, Class 2: 0.2375, Class 3: 0.5216, Class 4: 0.8457, Class 5: 0.5967, Class 6: 0.9529, 

Overall Mean Dice Score: 0.6331
Overall Mean F-beta Score: 0.7581

Training Loss: 0.3836, Validation Loss: 0.3680, Validation F-beta: 0.7581
Epoch 53/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.352]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.377]


Validation Dice Score
Class 0: 0.9887, Class 1: 0.6854, Class 2: 0.1658, Class 3: 0.3808, Class 4: 0.7761, Class 5: 0.4872, Class 6: 0.9105, 
Validation F-beta Score
Class 0: 0.9842, Class 1: 0.8601, Class 2: 0.2583, Class 3: 0.5522, Class 4: 0.8377, Class 5: 0.6162, Class 6: 0.9716, 

Overall Mean Dice Score: 0.6480
Overall Mean F-beta Score: 0.7675

Training Loss: 0.3865, Validation Loss: 0.3875, Validation F-beta: 0.7675
Epoch 54/4000


Training: 100%|██████████| 204/204 [03:07<00:00,  1.09it/s, loss=0.352]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s, loss=0.42] 


Validation Dice Score
Class 0: 0.9886, Class 1: 0.7017, Class 2: 0.1968, Class 3: 0.3113, Class 4: 0.6976, Class 5: 0.3947, Class 6: 0.8906, 
Validation F-beta Score
Class 0: 0.9829, Class 1: 0.8514, Class 2: 0.3263, Class 3: 0.5151, Class 4: 0.8007, Class 5: 0.6004, Class 6: 0.9412, 

Overall Mean Dice Score: 0.5992
Overall Mean F-beta Score: 0.7418

Training Loss: 0.3837, Validation Loss: 0.4014, Validation F-beta: 0.7418
Epoch 55/4000


Training: 100%|██████████| 204/204 [03:09<00:00,  1.08it/s, loss=0.32] 
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.36] 

Validation Dice Score
Class 0: 0.9890, Class 1: 0.7902, Class 2: 0.1999, Class 3: 0.3765, Class 4: 0.7542, Class 5: 0.4313, Class 6: 0.8952, 
Validation F-beta Score
Class 0: 0.9846, Class 1: 0.8736, Class 2: 0.2996, Class 3: 0.5670, Class 4: 0.7895, Class 5: 0.5914, Class 6: 0.9429, 

Overall Mean Dice Score: 0.6495
Overall Mean F-beta Score: 0.7529

Training Loss: 0.3834, Validation Loss: 0.3716, Validation F-beta: 0.7529
Early stopping





0,1
class_0_dice_score,█▅▇▆█▆▆▆▇▄▆▅▄▅▄▆▆▆█▄▄▆▆▄▄▅▁▆▆▇▅▅▅▆▆▇█▇▇▇
class_0_f_beta_score,█▆▇▇▅▄▅▅▅▆▅▄▄▄▄▅▄▅▆▄▄▅▃▃▂▁▅▄▅▄▄▄▄▆▆▆▅▅▆▆
class_1_dice_score,▁▁▄▆▆▇▇▇▇▇▇█▇▇▇▇█▇██▇▇█▇▇▇▇▇██▇█▇█▇█████
class_1_f_beta_score,▁▂▅▅▆▆▇▇▇▇▇▇▇▇▇█▇▇█▇█▇██▇█▇▇████▇███████
class_2_dice_score,▁▁▂▃▅▆▆▆▄▆▅█▅▅▆▆█▅█▇▇▇▅▆▅▆█▇▆▆▅▆▇▇▃▇▅▅▇▇
class_2_f_beta_score,▁▂▂▄▄▅▆▅▄▅▆▅▆▆▆▆▅▇▆▇▇▇▇▇▆█▇▆▇█▆▇▇▄▇▇▆▆▇█
class_3_dice_score,▁▃▃▃▂▂▅▃▂▃▃▅▃▆▄▄▄▆▆▆▅▅▆▅▅▇▇▇▇▆▆▆▅█▇▇▇▇▇▇
class_3_f_beta_score,▁▃▃▃▆▆▂▂▃▃▅▅▆▅▅▇▇█▇▆▅▇█▇▇▇▇▆▇▇▆██▆▆▇▆▆▇▇
class_4_dice_score,▁▃▃▅▃▄▄▅▄▃▆▄▂▄▃▅▅▅▄▅▆▅▂▆▃▄▃▆▃▆▅▅▆▇▇▇█▅▆█
class_4_f_beta_score,▁▄▃▃▅▆▇▆▆▆▇▆▇▆▆▇▇▇██▇▇▇▇▆▇█▆█▇█▇▆▆▆▅▇▇▇▆

0,1
class_0_dice_score,0.98896
class_0_f_beta_score,0.98455
class_1_dice_score,0.79021
class_1_f_beta_score,0.87358
class_2_dice_score,0.19994
class_2_f_beta_score,0.29959
class_3_dice_score,0.37654
class_3_f_beta_score,0.56702
class_4_dice_score,0.75425
class_4_f_beta_score,0.78952


In [12]:
train_img_dir = './datasets/public_data/images'
train_label_dir = './datasets/public_data/labels'

train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders_bw(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    val_non_random_transforms=non_random_transforms,
    random_transforms = random_transforms, 
    val_random_transforms=val_random_transforms,
    batch_size = loader_batch,
    num_workers=0,
    train_num_repeat=num_repeat
    )

batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

Loading dataset: 100%|██████████| 24/24 [00:02<00:00,  8.40it/s]
Loading dataset: 100%|██████████| 1/1 [00:00<00:00, 13.90it/s]


torch.Size([16, 1, 96, 96, 96]) torch.Size([16, 1, 96, 96, 96])


In [13]:
torch.backends.cudnn.benchmark = True

In [14]:

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")


기존 best model 발견: model_checkpoints\UNET_randGaus_511_241_noclswt_f48_d96s96_numb2_lr1e-03_a0.52_b0.48_b16_r4_ce0.4_ac1\best_model.pt
기존 학습된 가중치를 성공적으로 로드했습니다.


  checkpoint = torch.load(best_model_path, map_location=device)


In [15]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_UNet',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        "attn_drop_rate": attn_drop_rate,
        "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        "num_bottleneck": num_bottleneck,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

In [16]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps
     ) 

Epoch 40/4000


  class_weights = torch.tensor(self.ce.weight)  # CrossEntropy의 weight를 사용
Training: 100%|██████████| 96/96 [01:24<00:00,  1.14it/s, loss=0.384]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.72it/s, loss=0.432]


Validation Dice Score
Class 0: 0.9793, Class 1: 0.7131, Class 2: 0.3147, Class 3: 0.3278, Class 4: 0.5494, Class 5: 0.3526, Class 6: 0.8205, 
Validation F-beta Score
Class 0: 0.9640, Class 1: 0.8745, Class 2: 0.4542, Class 3: 0.4191, Class 4: 0.8662, Class 5: 0.6779, Class 6: 0.9508, 

Overall Mean Dice Score: 0.5527
Overall Mean F-beta Score: 0.7577

Training Loss: 0.4128, Validation Loss: 0.4210, Validation F-beta: 0.7577
Epoch 41/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.11it/s, loss=0.423]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.403]


Validation Dice Score
Class 0: 0.9843, Class 1: 0.6343, Class 2: 0.2167, Class 3: 0.3624, Class 4: 0.6642, Class 5: 0.4190, Class 6: 0.7874, 
Validation F-beta Score
Class 0: 0.9758, Class 1: 0.8716, Class 2: 0.2931, Class 3: 0.5254, Class 4: 0.8078, Class 5: 0.6112, Class 6: 0.9473, 

Overall Mean Dice Score: 0.5735
Overall Mean F-beta Score: 0.7527

Training Loss: 0.4211, Validation Loss: 0.4157, Validation F-beta: 0.7527
Epoch 42/4000


Training: 100%|██████████| 96/96 [01:28<00:00,  1.08it/s, loss=0.437]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, loss=0.372]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7744, Class 2: 0.3218, Class 3: 0.3681, Class 4: 0.6941, Class 5: 0.4286, Class 6: 0.8976, 
Validation F-beta Score
Class 0: 0.9818, Class 1: 0.8799, Class 2: 0.3553, Class 3: 0.5084, Class 4: 0.8463, Class 5: 0.5578, Class 6: 0.9655, 

Overall Mean Dice Score: 0.6326
Overall Mean F-beta Score: 0.7516

Training Loss: 0.4191, Validation Loss: 0.3807, Validation F-beta: 0.7516
Epoch 43/4000


Training: 100%|██████████| 96/96 [01:24<00:00,  1.14it/s, loss=0.437]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.62it/s, loss=0.39] 


Validation Dice Score
Class 0: 0.9882, Class 1: 0.7899, Class 2: 0.1844, Class 3: 0.3285, Class 4: 0.7371, Class 5: 0.4378, Class 6: 0.8349, 
Validation F-beta Score
Class 0: 0.9842, Class 1: 0.8645, Class 2: 0.2687, Class 3: 0.5634, Class 4: 0.7790, Class 5: 0.5455, Class 6: 0.9562, 

Overall Mean Dice Score: 0.6256
Overall Mean F-beta Score: 0.7417

Training Loss: 0.4085, Validation Loss: 0.3938, Validation F-beta: 0.7417
Epoch 44/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.11it/s, loss=0.38] 
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.383]


Validation Dice Score
Class 0: 0.9884, Class 1: 0.7490, Class 2: 0.1795, Class 3: 0.2979, Class 4: 0.6825, Class 5: 0.4203, Class 6: 0.8586, 
Validation F-beta Score
Class 0: 0.9845, Class 1: 0.8622, Class 2: 0.1752, Class 3: 0.3975, Class 4: 0.7656, Class 5: 0.5174, Class 6: 0.9606, 

Overall Mean Dice Score: 0.6016
Overall Mean F-beta Score: 0.7007

Training Loss: 0.4069, Validation Loss: 0.3849, Validation F-beta: 0.7007
Epoch 45/4000


Training: 100%|██████████| 96/96 [01:24<00:00,  1.14it/s, loss=0.402]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.392]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7757, Class 2: 0.1060, Class 3: 0.2840, Class 4: 0.6910, Class 5: 0.4319, Class 6: 0.8838, 
Validation F-beta Score
Class 0: 0.9814, Class 1: 0.8744, Class 2: 0.2407, Class 3: 0.4566, Class 4: 0.8077, Class 5: 0.6014, Class 6: 0.9660, 

Overall Mean Dice Score: 0.6133
Overall Mean F-beta Score: 0.7412

Training Loss: 0.4050, Validation Loss: 0.3938, Validation F-beta: 0.7412
Epoch 46/4000


Training: 100%|██████████| 96/96 [01:24<00:00,  1.14it/s, loss=0.399]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.78it/s, loss=0.363]


Validation Dice Score
Class 0: 0.9906, Class 1: 0.7447, Class 2: 0.2316, Class 3: 0.3594, Class 4: 0.7268, Class 5: 0.4722, Class 6: 0.9133, 
Validation F-beta Score
Class 0: 0.9875, Class 1: 0.8450, Class 2: 0.2825, Class 3: 0.4453, Class 4: 0.7702, Class 5: 0.6046, Class 6: 0.9660, 

Overall Mean Dice Score: 0.6433
Overall Mean F-beta Score: 0.7262

Training Loss: 0.4041, Validation Loss: 0.3791, Validation F-beta: 0.7262
Epoch 47/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.12it/s, loss=0.409]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.357]


Validation Dice Score
Class 0: 0.9894, Class 1: 0.8180, Class 2: 0.2697, Class 3: 0.4423, Class 4: 0.7197, Class 5: 0.4247, Class 6: 0.9093, 
Validation F-beta Score
Class 0: 0.9866, Class 1: 0.8601, Class 2: 0.3265, Class 3: 0.5626, Class 4: 0.7931, Class 5: 0.4615, Class 6: 0.9634, 

Overall Mean Dice Score: 0.6628
Overall Mean F-beta Score: 0.7281

Training Loss: 0.4031, Validation Loss: 0.3703, Validation F-beta: 0.7281
Epoch 48/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.10it/s, loss=0.399]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, loss=0.375]


Validation Dice Score
Class 0: 0.9881, Class 1: 0.7975, Class 2: 0.2704, Class 3: 0.3806, Class 4: 0.7200, Class 5: 0.4471, Class 6: 0.8869, 
Validation F-beta Score
Class 0: 0.9835, Class 1: 0.8741, Class 2: 0.3963, Class 3: 0.4591, Class 4: 0.7690, Class 5: 0.6008, Class 6: 0.9620, 

Overall Mean Dice Score: 0.6464
Overall Mean F-beta Score: 0.7330

Training Loss: 0.4014, Validation Loss: 0.3731, Validation F-beta: 0.7330
Epoch 49/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.10it/s, loss=0.394]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.72it/s, loss=0.379]


Validation Dice Score
Class 0: 0.9891, Class 1: 0.7579, Class 2: 0.2707, Class 3: 0.2696, Class 4: 0.6932, Class 5: 0.5026, Class 6: 0.8920, 
Validation F-beta Score
Class 0: 0.9853, Class 1: 0.8929, Class 2: 0.3348, Class 3: 0.4100, Class 4: 0.7523, Class 5: 0.6121, Class 6: 0.9732, 

Overall Mean Dice Score: 0.6231
Overall Mean F-beta Score: 0.7281

Training Loss: 0.4029, Validation Loss: 0.3829, Validation F-beta: 0.7281
Epoch 50/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.13it/s, loss=0.414]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.35] 


Validation Dice Score
Class 0: 0.9904, Class 1: 0.8233, Class 2: 0.1670, Class 3: 0.3752, Class 4: 0.7154, Class 5: 0.4998, Class 6: 0.9031, 
Validation F-beta Score
Class 0: 0.9882, Class 1: 0.8646, Class 2: 0.1900, Class 3: 0.4559, Class 4: 0.7975, Class 5: 0.5329, Class 6: 0.9697, 

Overall Mean Dice Score: 0.6634
Overall Mean F-beta Score: 0.7241

Training Loss: 0.4004, Validation Loss: 0.3451, Validation F-beta: 0.7241
SUPER Best model saved. Loss:0.3451, Score:0.7241
Epoch 51/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.12it/s, loss=0.402]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, loss=0.391]


Validation Dice Score
Class 0: 0.9892, Class 1: 0.7533, Class 2: 0.2571, Class 3: 0.3844, Class 4: 0.7429, Class 5: 0.4487, Class 6: 0.8718, 
Validation F-beta Score
Class 0: 0.9866, Class 1: 0.8572, Class 2: 0.3757, Class 3: 0.4794, Class 4: 0.7556, Class 5: 0.5362, Class 6: 0.9603, 

Overall Mean Dice Score: 0.6402
Overall Mean F-beta Score: 0.7177

Training Loss: 0.4043, Validation Loss: 0.3883, Validation F-beta: 0.7177
Epoch 52/4000


Training: 100%|██████████| 96/96 [01:24<00:00,  1.13it/s, loss=0.388]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.399]


Validation Dice Score
Class 0: 0.9892, Class 1: 0.6726, Class 2: 0.2163, Class 3: 0.3397, Class 4: 0.7481, Class 5: 0.4288, Class 6: 0.9021, 
Validation F-beta Score
Class 0: 0.9848, Class 1: 0.8274, Class 2: 0.2474, Class 3: 0.4020, Class 4: 0.8127, Class 5: 0.5897, Class 6: 0.9711, 

Overall Mean Dice Score: 0.6183
Overall Mean F-beta Score: 0.7206

Training Loss: 0.4018, Validation Loss: 0.3918, Validation F-beta: 0.7206
Epoch 53/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.13it/s, loss=0.388]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.369]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.7909, Class 2: 0.1894, Class 3: 0.4442, Class 4: 0.7280, Class 5: 0.4722, Class 6: 0.8921, 
Validation F-beta Score
Class 0: 0.9845, Class 1: 0.8421, Class 2: 0.2924, Class 3: 0.5485, Class 4: 0.7829, Class 5: 0.6186, Class 6: 0.9622, 

Overall Mean Dice Score: 0.6655
Overall Mean F-beta Score: 0.7509

Training Loss: 0.4002, Validation Loss: 0.3600, Validation F-beta: 0.7509
Epoch 54/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.11it/s, loss=0.415]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.398]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.8061, Class 2: 0.2889, Class 3: 0.2846, Class 4: 0.6608, Class 5: 0.4727, Class 6: 0.8939, 
Validation F-beta Score
Class 0: 0.9858, Class 1: 0.8818, Class 2: 0.3227, Class 3: 0.4386, Class 4: 0.8066, Class 5: 0.5795, Class 6: 0.9719, 

Overall Mean Dice Score: 0.6236
Overall Mean F-beta Score: 0.7357

Training Loss: 0.4043, Validation Loss: 0.3827, Validation F-beta: 0.7357
Epoch 55/4000


Training: 100%|██████████| 96/96 [01:41<00:00,  1.05s/it, loss=0.398]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.53it/s, loss=0.401]


Validation Dice Score
Class 0: 0.9892, Class 1: 0.7689, Class 2: 0.2669, Class 3: 0.4381, Class 4: 0.6991, Class 5: 0.4890, Class 6: 0.8964, 
Validation F-beta Score
Class 0: 0.9846, Class 1: 0.8433, Class 2: 0.3706, Class 3: 0.6392, Class 4: 0.8118, Class 5: 0.6056, Class 6: 0.9568, 

Overall Mean Dice Score: 0.6583
Overall Mean F-beta Score: 0.7713

Training Loss: 0.3981, Validation Loss: 0.3756, Validation F-beta: 0.7713
Epoch 56/4000


Training: 100%|██████████| 96/96 [01:33<00:00,  1.03it/s, loss=0.412]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.60it/s, loss=0.352]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.7814, Class 2: 0.2101, Class 3: 0.3673, Class 4: 0.7058, Class 5: 0.4794, Class 6: 0.8930, 
Validation F-beta Score
Class 0: 0.9850, Class 1: 0.8385, Class 2: 0.3607, Class 3: 0.5210, Class 4: 0.7881, Class 5: 0.5710, Class 6: 0.9632, 

Overall Mean Dice Score: 0.6454
Overall Mean F-beta Score: 0.7364

Training Loss: 0.4006, Validation Loss: 0.3734, Validation F-beta: 0.7364
Epoch 57/4000


Training: 100%|██████████| 96/96 [01:32<00:00,  1.04it/s, loss=0.423]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.63it/s, loss=0.381]


Validation Dice Score
Class 0: 0.9888, Class 1: 0.7910, Class 2: 0.1396, Class 3: 0.3791, Class 4: 0.6912, Class 5: 0.4107, Class 6: 0.8685, 
Validation F-beta Score
Class 0: 0.9854, Class 1: 0.8668, Class 2: 0.1733, Class 3: 0.4846, Class 4: 0.7667, Class 5: 0.5057, Class 6: 0.9573, 

Overall Mean Dice Score: 0.6281
Overall Mean F-beta Score: 0.7162

Training Loss: 0.3985, Validation Loss: 0.3776, Validation F-beta: 0.7162
Epoch 58/4000


Training: 100%|██████████| 96/96 [01:31<00:00,  1.04it/s, loss=0.401]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.56it/s, loss=0.361]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.7593, Class 2: 0.2952, Class 3: 0.3495, Class 4: 0.7467, Class 5: 0.4574, Class 6: 0.9175, 
Validation F-beta Score
Class 0: 0.9872, Class 1: 0.8788, Class 2: 0.3497, Class 3: 0.4666, Class 4: 0.7685, Class 5: 0.5504, Class 6: 0.9616, 

Overall Mean Dice Score: 0.6461
Overall Mean F-beta Score: 0.7252

Training Loss: 0.3972, Validation Loss: 0.3849, Validation F-beta: 0.7252
Epoch 59/4000


Training: 100%|██████████| 96/96 [01:28<00:00,  1.09it/s, loss=0.434]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.34] 


Validation Dice Score
Class 0: 0.9894, Class 1: 0.7875, Class 2: 0.2023, Class 3: 0.3551, Class 4: 0.7163, Class 5: 0.4662, Class 6: 0.8908, 
Validation F-beta Score
Class 0: 0.9866, Class 1: 0.8510, Class 2: 0.2434, Class 3: 0.4587, Class 4: 0.7773, Class 5: 0.5233, Class 6: 0.9633, 

Overall Mean Dice Score: 0.6432
Overall Mean F-beta Score: 0.7147

Training Loss: 0.3996, Validation Loss: 0.3672, Validation F-beta: 0.7147
Epoch 60/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.11it/s, loss=0.373]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, loss=0.378]


Validation Dice Score
Class 0: 0.9896, Class 1: 0.8106, Class 2: 0.3398, Class 3: 0.3831, Class 4: 0.7506, Class 5: 0.4648, Class 6: 0.8913, 
Validation F-beta Score
Class 0: 0.9858, Class 1: 0.8603, Class 2: 0.4056, Class 3: 0.4614, Class 4: 0.8028, Class 5: 0.6024, Class 6: 0.9603, 

Overall Mean Dice Score: 0.6601
Overall Mean F-beta Score: 0.7374

Training Loss: 0.3988, Validation Loss: 0.3779, Validation F-beta: 0.7374
Epoch 61/4000


Training: 100%|██████████| 96/96 [01:28<00:00,  1.09it/s, loss=0.394]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, loss=0.35] 


Validation Dice Score
Class 0: 0.9902, Class 1: 0.7933, Class 2: 0.3355, Class 3: 0.4612, Class 4: 0.7342, Class 5: 0.4827, Class 6: 0.8897, 
Validation F-beta Score
Class 0: 0.9866, Class 1: 0.8610, Class 2: 0.4419, Class 3: 0.6220, Class 4: 0.8033, Class 5: 0.6027, Class 6: 0.9567, 

Overall Mean Dice Score: 0.6722
Overall Mean F-beta Score: 0.7691

Training Loss: 0.4012, Validation Loss: 0.3646, Validation F-beta: 0.7691
Epoch 62/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.09it/s, loss=0.407]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.62it/s, loss=0.345]


Validation Dice Score
Class 0: 0.9900, Class 1: 0.7833, Class 2: 0.2132, Class 3: 0.4557, Class 4: 0.7212, Class 5: 0.5111, Class 6: 0.9019, 
Validation F-beta Score
Class 0: 0.9856, Class 1: 0.8706, Class 2: 0.2990, Class 3: 0.5576, Class 4: 0.8264, Class 5: 0.6430, Class 6: 0.9590, 

Overall Mean Dice Score: 0.6746
Overall Mean F-beta Score: 0.7713

Training Loss: 0.3994, Validation Loss: 0.3580, Validation F-beta: 0.7713
Epoch 63/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.11it/s, loss=0.423]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, loss=0.349]


Validation Dice Score
Class 0: 0.9908, Class 1: 0.8062, Class 2: 0.3805, Class 3: 0.4451, Class 4: 0.7506, Class 5: 0.4731, Class 6: 0.9104, 
Validation F-beta Score
Class 0: 0.9872, Class 1: 0.8778, Class 2: 0.4728, Class 3: 0.6083, Class 4: 0.8362, Class 5: 0.5814, Class 6: 0.9493, 

Overall Mean Dice Score: 0.6771
Overall Mean F-beta Score: 0.7706

Training Loss: 0.3978, Validation Loss: 0.3541, Validation F-beta: 0.7706
Epoch 64/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.12it/s, loss=0.404]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.352]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.8352, Class 2: 0.2747, Class 3: 0.4710, Class 4: 0.7440, Class 5: 0.5097, Class 6: 0.9086, 
Validation F-beta Score
Class 0: 0.9873, Class 1: 0.8773, Class 2: 0.3484, Class 3: 0.5923, Class 4: 0.8069, Class 5: 0.5987, Class 6: 0.9603, 

Overall Mean Dice Score: 0.6937
Overall Mean F-beta Score: 0.7671

Training Loss: 0.3997, Validation Loss: 0.3598, Validation F-beta: 0.7671
Epoch 65/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.10it/s, loss=0.416]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.389]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.7923, Class 2: 0.2320, Class 3: 0.3728, Class 4: 0.7452, Class 5: 0.4671, Class 6: 0.9009, 
Validation F-beta Score
Class 0: 0.9862, Class 1: 0.8671, Class 2: 0.3367, Class 3: 0.4750, Class 4: 0.7849, Class 5: 0.5682, Class 6: 0.9677, 

Overall Mean Dice Score: 0.6557
Overall Mean F-beta Score: 0.7326

Training Loss: 0.4004, Validation Loss: 0.3776, Validation F-beta: 0.7326
Epoch 66/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.10it/s, loss=0.411]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.353]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.7672, Class 2: 0.2257, Class 3: 0.4072, Class 4: 0.7514, Class 5: 0.4837, Class 6: 0.8964, 
Validation F-beta Score
Class 0: 0.9877, Class 1: 0.8579, Class 2: 0.2779, Class 3: 0.5488, Class 4: 0.7844, Class 5: 0.5670, Class 6: 0.9552, 

Overall Mean Dice Score: 0.6612
Overall Mean F-beta Score: 0.7427

Training Loss: 0.3991, Validation Loss: 0.3641, Validation F-beta: 0.7427
Epoch 67/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.10it/s, loss=0.416]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.64it/s, loss=0.366]


Validation Dice Score
Class 0: 0.9891, Class 1: 0.7417, Class 2: 0.2333, Class 3: 0.4488, Class 4: 0.7324, Class 5: 0.4828, Class 6: 0.8804, 
Validation F-beta Score
Class 0: 0.9860, Class 1: 0.8820, Class 2: 0.2973, Class 3: 0.5583, Class 4: 0.7766, Class 5: 0.5725, Class 6: 0.9449, 

Overall Mean Dice Score: 0.6572
Overall Mean F-beta Score: 0.7469

Training Loss: 0.3988, Validation Loss: 0.3580, Validation F-beta: 0.7469
Epoch 68/4000


Training: 100%|██████████| 96/96 [01:28<00:00,  1.09it/s, loss=0.402]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.65it/s, loss=0.352]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.8358, Class 2: 0.3023, Class 3: 0.4353, Class 4: 0.6781, Class 5: 0.4547, Class 6: 0.8991, 
Validation F-beta Score
Class 0: 0.9863, Class 1: 0.8721, Class 2: 0.3608, Class 3: 0.5872, Class 4: 0.7817, Class 5: 0.5805, Class 6: 0.9608, 

Overall Mean Dice Score: 0.6606
Overall Mean F-beta Score: 0.7565

Training Loss: 0.3980, Validation Loss: 0.3560, Validation F-beta: 0.7565
Epoch 69/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.12it/s, loss=0.39] 
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.347]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.8119, Class 2: 0.2534, Class 3: 0.4125, Class 4: 0.7488, Class 5: 0.4482, Class 6: 0.9172, 
Validation F-beta Score
Class 0: 0.9870, Class 1: 0.8609, Class 2: 0.3314, Class 3: 0.5665, Class 4: 0.7848, Class 5: 0.5639, Class 6: 0.9662, 

Overall Mean Dice Score: 0.6677
Overall Mean F-beta Score: 0.7485

Training Loss: 0.3983, Validation Loss: 0.3675, Validation F-beta: 0.7485
Epoch 70/4000


Training: 100%|██████████| 96/96 [01:24<00:00,  1.13it/s, loss=0.391]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.376]


Validation Dice Score
Class 0: 0.9897, Class 1: 0.7621, Class 2: 0.2558, Class 3: 0.4118, Class 4: 0.6858, Class 5: 0.4707, Class 6: 0.8806, 
Validation F-beta Score
Class 0: 0.9861, Class 1: 0.8741, Class 2: 0.4260, Class 3: 0.5871, Class 4: 0.7510, Class 5: 0.5570, Class 6: 0.9474, 

Overall Mean Dice Score: 0.6422
Overall Mean F-beta Score: 0.7433

Training Loss: 0.3996, Validation Loss: 0.3746, Validation F-beta: 0.7433
Epoch 71/4000


Training: 100%|██████████| 96/96 [01:24<00:00,  1.13it/s, loss=0.397]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.74it/s, loss=0.351]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.7805, Class 2: 0.2243, Class 3: 0.4653, Class 4: 0.7446, Class 5: 0.5696, Class 6: 0.8746, 
Validation F-beta Score
Class 0: 0.9872, Class 1: 0.8694, Class 2: 0.2682, Class 3: 0.5584, Class 4: 0.7941, Class 5: 0.6324, Class 6: 0.9583, 

Overall Mean Dice Score: 0.6869
Overall Mean F-beta Score: 0.7625

Training Loss: 0.3962, Validation Loss: 0.3587, Validation F-beta: 0.7625
Epoch 72/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.10it/s, loss=0.402]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.372]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.7921, Class 2: 0.1912, Class 3: 0.4048, Class 4: 0.7045, Class 5: 0.5095, Class 6: 0.9018, 
Validation F-beta Score
Class 0: 0.9843, Class 1: 0.8779, Class 2: 0.2480, Class 3: 0.5506, Class 4: 0.7916, Class 5: 0.6212, Class 6: 0.9530, 

Overall Mean Dice Score: 0.6625
Overall Mean F-beta Score: 0.7589

Training Loss: 0.3984, Validation Loss: 0.3610, Validation F-beta: 0.7589
Epoch 73/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.11it/s, loss=0.391]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.71it/s, loss=0.367]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.7698, Class 2: 0.3336, Class 3: 0.3751, Class 4: 0.7166, Class 5: 0.4290, Class 6: 0.8946, 
Validation F-beta Score
Class 0: 0.9857, Class 1: 0.8390, Class 2: 0.4243, Class 3: 0.5643, Class 4: 0.7874, Class 5: 0.5815, Class 6: 0.9609, 

Overall Mean Dice Score: 0.6370
Overall Mean F-beta Score: 0.7466

Training Loss: 0.3984, Validation Loss: 0.3804, Validation F-beta: 0.7466
Epoch 74/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.12it/s, loss=0.413]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.339]


Validation Dice Score
Class 0: 0.9897, Class 1: 0.7513, Class 2: 0.2124, Class 3: 0.4361, Class 4: 0.7441, Class 5: 0.4491, Class 6: 0.8794, 
Validation F-beta Score
Class 0: 0.9859, Class 1: 0.8326, Class 2: 0.2636, Class 3: 0.5294, Class 4: 0.7942, Class 5: 0.6057, Class 6: 0.9527, 

Overall Mean Dice Score: 0.6520
Overall Mean F-beta Score: 0.7429

Training Loss: 0.3988, Validation Loss: 0.3723, Validation F-beta: 0.7429
Epoch 75/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.12it/s, loss=0.411]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.353]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.7866, Class 2: 0.3616, Class 3: 0.4417, Class 4: 0.7379, Class 5: 0.4949, Class 6: 0.8945, 
Validation F-beta Score
Class 0: 0.9867, Class 1: 0.8530, Class 2: 0.3882, Class 3: 0.5494, Class 4: 0.8039, Class 5: 0.5913, Class 6: 0.9582, 

Overall Mean Dice Score: 0.6711
Overall Mean F-beta Score: 0.7512

Training Loss: 0.3999, Validation Loss: 0.3666, Validation F-beta: 0.7512
Epoch 76/4000


Training: 100%|██████████| 96/96 [01:30<00:00,  1.07it/s, loss=0.407]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.345]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.8242, Class 2: 0.2849, Class 3: 0.4501, Class 4: 0.7084, Class 5: 0.4881, Class 6: 0.9056, 
Validation F-beta Score
Class 0: 0.9863, Class 1: 0.8976, Class 2: 0.3395, Class 3: 0.5507, Class 4: 0.8182, Class 5: 0.6292, Class 6: 0.9629, 

Overall Mean Dice Score: 0.6753
Overall Mean F-beta Score: 0.7717

Training Loss: 0.4019, Validation Loss: 0.3436, Validation F-beta: 0.7717
SUPER Best model saved. Loss:0.3436, Score:0.7717
Epoch 77/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.12it/s, loss=0.391]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s, loss=0.353]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.7885, Class 2: 0.2732, Class 3: 0.3917, Class 4: 0.7517, Class 5: 0.4362, Class 6: 0.9105, 
Validation F-beta Score
Class 0: 0.9859, Class 1: 0.8734, Class 2: 0.3615, Class 3: 0.5680, Class 4: 0.8541, Class 5: 0.5560, Class 6: 0.9568, 

Overall Mean Dice Score: 0.6557
Overall Mean F-beta Score: 0.7617

Training Loss: 0.4005, Validation Loss: 0.3710, Validation F-beta: 0.7617
Epoch 78/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.10it/s, loss=0.377]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.389]


Validation Dice Score
Class 0: 0.9910, Class 1: 0.7782, Class 2: 0.2944, Class 3: 0.4442, Class 4: 0.6930, Class 5: 0.4858, Class 6: 0.8826, 
Validation F-beta Score
Class 0: 0.9873, Class 1: 0.8624, Class 2: 0.3887, Class 3: 0.5718, Class 4: 0.7732, Class 5: 0.6503, Class 6: 0.9507, 

Overall Mean Dice Score: 0.6568
Overall Mean F-beta Score: 0.7617

Training Loss: 0.3983, Validation Loss: 0.3730, Validation F-beta: 0.7617
Epoch 79/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.09it/s, loss=0.403]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.371]


Validation Dice Score
Class 0: 0.9896, Class 1: 0.7853, Class 2: 0.3060, Class 3: 0.4077, Class 4: 0.6865, Class 5: 0.4583, Class 6: 0.9148, 
Validation F-beta Score
Class 0: 0.9861, Class 1: 0.8627, Class 2: 0.3830, Class 3: 0.5688, Class 4: 0.7740, Class 5: 0.5365, Class 6: 0.9683, 

Overall Mean Dice Score: 0.6505
Overall Mean F-beta Score: 0.7421

Training Loss: 0.4004, Validation Loss: 0.3608, Validation F-beta: 0.7421
Epoch 80/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.11it/s, loss=0.404]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.369]


Validation Dice Score
Class 0: 0.9909, Class 1: 0.8033, Class 2: 0.1804, Class 3: 0.3695, Class 4: 0.7490, Class 5: 0.4975, Class 6: 0.8804, 
Validation F-beta Score
Class 0: 0.9874, Class 1: 0.8842, Class 2: 0.2381, Class 3: 0.5147, Class 4: 0.8108, Class 5: 0.6218, Class 6: 0.9381, 

Overall Mean Dice Score: 0.6599
Overall Mean F-beta Score: 0.7539

Training Loss: 0.3989, Validation Loss: 0.3613, Validation F-beta: 0.7539
Epoch 81/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.10it/s, loss=0.418]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, loss=0.402]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.7554, Class 2: 0.1747, Class 3: 0.3488, Class 4: 0.7252, Class 5: 0.5089, Class 6: 0.8701, 
Validation F-beta Score
Class 0: 0.9856, Class 1: 0.8552, Class 2: 0.2064, Class 3: 0.4661, Class 4: 0.8302, Class 5: 0.6203, Class 6: 0.9647, 

Overall Mean Dice Score: 0.6417
Overall Mean F-beta Score: 0.7473

Training Loss: 0.3993, Validation Loss: 0.3805, Validation F-beta: 0.7473
Epoch 82/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.11it/s, loss=0.39] 
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.324]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.7955, Class 2: 0.2612, Class 3: 0.4987, Class 4: 0.7414, Class 5: 0.5123, Class 6: 0.8768, 
Validation F-beta Score
Class 0: 0.9854, Class 1: 0.8728, Class 2: 0.3495, Class 3: 0.6340, Class 4: 0.8279, Class 5: 0.6267, Class 6: 0.9495, 

Overall Mean Dice Score: 0.6849
Overall Mean F-beta Score: 0.7822

Training Loss: 0.3984, Validation Loss: 0.3512, Validation F-beta: 0.7822
Epoch 83/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.12it/s, loss=0.42] 
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s, loss=0.377]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.7862, Class 2: 0.2824, Class 3: 0.4541, Class 4: 0.7499, Class 5: 0.4635, Class 6: 0.8827, 
Validation F-beta Score
Class 0: 0.9867, Class 1: 0.8510, Class 2: 0.4098, Class 3: 0.6499, Class 4: 0.7666, Class 5: 0.5910, Class 6: 0.9350, 

Overall Mean Dice Score: 0.6673
Overall Mean F-beta Score: 0.7587

Training Loss: 0.3959, Validation Loss: 0.3775, Validation F-beta: 0.7587
Epoch 84/4000


Training: 100%|██████████| 96/96 [01:21<00:00,  1.18it/s, loss=0.397]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.76it/s, loss=0.409]


Validation Dice Score
Class 0: 0.9897, Class 1: 0.7124, Class 2: 0.1024, Class 3: 0.3557, Class 4: 0.7351, Class 5: 0.5482, Class 6: 0.8596, 
Validation F-beta Score
Class 0: 0.9865, Class 1: 0.8449, Class 2: 0.1685, Class 3: 0.5685, Class 4: 0.7830, Class 5: 0.6188, Class 6: 0.9643, 

Overall Mean Dice Score: 0.6422
Overall Mean F-beta Score: 0.7559

Training Loss: 0.3977, Validation Loss: 0.3967, Validation F-beta: 0.7559
Epoch 85/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.10it/s, loss=0.415]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.62it/s, loss=0.358]


Validation Dice Score
Class 0: 0.9893, Class 1: 0.8060, Class 2: 0.2213, Class 3: 0.3804, Class 4: 0.7519, Class 5: 0.5198, Class 6: 0.8888, 
Validation F-beta Score
Class 0: 0.9861, Class 1: 0.8903, Class 2: 0.2652, Class 3: 0.4846, Class 4: 0.8094, Class 5: 0.6051, Class 6: 0.9533, 

Overall Mean Dice Score: 0.6694
Overall Mean F-beta Score: 0.7486

Training Loss: 0.3988, Validation Loss: 0.3593, Validation F-beta: 0.7486
Epoch 86/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.09it/s, loss=0.407]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, loss=0.337]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.7923, Class 2: 0.2510, Class 3: 0.4664, Class 4: 0.7420, Class 5: 0.5006, Class 6: 0.9068, 
Validation F-beta Score
Class 0: 0.9870, Class 1: 0.8532, Class 2: 0.3483, Class 3: 0.5755, Class 4: 0.8240, Class 5: 0.5977, Class 6: 0.9697, 

Overall Mean Dice Score: 0.6816
Overall Mean F-beta Score: 0.7640

Training Loss: 0.3980, Validation Loss: 0.3519, Validation F-beta: 0.7640
Epoch 87/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.09it/s, loss=0.386]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.385]


Validation Dice Score
Class 0: 0.9892, Class 1: 0.8179, Class 2: 0.2239, Class 3: 0.4049, Class 4: 0.7442, Class 5: 0.4213, Class 6: 0.8835, 
Validation F-beta Score
Class 0: 0.9857, Class 1: 0.8682, Class 2: 0.3589, Class 3: 0.5865, Class 4: 0.7831, Class 5: 0.5323, Class 6: 0.9460, 

Overall Mean Dice Score: 0.6544
Overall Mean F-beta Score: 0.7432

Training Loss: 0.3993, Validation Loss: 0.3823, Validation F-beta: 0.7432
Epoch 88/4000


Training: 100%|██████████| 96/96 [01:26<00:00,  1.11it/s, loss=0.408]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s, loss=0.349]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.8062, Class 2: 0.1358, Class 3: 0.4094, Class 4: 0.7053, Class 5: 0.4712, Class 6: 0.8999, 
Validation F-beta Score
Class 0: 0.9856, Class 1: 0.8773, Class 2: 0.2215, Class 3: 0.5268, Class 4: 0.7871, Class 5: 0.6140, Class 6: 0.9651, 

Overall Mean Dice Score: 0.6584
Overall Mean F-beta Score: 0.7541

Training Loss: 0.3983, Validation Loss: 0.3604, Validation F-beta: 0.7541
Epoch 89/4000


Training: 100%|██████████| 96/96 [01:25<00:00,  1.12it/s, loss=0.404]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.73it/s, loss=0.358]


Validation Dice Score
Class 0: 0.9896, Class 1: 0.7794, Class 2: 0.2480, Class 3: 0.3992, Class 4: 0.7255, Class 5: 0.4605, Class 6: 0.8888, 
Validation F-beta Score
Class 0: 0.9853, Class 1: 0.8813, Class 2: 0.3026, Class 3: 0.5324, Class 4: 0.8341, Class 5: 0.5740, Class 6: 0.9615, 

Overall Mean Dice Score: 0.6507
Overall Mean F-beta Score: 0.7566

Training Loss: 0.4010, Validation Loss: 0.3593, Validation F-beta: 0.7566
Epoch 90/4000


Training: 100%|██████████| 96/96 [01:24<00:00,  1.14it/s, loss=0.399]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.72it/s, loss=0.372]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.7476, Class 2: 0.2487, Class 3: 0.4066, Class 4: 0.7805, Class 5: 0.5046, Class 6: 0.8979, 
Validation F-beta Score
Class 0: 0.9869, Class 1: 0.8600, Class 2: 0.3425, Class 3: 0.5626, Class 4: 0.8180, Class 5: 0.5876, Class 6: 0.9759, 

Overall Mean Dice Score: 0.6675
Overall Mean F-beta Score: 0.7608

Training Loss: 0.3980, Validation Loss: 0.3799, Validation F-beta: 0.7608
Epoch 91/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.10it/s, loss=0.404]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, loss=0.409]


Validation Dice Score
Class 0: 0.9900, Class 1: 0.7441, Class 2: 0.2954, Class 3: 0.3742, Class 4: 0.7088, Class 5: 0.4217, Class 6: 0.8913, 
Validation F-beta Score
Class 0: 0.9857, Class 1: 0.8428, Class 2: 0.4534, Class 3: 0.5666, Class 4: 0.7836, Class 5: 0.5878, Class 6: 0.9514, 

Overall Mean Dice Score: 0.6280
Overall Mean F-beta Score: 0.7464

Training Loss: 0.3991, Validation Loss: 0.3896, Validation F-beta: 0.7464
Epoch 92/4000


Training: 100%|██████████| 96/96 [01:27<00:00,  1.10it/s, loss=0.388]
Validation: 100%|██████████| 3/3 [00:01<00:00,  1.68it/s, loss=0.355]

Validation Dice Score
Class 0: 0.9895, Class 1: 0.8093, Class 2: 0.2713, Class 3: 0.4178, Class 4: 0.7519, Class 5: 0.4325, Class 6: 0.8945, 
Validation F-beta Score
Class 0: 0.9858, Class 1: 0.8699, Class 2: 0.3837, Class 3: 0.5612, Class 4: 0.7828, Class 5: 0.5757, Class 6: 0.9573, 

Overall Mean Dice Score: 0.6612
Overall Mean F-beta Score: 0.7494

Training Loss: 0.3992, Validation Loss: 0.3662, Validation F-beta: 0.7494
Early stopping





0,1
class_0_dice_score,▁▄▆▆▆█▇▆▇▇▇▇▇▇▇▇███▇█▇▇▇▇▇███▇▇▇▇▇▇▇▇▇▇▇
class_0_f_beta_score,▁▆▇▇▆█▇▇██▇▇▇▇▇█▇█▇█▇█▇▇██▇▇█▇█▇█▇▇█▇▇█▇
class_1_dice_score,▄▁▆▆▅▅▇▇▅█▂▆▇▆▆▅▆▇▇▆▆▆▅█▇▆▆▆▅▆▆▆▆▇▅▄▇▇▆▅
class_1_f_beta_score,▅▆▅▄▆▄▆█▄▁▆▃▂▅▆▄▄▅▆▆▄▆▅▄▆▂█▆▄▅▄▆▃▃▇▅▆▆▄▅
class_2_dice_score,▆▄▇▃▁▅▅▅▃▅▃▆▅▄▂▄▇▄█▅▄▄▆▅▅▃▇▄█▆▃▅▆▁▄▄▂▅▅▅
class_2_f_beta_score,█▄▃▁▃▅▆▅▁▆▄▅▆▅▃▄█▅▅▄▅▅▇▃▇▆▅▅▆▆▂▅▇▁▃▅▂▄▅▆
class_3_dice_score,▃▄▄▃▂▄▁▄▅▃▁▆▄▄▃▄▇▇▇▄▆▆▅▅▇▆▆▇▅▅▃█▇▄▄▅▅▅▅▆
class_3_f_beta_score,▂▅▄▆▁▂▆▃▁▃▁▅▂█▄▃▃▇▅▆▅▅▆▆▅▆▅▅▅▆▆██▆▃▆▅▅▆▆
class_4_dice_score,▁▄▅▅▆▅▆▇▇▆▆▆▅▇▆▇▆▇▇▇▇▅▇▆▆▇▆▇▅▅▆▇▇▇▇▇▆▆█▇
class_4_f_beta_score,█▄▇▃▂▂▄▂▁▄▅▃▄▅▃▂▃▄▄▆▄▃▃▃▃▁▄▃▃▄▅▇▂▅▆▃▅▃▆▃

0,1
class_0_dice_score,0.98954
class_0_f_beta_score,0.98581
class_1_dice_score,0.80926
class_1_f_beta_score,0.8699
class_2_dice_score,0.27132
class_2_f_beta_score,0.38371
class_3_dice_score,0.41778
class_3_f_beta_score,0.56121
class_4_dice_score,0.75191
class_4_f_beta_score,0.78279


In [17]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# VAl

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import TverskyLoss
import torch
import numpy as np
from tqdm import tqdm
import wandb
from src.dataset.dataset import make_val_dataloader

val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
lamda = 0.52

wandb.init(
    project='czii_SwinUnetR_val',  # 프로젝트 이름 설정
    name='SwinUNETR96_96_lr0.001_lambda0.52_batch2',         # 실행(run) 이름 설정
    config={
        'learning_rate': 0.001,
        'batch_size': batch_size,
        'lambda': lamda,
        'img_size': img_size,
        'device': 'cuda',
        "checkpoint_dir": "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2",
        
    }
)

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

val_loader = make_val_dataloader(
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0
)
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    softmax=True
)
    
    
from monai.metrics import DiceMetric

img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

val_loss, overall_mean_fbeta_score = validate_one_epoch(
    model=model, 
    val_loader=val_loader, 
    criterion=criterion, 
    device=device, 
    epoch=0, 
    calculate_dice_interval=1
)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
class_0_dice_score,▁
class_0_f_beta_score,▁
class_1_dice_score,▁
class_1_f_beta_score,▁
class_2_dice_score,▁
class_2_f_beta_score,▁
class_3_dice_score,▁
class_3_f_beta_score,▁
class_4_dice_score,▁
class_4_f_beta_score,▁

0,1
class_0_dice_score,0.65703
class_0_f_beta_score,0.50748
class_1_dice_score,0.53332
class_1_f_beta_score,0.64703
class_2_dice_score,0.00286
class_2_f_beta_score,0.02334
class_3_dice_score,0.23703
class_3_f_beta_score,0.23033
class_4_dice_score,0.65487
class_4_f_beta_score,0.62525


Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]
  checkpoint = torch.load(pretrain_path, map_location=device)
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s, loss=0.865]

Validation Dice Score
Class 0: 0.6570, Class 1: 0.5333, Class 2: 0.0029, Class 3: 0.2370, 
Class 4: 0.6549, Class 5: 0.4790, Class 6: 0.4255, 
Validation F-beta Score
Class 0: 0.5075, Class 1: 0.6470, Class 2: 0.0233, Class 3: 0.2303, 
Class 4: 0.6252, Class 5: 0.5145, Class 6: 0.4720, 
Overall Mean Dice Score: 0.4659
Overall Mean F-beta Score: 0.4978






: 

: 

# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

: 

: 

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


: 

: 

In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


: 

: 

In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

: 

: 

In [None]:
val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=1, 
            calculate_dice_interval=0
        )

Validation:   0%|          | 0/4 [00:03<?, ?it/s, loss=0.764]


ZeroDivisionError: integer modulo by zero

: 

: 

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv


: 

: 