In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

import random
import numpy as np
import torch


from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

# 랜덤 시드 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


print_config()

MONAI version: 1.4.0
Numpy version: 1.26.3
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: c:\ProgramData\anaconda3\envs\ship\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 5.2.0
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the 

In [2]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 1500}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 1500},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.16393442622950818, 2: 0.01639344262295082, 3: 0.2459016393442623, 4: 0.16393442622950818, 5: 0.2459016393442623, 6: 0.16393442622950818}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.16393442622950818, 0.01639344262295082, 0.2459016393442623, 0.16393442622950818, 0.2459016393442623, 0.16393442622950818]


# 모델 설정

In [None]:
from src.dataset.dataset import create_dataloaders
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.transforms import CastToTyped
import numpy as np

train_img_dir = "./datasets/train/images"
train_label_dir = "./datasets/train/labels"
val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
# DATA CONFIG
img_size =  96 # Match your patch size
img_depth = img_size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
num_repeat = 20
# MODEL CONFIG
num_epochs = 4000
lamda = 0.5
ce_weight = 0.4
lr = 0.001
feature_size = 48
use_checkpoint = True
use_v2 = True
drop_rate= 0.25
attn_drop_rate = 0.25
# CLASS_WEIGHTS
class_weights = None
class_weights = torch.tensor([0.0001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치

accumulation_steps = 8
# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    CastToTyped(keys=["image"], dtype=np.float16),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
])


In [4]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0,train_num_repeat=num_repeat)

Loading dataset: 100%|██████████| 24/24 [00:42<00:00,  1.77s/it]
Loading dataset: 100%|██████████| 4/4 [00:08<00:00,  2.08s/it]


https://monai.io/model-zoo.html

In [5]:
from monai.losses import TverskyLoss
import torch
import torch.nn as nn
from MSSSIM_Loss import MSSSIM

def loss_fn(loss, class_weights, device):
    """
    Tversky 손실에 클래스별 가중치를 적용하여 최종 스칼라 값을 반환합니다.

    Args:
        loss: Tversky 손실 텐서 (B, num_classes, H, W, D).
        class_weights: 클래스별 가중치 텐서 (num_classes,).
        device: 사용할 장치 (예: 'cuda' 또는 'cpu').

    Returns:
        torch.Tensor: 최종 가중 평균 손실 값 (스칼라).
    """
    # 가중치를 device로 이동
    class_weights = class_weights.to(device)

    # 클래스 차원에 가중치 적용 (B, num_classes, ...)
    class_weights = class_weights.view(1, n_classes, 1, 1, 1)  # [1, num_classes, 1, 1, 1]
    weighted_loss = loss * class_weights

    # 모든 차원을 평균 내어 스칼라 손실 반환
    final_loss = torch.mean(weighted_loss)
    return final_loss

class DynamicTverskyLoss(TverskyLoss):
    def __init__(self, lamda=0.5, **kwargs):
        super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
        self.lamda = lamda

    def set_lamda(self, lamda):
        self.lamda = lamda
        self.alpha = 1 - lamda
        self.beta = lamda
        
# criterion = DynamicTverskyLoss(
#     lamda=0.5,
#     include_background=False,
#     reduction="mean",
#     softmax=True
# )

class CombinedLoss(nn.Module):
    def __init__(self, lamda=0.5, ce_weight=0.3, tversky_weight=0.3, ssim_weight=0.4, **kwargs):
        super().__init__()
        self._lamda = lamda
        self.tversky = DynamicTverskyLoss(lamda=lamda, **kwargs)
        self.ce = nn.CrossEntropyLoss()
        self.ssim = MSSSIM(window_size=11, size_average=True, channel=1)
        
        # 가중치 설정 (합이 1이 되도록)
        total = ce_weight + tversky_weight + ssim_weight
        self.ce_weight = ce_weight / total
        self.tversky_weight = tversky_weight / total
        self.ssim_weight = ssim_weight / total
        
    def forward(self, inputs, targets):
        # CE Loss와 Tversky Loss 계산
        ce_loss = self.ce(inputs, targets)
        tversky_loss = self.tversky(inputs, targets)
        
        # SSIM Loss 계산 (1 - SSIM)
        # softmax 적용 후 관심 있는 클래스에 대해서만 SSIM 계산
        probs = torch.softmax(inputs, dim=1)
        ssim_loss = 1 - self.ssim(probs[:, 1:], targets.unsqueeze(1).float())
        
        # 가중 평균 계산
        total_loss = (
            self.ce_weight * ce_loss + 
            self.tversky_weight * tversky_loss + 
            self.ssim_weight * ssim_loss
        )
        
        return total_loss
    
    def set_lamda(self, lamda):
        self._lamda = lamda
        self.tversky.set_lamda(lamda)
    
    @property
    def lamda(self):
        return self._lamda

# 사용 예시
criterion = CombinedLoss(
    lamda=0.5,
    ce_weight=ce_weight,
    tversky_weight=(1-ce_weight)/2,
    ssim_weight=(1-ce_weight)/2,
    include_background=False,
    reduction="mean",
    softmax=True
)

In [6]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
from monai.metrics import DiceMetric

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=feature_size,
    use_checkpoint=True,
    drop_rate = drop_rate,
    attn_drop_rate = attn_drop_rate,
    use_v2 = use_v2,
).to(device)
# Pretrained weights 불러오기
if use_checkpoint:
    pretrain_path = "./swin_unetr_btcv_segmentation/models/model.pt"
    weight = torch.load(pretrain_path, map_location=device)

    # 출력 레이어의 키를 제외한 나머지 가중치만 로드
    filtered_weights = {k: v for k, v in weight.items() if "out.conv.conv" not in k}

    # strict=False로 로드하여 불일치하는 부분 무시
    model.load_state_dict(filtered_weights, strict=False)
    print("Filtered weights loaded successfully. Output layer will be trained from scratch.")

# Load pretrained weights
# model.load_from(weights=np.load(config_vit.real_pretrained_path, allow_pickle=True))
# TverskyLoss 설정
# 사용 예시
criterion = CombinedCETverskyLoss(
    lamda=lamda,
    ce_weight=ce_weight,  # CE Loss와 Tversky Loss의 비중을 0.5:0.5로 설정
    include_background=False,
    reduction="mean",
    softmax=True
)

pretrain_str = "yes" if use_checkpoint else "no"
weight_str = "weighted" if class_weights is not None else ""

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
checkpoint_dir = checkpoint_base_dir / f"SwinUNETR_v2_p{pretrain_str}_{weight_str}_f{feature_size}_d{img_depth}s{img_size}_lr{lr:.0e}_a{lamda:.2f}_b{1-lamda:.2f}_b{batch_size}_r{num_repeat}_ce{ce_weight}"
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")



기존 best model 발견: model_checkpoints\SwinUNETR_v2_pyes_weighted_f48_d96s96_lr1e-03_a0.52_b0.48_b2_r29_ce0.4\best_model.pt


  checkpoint = torch.load(best_model_path, map_location=device)


기존 학습된 가중치를 성공적으로 로드했습니다.


In [7]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([2, 1, 96, 96, 96]) torch.Size([2, 1, 96, 96, 96])


In [8]:
torch.backends.cudnn.benchmark = True

In [9]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f'SwinUNETR_p{pretrain_str}_{weight_str}_f{feature_size}_d{img_depth}_s{img_size}_lr{lr:.0e}_a{lamda:.2f}_b{1-lamda:.2f}_batch{batch_size}_{current_time}'

# wandb 초기화
wandb.init(
    project='czii_SwinUnetR',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        "attn_drop_rate": attn_drop_rate,
        "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mwoow070840[0m ([33mwaooang[0m). Use [1m`wandb login --relogin`[0m to force relogin


[]

# 학습

In [None]:
from monai.metrics import DiceMetric
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    # loss = criterion(outputs, labels_onehot)
    loss = loss_fn(criterion(outputs, labels_onehot),class_weights=class_weights, device=device)
    return loss, outputs, labels, outputs.argmax(dim=1)

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    with tqdm(train_loader, desc='Training') as pbar:
        for i, batch_data in enumerate(pbar):
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
    avg_loss = epoch_loss / len(train_loader)
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss


def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval):
    model.eval()
    val_loss = 0
    
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
        print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1})
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\n")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    return val_loss / len(val_loader), overall_mean_fbeta

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps
        )
        
        scheduler.step(train_loss)
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F-beta: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        if epochs_no_improve%6 == 0:
            # 손실이 개선되지 않았으므로 lambda 감소
            new_lamda = max(criterion.lamda - 0.01, 0.35)  # 최소값은 0.1로 설정
            criterion.set_lamda(new_lamda)
            print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")

    wandb.finish()


In [11]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps
     ) 

Epoch 32/4000


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Training: 100%|██████████| 696/696 [12:56<00:00,  1.12s/it, loss=0.311]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.12it/s, loss=0.401]


Validation Dice Score
Class 0: 0.9897, Class 1: 0.6125, Class 2: 0.0824, Class 3: 0.2307, Class 4: 0.6816, Class 5: 0.3723, Class 6: 0.5436, 
Validation F-beta Score
Class 0: 0.9906, Class 1: 0.6572, Class 2: 0.0867, Class 3: 0.2404, Class 4: 0.6373, Class 5: 0.3871, Class 6: 0.6323, 

Overall Mean Dice Score: 0.4882
Overall Mean F-beta Score: 0.5108

Training Loss: 0.3099, Validation Loss: 0.3447, Validation F-beta: 0.5108
Epoch 33/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.341]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.12it/s, loss=0.279]


Validation Dice Score
Class 0: 0.9905, Class 1: 0.6668, Class 2: 0.1846, Class 3: 0.4980, Class 4: 0.6712, Class 5: 0.3666, Class 6: 0.9279, 
Validation F-beta Score
Class 0: 0.9885, Class 1: 0.7130, Class 2: 0.1691, Class 3: 0.5477, Class 4: 0.7575, Class 5: 0.4113, Class 6: 0.9486, 

Overall Mean Dice Score: 0.6261
Overall Mean F-beta Score: 0.6756

Training Loss: 0.3110, Validation Loss: 0.2967, Validation F-beta: 0.6756
Epoch 34/4000


Training: 100%|██████████| 696/696 [14:32<00:00,  1.25s/it, loss=0.326]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.14it/s, loss=0.31] 


Validation Dice Score
Class 0: 0.9912, Class 1: 0.5044, Class 2: 0.1660, Class 3: 0.3040, Class 4: 0.6207, Class 5: 0.5441, Class 6: 0.6814, 
Validation F-beta Score
Class 0: 0.9911, Class 1: 0.7390, Class 2: 0.1677, Class 3: 0.3256, Class 4: 0.6406, Class 5: 0.5248, Class 6: 0.6831, 

Overall Mean Dice Score: 0.5309
Overall Mean F-beta Score: 0.5826

Training Loss: 0.3091, Validation Loss: 0.3071, Validation F-beta: 0.5826
Epoch 35/4000


Training: 100%|██████████| 696/696 [14:28<00:00,  1.25s/it, loss=0.356]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.23it/s, loss=0.392]


Validation Dice Score
Class 0: 0.9915, Class 1: 0.5078, Class 2: 0.0470, Class 3: 0.3572, Class 4: 0.4593, Class 5: 0.4470, Class 6: 0.3836, 
Validation F-beta Score
Class 0: 0.9915, Class 1: 0.5410, Class 2: 0.0438, Class 3: 0.4229, Class 4: 0.5047, Class 5: 0.5123, Class 6: 0.4760, 

Overall Mean Dice Score: 0.4310
Overall Mean F-beta Score: 0.4914

Training Loss: 0.3121, Validation Loss: 0.3407, Validation F-beta: 0.4914
Epoch 36/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.326]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.14it/s, loss=0.204]


Validation Dice Score
Class 0: 0.9912, Class 1: 0.4980, Class 2: 0.2081, Class 3: 0.3978, Class 4: 0.5992, Class 5: 0.4370, Class 6: 0.4731, 
Validation F-beta Score
Class 0: 0.9902, Class 1: 0.6756, Class 2: 0.1917, Class 3: 0.4854, Class 4: 0.6272, Class 5: 0.4528, Class 6: 0.6224, 

Overall Mean Dice Score: 0.4810
Overall Mean F-beta Score: 0.5727

Training Loss: 0.3063, Validation Loss: 0.3238, Validation F-beta: 0.5727
Epoch 37/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.395]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.13it/s, loss=0.262]


Validation Dice Score
Class 0: 0.9912, Class 1: 0.8040, Class 2: 0.1109, Class 3: 0.4763, Class 4: 0.6611, Class 5: 0.5247, Class 6: 0.6283, 
Validation F-beta Score
Class 0: 0.9899, Class 1: 0.8616, Class 2: 0.1317, Class 3: 0.5239, Class 4: 0.6931, Class 5: 0.5664, Class 6: 0.6732, 

Overall Mean Dice Score: 0.6189
Overall Mean F-beta Score: 0.6636

Training Loss: 0.3082, Validation Loss: 0.2943, Validation F-beta: 0.6636
Epoch 38/4000


Training: 100%|██████████| 696/696 [14:29<00:00,  1.25s/it, loss=0.314]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.14it/s, loss=0.375]


Validation Dice Score
Class 0: 0.9915, Class 1: 0.6766, Class 2: 0.1682, Class 3: 0.4194, Class 4: 0.6346, Class 5: 0.3533, Class 6: 0.6050, 
Validation F-beta Score
Class 0: 0.9934, Class 1: 0.7426, Class 2: 0.1651, Class 3: 0.4371, Class 4: 0.5524, Class 5: 0.3941, Class 6: 0.6298, 

Overall Mean Dice Score: 0.5378
Overall Mean F-beta Score: 0.5512

Training Loss: 0.3097, Validation Loss: 0.3225, Validation F-beta: 0.5512
Epoch 39/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.381]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.14it/s, loss=0.303]


Validation Dice Score
Class 0: 0.9910, Class 1: 0.6889, Class 2: 0.2484, Class 3: 0.3228, Class 4: 0.6312, Class 5: 0.5181, Class 6: 0.6793, 
Validation F-beta Score
Class 0: 0.9911, Class 1: 0.8142, Class 2: 0.3037, Class 3: 0.3524, Class 4: 0.6400, Class 5: 0.5182, Class 6: 0.6855, 

Overall Mean Dice Score: 0.5681
Overall Mean F-beta Score: 0.6021

Training Loss: 0.3080, Validation Loss: 0.3089, Validation F-beta: 0.6021
Epoch 40/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.333]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.16it/s, loss=0.291]


Validation Dice Score
Class 0: 0.9916, Class 1: 0.5021, Class 2: 0.1798, Class 3: 0.4768, Class 4: 0.5458, Class 5: 0.4402, Class 6: 0.4164, 
Validation F-beta Score
Class 0: 0.9921, Class 1: 0.6215, Class 2: 0.1508, Class 3: 0.5081, Class 4: 0.6053, Class 5: 0.4634, Class 6: 0.4072, 

Overall Mean Dice Score: 0.4763
Overall Mean F-beta Score: 0.5211

Training Loss: 0.3060, Validation Loss: 0.3223, Validation F-beta: 0.5211
Epoch 41/4000


Training: 100%|██████████| 696/696 [14:30<00:00,  1.25s/it, loss=0.346]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.17it/s, loss=0.342]


Validation Dice Score
Class 0: 0.9936, Class 1: 0.6302, Class 2: 0.0422, Class 3: 0.3616, Class 4: 0.5668, Class 5: 0.4635, Class 6: 0.6350, 
Validation F-beta Score
Class 0: 0.9927, Class 1: 0.6593, Class 2: 0.0389, Class 3: 0.3635, Class 4: 0.5584, Class 5: 0.5757, Class 6: 0.6719, 

Overall Mean Dice Score: 0.5314
Overall Mean F-beta Score: 0.5658

Training Loss: 0.3093, Validation Loss: 0.3185, Validation F-beta: 0.5658
Epoch 42/4000


Training: 100%|██████████| 696/696 [14:30<00:00,  1.25s/it, loss=0.355]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.13it/s, loss=0.368]


Validation Dice Score
Class 0: 0.9924, Class 1: 0.5763, Class 2: 0.1796, Class 3: 0.3544, Class 4: 0.6110, Class 5: 0.4933, Class 6: 0.5866, 
Validation F-beta Score
Class 0: 0.9921, Class 1: 0.6484, Class 2: 0.2443, Class 3: 0.4187, Class 4: 0.6173, Class 5: 0.5267, Class 6: 0.6243, 

Overall Mean Dice Score: 0.5243
Overall Mean F-beta Score: 0.5671

Training Loss: 0.3110, Validation Loss: 0.3131, Validation F-beta: 0.5671
Epoch 43/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.336]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.14it/s, loss=0.38] 


Validation Dice Score
Class 0: 0.9925, Class 1: 0.6748, Class 2: 0.0832, Class 3: 0.3864, Class 4: 0.6501, Class 5: 0.3707, Class 6: 0.6988, 
Validation F-beta Score
Class 0: 0.9920, Class 1: 0.7235, Class 2: 0.0931, Class 3: 0.4577, Class 4: 0.6432, Class 5: 0.3827, Class 6: 0.6825, 

Overall Mean Dice Score: 0.5562
Overall Mean F-beta Score: 0.5779

Training Loss: 0.3061, Validation Loss: 0.3007, Validation F-beta: 0.5779
Epoch 44/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.304]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.14it/s, loss=0.29] 


Validation Dice Score
Class 0: 0.9907, Class 1: 0.6609, Class 2: 0.2181, Class 3: 0.3340, Class 4: 0.7057, Class 5: 0.3668, Class 6: 0.7752, 
Validation F-beta Score
Class 0: 0.9905, Class 1: 0.7140, Class 2: 0.2519, Class 3: 0.4110, Class 4: 0.7634, Class 5: 0.3542, Class 6: 0.8032, 

Overall Mean Dice Score: 0.5685
Overall Mean F-beta Score: 0.6092

Training Loss: 0.3056, Validation Loss: 0.2979, Validation F-beta: 0.6092
Epoch 45/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.328]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.17it/s, loss=0.286]


Validation Dice Score
Class 0: 0.9917, Class 1: 0.3699, Class 2: 0.1003, Class 3: 0.4410, Class 4: 0.6309, Class 5: 0.4852, Class 6: 0.6270, 
Validation F-beta Score
Class 0: 0.9923, Class 1: 0.4084, Class 2: 0.1577, Class 3: 0.5626, Class 4: 0.7762, Class 5: 0.5073, Class 6: 0.7346, 

Overall Mean Dice Score: 0.5108
Overall Mean F-beta Score: 0.5978

Training Loss: 0.3059, Validation Loss: 0.3214, Validation F-beta: 0.5978
Epoch 46/4000


Training: 100%|██████████| 696/696 [14:30<00:00,  1.25s/it, loss=0.264]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.13it/s, loss=0.359]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.6016, Class 2: 0.1019, Class 3: 0.3067, Class 4: 0.4755, Class 5: 0.4281, Class 6: 0.6006, 
Validation F-beta Score
Class 0: 0.9918, Class 1: 0.6213, Class 2: 0.1178, Class 3: 0.4275, Class 4: 0.4927, Class 5: 0.4359, Class 6: 0.6493, 

Overall Mean Dice Score: 0.4825
Overall Mean F-beta Score: 0.5253

Training Loss: 0.3020, Validation Loss: 0.3424, Validation F-beta: 0.5253
Epoch 47/4000


Training: 100%|██████████| 696/696 [14:29<00:00,  1.25s/it, loss=0.257]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.22it/s, loss=0.258]


Validation Dice Score
Class 0: 0.9909, Class 1: 0.6184, Class 2: 0.2176, Class 3: 0.3738, Class 4: 0.6367, Class 5: 0.4586, Class 6: 0.6070, 
Validation F-beta Score
Class 0: 0.9893, Class 1: 0.6837, Class 2: 0.2179, Class 3: 0.3982, Class 4: 0.6402, Class 5: 0.5500, Class 6: 0.6517, 

Overall Mean Dice Score: 0.5389
Overall Mean F-beta Score: 0.5848

Training Loss: 0.3034, Validation Loss: 0.3028, Validation F-beta: 0.5848
Epoch 48/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.272]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.09it/s, loss=0.256]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.5474, Class 2: 0.1546, Class 3: 0.2956, Class 4: 0.7946, Class 5: 0.5051, Class 6: 0.6974, 
Validation F-beta Score
Class 0: 0.9915, Class 1: 0.5893, Class 2: 0.1667, Class 3: 0.3093, Class 4: 0.8175, Class 5: 0.5085, Class 6: 0.8779, 

Overall Mean Dice Score: 0.5680
Overall Mean F-beta Score: 0.6205

Training Loss: 0.3048, Validation Loss: 0.3131, Validation F-beta: 0.6205
Epoch 49/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.259]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.08it/s, loss=0.192]


Validation Dice Score
Class 0: 0.9913, Class 1: 0.6837, Class 2: 0.1782, Class 3: 0.4956, Class 4: 0.5687, Class 5: 0.3949, Class 6: 0.6830, 
Validation F-beta Score
Class 0: 0.9919, Class 1: 0.9141, Class 2: 0.1982, Class 3: 0.5037, Class 4: 0.5588, Class 5: 0.3863, Class 6: 0.8610, 

Overall Mean Dice Score: 0.5652
Overall Mean F-beta Score: 0.6448

Training Loss: 0.3093, Validation Loss: 0.2960, Validation F-beta: 0.6448
Epoch 50/4000


Training: 100%|██████████| 696/696 [14:32<00:00,  1.25s/it, loss=0.305]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.12it/s, loss=0.308]


Validation Dice Score
Class 0: 0.9931, Class 1: 0.5827, Class 2: 0.2226, Class 3: 0.4689, Class 4: 0.6534, Class 5: 0.4092, Class 6: 0.7607, 
Validation F-beta Score
Class 0: 0.9911, Class 1: 0.5722, Class 2: 0.3093, Class 3: 0.5840, Class 4: 0.7041, Class 5: 0.4357, Class 6: 0.8726, 

Overall Mean Dice Score: 0.5750
Overall Mean F-beta Score: 0.6337

Training Loss: 0.3016, Validation Loss: 0.2921, Validation F-beta: 0.6337
Epoch 51/4000


Training: 100%|██████████| 696/696 [14:30<00:00,  1.25s/it, loss=0.316]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.11it/s, loss=0.334]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.5437, Class 2: 0.2101, Class 3: 0.2361, Class 4: 0.5423, Class 5: 0.4944, Class 6: 0.6697, 
Validation F-beta Score
Class 0: 0.9907, Class 1: 0.6314, Class 2: 0.2418, Class 3: 0.3079, Class 4: 0.6480, Class 5: 0.5431, Class 6: 0.7114, 

Overall Mean Dice Score: 0.4972
Overall Mean F-beta Score: 0.5683

Training Loss: 0.3031, Validation Loss: 0.3285, Validation F-beta: 0.5683
Epoch 52/4000


Training: 100%|██████████| 696/696 [14:33<00:00,  1.26s/it, loss=0.325]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.12it/s, loss=0.365]


Validation Dice Score
Class 0: 0.9913, Class 1: 0.6992, Class 2: 0.1478, Class 3: 0.2700, Class 4: 0.7546, Class 5: 0.2926, Class 6: 0.6400, 
Validation F-beta Score
Class 0: 0.9913, Class 1: 0.7672, Class 2: 0.1648, Class 3: 0.3096, Class 4: 0.7572, Class 5: 0.2835, Class 6: 0.8650, 

Overall Mean Dice Score: 0.5313
Overall Mean F-beta Score: 0.5965

Training Loss: 0.3031, Validation Loss: 0.3200, Validation F-beta: 0.5965
Epoch 53/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.32] 
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.19it/s, loss=0.392]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.7623, Class 2: 0.1383, Class 3: 0.3337, Class 4: 0.4867, Class 5: 0.4912, Class 6: 0.6713, 
Validation F-beta Score
Class 0: 0.9902, Class 1: 0.8224, Class 2: 0.1339, Class 3: 0.3123, Class 4: 0.4742, Class 5: 0.6115, Class 6: 0.7092, 

Overall Mean Dice Score: 0.5491
Overall Mean F-beta Score: 0.5859

Training Loss: 0.3001, Validation Loss: 0.3027, Validation F-beta: 0.5859
Epoch 54/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.238]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.14it/s, loss=0.293]


Validation Dice Score
Class 0: 0.9912, Class 1: 0.6288, Class 2: 0.1911, Class 3: 0.4195, Class 4: 0.5843, Class 5: 0.4989, Class 6: 0.6826, 
Validation F-beta Score
Class 0: 0.9930, Class 1: 0.7554, Class 2: 0.1940, Class 3: 0.4427, Class 4: 0.5340, Class 5: 0.5423, Class 6: 0.6647, 

Overall Mean Dice Score: 0.5628
Overall Mean F-beta Score: 0.5878

Training Loss: 0.3013, Validation Loss: 0.2948, Validation F-beta: 0.5878
Epoch 55/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.264]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.15it/s, loss=0.292]


Validation Dice Score
Class 0: 0.9929, Class 1: 0.5716, Class 2: 0.2141, Class 3: 0.4107, Class 4: 0.7518, Class 5: 0.4122, Class 6: 0.3990, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.6732, Class 2: 0.1993, Class 3: 0.4432, Class 4: 0.7745, Class 5: 0.4500, Class 6: 0.5411, 

Overall Mean Dice Score: 0.5090
Overall Mean F-beta Score: 0.5764

Training Loss: 0.3023, Validation Loss: 0.3263, Validation F-beta: 0.5764
Epoch 56/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.241]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.11it/s, loss=0.194]


Validation Dice Score
Class 0: 0.9935, Class 1: 0.6717, Class 2: 0.1870, Class 3: 0.4557, Class 4: 0.6193, Class 5: 0.5282, Class 6: 0.6958, 
Validation F-beta Score
Class 0: 0.9935, Class 1: 0.7710, Class 2: 0.2035, Class 3: 0.4931, Class 4: 0.6076, Class 5: 0.5382, Class 6: 0.8030, 

Overall Mean Dice Score: 0.5941
Overall Mean F-beta Score: 0.6426

Training Loss: 0.3023, Validation Loss: 0.2855, Validation F-beta: 0.6426
SUPER Best model saved. Loss:0.2855, Score:0.6426
Epoch 57/4000


Training: 100%|██████████| 696/696 [14:36<00:00,  1.26s/it, loss=0.255]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.14it/s, loss=0.326]


Validation Dice Score
Class 0: 0.9910, Class 1: 0.7191, Class 2: 0.2075, Class 3: 0.3530, Class 4: 0.5891, Class 5: 0.4323, Class 6: 0.6898, 
Validation F-beta Score
Class 0: 0.9898, Class 1: 0.8475, Class 2: 0.2082, Class 3: 0.4338, Class 4: 0.6617, Class 5: 0.4762, Class 6: 0.7249, 

Overall Mean Dice Score: 0.5567
Overall Mean F-beta Score: 0.6288

Training Loss: 0.2989, Validation Loss: 0.3040, Validation F-beta: 0.6288
Epoch 58/4000


Training: 100%|██████████| 696/696 [14:31<00:00,  1.25s/it, loss=0.309]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.16it/s, loss=0.29] 


Validation Dice Score
Class 0: 0.9923, Class 1: 0.6700, Class 2: 0.2162, Class 3: 0.3193, Class 4: 0.7625, Class 5: 0.4179, Class 6: 0.8520, 
Validation F-beta Score
Class 0: 0.9922, Class 1: 0.7536, Class 2: 0.1904, Class 3: 0.3972, Class 4: 0.7554, Class 5: 0.4162, Class 6: 0.8626, 

Overall Mean Dice Score: 0.6043
Overall Mean F-beta Score: 0.6370

Training Loss: 0.3008, Validation Loss: 0.2798, Validation F-beta: 0.6370
Epoch 59/4000


Training: 100%|██████████| 696/696 [14:33<00:00,  1.25s/it, loss=0.282]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.14it/s, loss=0.273]


Validation Dice Score
Class 0: 0.9917, Class 1: 0.6332, Class 2: 0.1650, Class 3: 0.3452, Class 4: 0.6264, Class 5: 0.4441, Class 6: 0.7849, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.6781, Class 2: 0.1878, Class 3: 0.4658, Class 4: 0.6075, Class 5: 0.4603, Class 6: 0.7878, 

Overall Mean Dice Score: 0.5668
Overall Mean F-beta Score: 0.5999

Training Loss: 0.2995, Validation Loss: 0.3096, Validation F-beta: 0.5999
Epoch 60/4000


Training: 100%|██████████| 696/696 [14:32<00:00,  1.25s/it, loss=0.285]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.16it/s, loss=0.287]


Validation Dice Score
Class 0: 0.9915, Class 1: 0.6279, Class 2: 0.2059, Class 3: 0.3914, Class 4: 0.6962, Class 5: 0.4538, Class 6: 0.6936, 
Validation F-beta Score
Class 0: 0.9928, Class 1: 0.6478, Class 2: 0.2408, Class 3: 0.3502, Class 4: 0.7060, Class 5: 0.4454, Class 6: 0.6936, 

Overall Mean Dice Score: 0.5726
Overall Mean F-beta Score: 0.5686

Training Loss: 0.3011, Validation Loss: 0.3053, Validation F-beta: 0.5686
Epoch 61/4000


Training: 100%|██████████| 696/696 [14:06<00:00,  1.22s/it, loss=0.363]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.22it/s, loss=0.4]  


Validation Dice Score
Class 0: 0.9920, Class 1: 0.7684, Class 2: 0.2533, Class 3: 0.3643, Class 4: 0.6640, Class 5: 0.5071, Class 6: 0.6049, 
Validation F-beta Score
Class 0: 0.9925, Class 1: 0.8161, Class 2: 0.2908, Class 3: 0.4048, Class 4: 0.6312, Class 5: 0.5179, Class 6: 0.8090, 

Overall Mean Dice Score: 0.5818
Overall Mean F-beta Score: 0.6358

Training Loss: 0.3009, Validation Loss: 0.2999, Validation F-beta: 0.6358
Epoch 62/4000


Training: 100%|██████████| 696/696 [16:25<00:00,  1.42s/it, loss=0.304]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.87it/s, loss=0.312]


Validation Dice Score
Class 0: 0.9923, Class 1: 0.6446, Class 2: 0.1518, Class 3: 0.4454, Class 4: 0.8077, Class 5: 0.3919, Class 6: 0.7883, 
Validation F-beta Score
Class 0: 0.9905, Class 1: 0.6830, Class 2: 0.1448, Class 3: 0.5483, Class 4: 0.8704, Class 5: 0.4465, Class 6: 0.8857, 

Overall Mean Dice Score: 0.6156
Overall Mean F-beta Score: 0.6868

Training Loss: 0.2958, Validation Loss: 0.2860, Validation F-beta: 0.6868
Epoch 63/4000


Training: 100%|██████████| 696/696 [17:01<00:00,  1.47s/it, loss=0.285]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.74it/s, loss=0.325]


Validation Dice Score
Class 0: 0.9921, Class 1: 0.6691, Class 2: 0.1289, Class 3: 0.2875, Class 4: 0.6113, Class 5: 0.4477, Class 6: 0.6374, 
Validation F-beta Score
Class 0: 0.9919, Class 1: 0.7633, Class 2: 0.1328, Class 3: 0.2726, Class 4: 0.6266, Class 5: 0.4383, Class 6: 0.7882, 

Overall Mean Dice Score: 0.5306
Overall Mean F-beta Score: 0.5778

Training Loss: 0.3002, Validation Loss: 0.3109, Validation F-beta: 0.5778
Epoch 64/4000


Training: 100%|██████████| 696/696 [16:34<00:00,  1.43s/it, loss=0.284]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.97it/s, loss=0.403]


Validation Dice Score
Class 0: 0.9915, Class 1: 0.7049, Class 2: 0.1590, Class 3: 0.3654, Class 4: 0.5503, Class 5: 0.3718, Class 6: 0.6155, 
Validation F-beta Score
Class 0: 0.9918, Class 1: 0.7341, Class 2: 0.2428, Class 3: 0.4730, Class 4: 0.5587, Class 5: 0.3633, Class 6: 0.7947, 

Overall Mean Dice Score: 0.5216
Overall Mean F-beta Score: 0.5848

Training Loss: 0.3013, Validation Loss: 0.3246, Validation F-beta: 0.5848
Epoch 65/4000


Training: 100%|██████████| 696/696 [18:08<00:00,  1.56s/it, loss=0.311]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.03it/s, loss=0.298]


Validation Dice Score
Class 0: 0.9915, Class 1: 0.7386, Class 2: 0.1155, Class 3: 0.4045, Class 4: 0.5158, Class 5: 0.4732, Class 6: 0.6743, 
Validation F-beta Score
Class 0: 0.9926, Class 1: 0.8379, Class 2: 0.1270, Class 3: 0.4662, Class 4: 0.5182, Class 5: 0.4658, Class 6: 0.6959, 

Overall Mean Dice Score: 0.5613
Overall Mean F-beta Score: 0.5968

Training Loss: 0.3002, Validation Loss: 0.3149, Validation F-beta: 0.5968
Epoch 66/4000


Training: 100%|██████████| 696/696 [17:14<00:00,  1.49s/it, loss=0.309]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.76it/s, loss=0.366]


Validation Dice Score
Class 0: 0.9896, Class 1: 0.6023, Class 2: 0.1917, Class 3: 0.3909, Class 4: 0.6180, Class 5: 0.3394, Class 6: 0.6793, 
Validation F-beta Score
Class 0: 0.9862, Class 1: 0.6809, Class 2: 0.2127, Class 3: 0.4212, Class 4: 0.7185, Class 5: 0.4815, Class 6: 0.7797, 

Overall Mean Dice Score: 0.5260
Overall Mean F-beta Score: 0.6164

Training Loss: 0.2998, Validation Loss: 0.3044, Validation F-beta: 0.6164
Epoch 67/4000


Training: 100%|██████████| 696/696 [20:49<00:00,  1.80s/it, loss=0.362]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.48it/s, loss=0.362]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.6465, Class 2: 0.0570, Class 3: 0.2518, Class 4: 0.6653, Class 5: 0.4192, Class 6: 0.5741, 
Validation F-beta Score
Class 0: 0.9933, Class 1: 0.6596, Class 2: 0.0636, Class 3: 0.2578, Class 4: 0.6297, Class 5: 0.3714, Class 6: 0.5775, 

Overall Mean Dice Score: 0.5114
Overall Mean F-beta Score: 0.4992

Training Loss: 0.3005, Validation Loss: 0.3307, Validation F-beta: 0.4992
Epoch 68/4000


Training: 100%|██████████| 696/696 [21:02<00:00,  1.81s/it, loss=0.303]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.73it/s, loss=0.222]


Validation Dice Score
Class 0: 0.9925, Class 1: 0.4985, Class 2: 0.2133, Class 3: 0.5337, Class 4: 0.7501, Class 5: 0.3707, Class 6: 0.6887, 
Validation F-beta Score
Class 0: 0.9917, Class 1: 0.5689, Class 2: 0.2585, Class 3: 0.5860, Class 4: 0.7973, Class 5: 0.4143, Class 6: 0.7085, 

Overall Mean Dice Score: 0.5683
Overall Mean F-beta Score: 0.6150

Training Loss: 0.2998, Validation Loss: 0.3066, Validation F-beta: 0.6150
Epoch 69/4000


Training: 100%|██████████| 696/696 [19:35<00:00,  1.69s/it, loss=0.265]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.02it/s, loss=0.227]


Validation Dice Score
Class 0: 0.9923, Class 1: 0.7196, Class 2: 0.2040, Class 3: 0.4504, Class 4: 0.6056, Class 5: 0.4806, Class 6: 0.7683, 
Validation F-beta Score
Class 0: 0.9933, Class 1: 0.8485, Class 2: 0.2804, Class 3: 0.5169, Class 4: 0.5798, Class 5: 0.4724, Class 6: 0.8609, 

Overall Mean Dice Score: 0.6049
Overall Mean F-beta Score: 0.6557

Training Loss: 0.2957, Validation Loss: 0.2879, Validation F-beta: 0.6557
Epoch 70/4000


Training: 100%|██████████| 696/696 [16:52<00:00,  1.45s/it, loss=0.277]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.96it/s, loss=0.356]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.7783, Class 2: 0.2584, Class 3: 0.3879, Class 4: 0.5789, Class 5: 0.5147, Class 6: 0.6829, 
Validation F-beta Score
Class 0: 0.9921, Class 1: 0.8645, Class 2: 0.2670, Class 3: 0.4250, Class 4: 0.5894, Class 5: 0.5198, Class 6: 0.6945, 

Overall Mean Dice Score: 0.5886
Overall Mean F-beta Score: 0.6186

Training Loss: 0.2927, Validation Loss: 0.2861, Validation F-beta: 0.6186
Epoch 71/4000


Training: 100%|██████████| 696/696 [16:37<00:00,  1.43s/it, loss=0.332]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.81it/s, loss=0.234]


Validation Dice Score
Class 0: 0.9906, Class 1: 0.7053, Class 2: 0.1286, Class 3: 0.4286, Class 4: 0.5766, Class 5: 0.4568, Class 6: 0.7721, 
Validation F-beta Score
Class 0: 0.9904, Class 1: 0.7549, Class 2: 0.1428, Class 3: 0.5196, Class 4: 0.5938, Class 5: 0.5006, Class 6: 0.7954, 

Overall Mean Dice Score: 0.5879
Overall Mean F-beta Score: 0.6329

Training Loss: 0.2926, Validation Loss: 0.2911, Validation F-beta: 0.6329
Epoch 72/4000


Training: 100%|██████████| 696/696 [16:57<00:00,  1.46s/it, loss=0.323]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.84it/s, loss=0.294]


Validation Dice Score
Class 0: 0.9933, Class 1: 0.6950, Class 2: 0.2401, Class 3: 0.4207, Class 4: 0.6218, Class 5: 0.4007, Class 6: 0.5382, 
Validation F-beta Score
Class 0: 0.9944, Class 1: 0.7071, Class 2: 0.2275, Class 3: 0.4360, Class 4: 0.5812, Class 5: 0.4232, Class 6: 0.7203, 

Overall Mean Dice Score: 0.5353
Overall Mean F-beta Score: 0.5735

Training Loss: 0.2948, Validation Loss: 0.3171, Validation F-beta: 0.5735
Epoch 73/4000


Training: 100%|██████████| 696/696 [16:41<00:00,  1.44s/it, loss=0.352]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.95it/s, loss=0.29] 


Validation Dice Score
Class 0: 0.9928, Class 1: 0.6852, Class 2: 0.1992, Class 3: 0.3510, Class 4: 0.5128, Class 5: 0.4302, Class 6: 0.6106, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.7495, Class 2: 0.2459, Class 3: 0.4092, Class 4: 0.5189, Class 5: 0.4531, Class 6: 0.7124, 

Overall Mean Dice Score: 0.5180
Overall Mean F-beta Score: 0.5686

Training Loss: 0.2921, Validation Loss: 0.3148, Validation F-beta: 0.5686
Epoch 74/4000


Training: 100%|██████████| 696/696 [19:05<00:00,  1.65s/it, loss=0.339]
Validation: 100%|██████████| 12/12 [00:07<00:00,  1.52it/s, loss=0.334]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.5646, Class 2: 0.2660, Class 3: 0.5032, Class 4: 0.7122, Class 5: 0.4932, Class 6: 0.7629, 
Validation F-beta Score
Class 0: 0.9922, Class 1: 0.6756, Class 2: 0.2902, Class 3: 0.5791, Class 4: 0.7021, Class 5: 0.5335, Class 6: 0.8484, 

Overall Mean Dice Score: 0.6072
Overall Mean F-beta Score: 0.6677

Training Loss: 0.2935, Validation Loss: 0.2965, Validation F-beta: 0.6677
Epoch 75/4000


Training: 100%|██████████| 696/696 [27:50<00:00,  2.40s/it, loss=0.261]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.36it/s, loss=0.293]


Validation Dice Score
Class 0: 0.9930, Class 1: 0.6086, Class 2: 0.1430, Class 3: 0.3412, Class 4: 0.6257, Class 5: 0.3026, Class 6: 0.6875, 
Validation F-beta Score
Class 0: 0.9929, Class 1: 0.7619, Class 2: 0.1400, Class 3: 0.3705, Class 4: 0.6539, Class 5: 0.2947, Class 6: 0.6955, 

Overall Mean Dice Score: 0.5131
Overall Mean F-beta Score: 0.5553

Training Loss: 0.2925, Validation Loss: 0.3148, Validation F-beta: 0.5553
Epoch 76/4000


Training: 100%|██████████| 696/696 [26:47<00:00,  2.31s/it, loss=0.316]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.46it/s, loss=0.197]


Validation Dice Score
Class 0: 0.9923, Class 1: 0.6766, Class 2: 0.1407, Class 3: 0.4911, Class 4: 0.5756, Class 5: 0.4419, Class 6: 0.6262, 
Validation F-beta Score
Class 0: 0.9930, Class 1: 0.8056, Class 2: 0.1105, Class 3: 0.5901, Class 4: 0.6105, Class 5: 0.3973, Class 6: 0.7127, 

Overall Mean Dice Score: 0.5623
Overall Mean F-beta Score: 0.6232

Training Loss: 0.2902, Validation Loss: 0.3051, Validation F-beta: 0.6232
Epoch 77/4000


Training: 100%|██████████| 696/696 [26:32<00:00,  2.29s/it, loss=0.357]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.47it/s, loss=0.39] 


Validation Dice Score
Class 0: 0.9912, Class 1: 0.6459, Class 2: 0.0897, Class 3: 0.4260, Class 4: 0.5597, Class 5: 0.4787, Class 6: 0.5527, 
Validation F-beta Score
Class 0: 0.9923, Class 1: 0.8151, Class 2: 0.0779, Class 3: 0.4738, Class 4: 0.5960, Class 5: 0.5217, Class 6: 0.6336, 

Overall Mean Dice Score: 0.5326
Overall Mean F-beta Score: 0.6080

Training Loss: 0.2908, Validation Loss: 0.3039, Validation F-beta: 0.6080
Epoch 78/4000


Training: 100%|██████████| 696/696 [26:52<00:00,  2.32s/it, loss=0.327]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.45it/s, loss=0.283]


Validation Dice Score
Class 0: 0.9924, Class 1: 0.6026, Class 2: 0.3103, Class 3: 0.3129, Class 4: 0.7416, Class 5: 0.4706, Class 6: 0.8277, 
Validation F-beta Score
Class 0: 0.9915, Class 1: 0.7516, Class 2: 0.3307, Class 3: 0.3440, Class 4: 0.7712, Class 5: 0.5061, Class 6: 0.9325, 

Overall Mean Dice Score: 0.5911
Overall Mean F-beta Score: 0.6611

Training Loss: 0.2935, Validation Loss: 0.2963, Validation F-beta: 0.6611
Epoch 79/4000


Training: 100%|██████████| 696/696 [27:31<00:00,  2.37s/it, loss=0.25] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.26it/s, loss=0.262]


Validation Dice Score
Class 0: 0.9930, Class 1: 0.7194, Class 2: 0.3317, Class 3: 0.3692, Class 4: 0.6065, Class 5: 0.4067, Class 6: 0.7037, 
Validation F-beta Score
Class 0: 0.9915, Class 1: 0.8252, Class 2: 0.3241, Class 3: 0.4035, Class 4: 0.5915, Class 5: 0.4904, Class 6: 0.8785, 

Overall Mean Dice Score: 0.5611
Overall Mean F-beta Score: 0.6378

Training Loss: 0.2906, Validation Loss: 0.2970, Validation F-beta: 0.6378
Epoch 80/4000


Training: 100%|██████████| 696/696 [26:58<00:00,  2.32s/it, loss=0.258]
Validation: 100%|██████████| 12/12 [00:07<00:00,  1.54it/s, loss=0.363]


Validation Dice Score
Class 0: 0.9918, Class 1: 0.5815, Class 2: 0.1344, Class 3: 0.2905, Class 4: 0.5416, Class 5: 0.3249, Class 6: 0.5339, 
Validation F-beta Score
Class 0: 0.9899, Class 1: 0.7514, Class 2: 0.1490, Class 3: 0.3558, Class 4: 0.6449, Class 5: 0.3778, Class 6: 0.7131, 

Overall Mean Dice Score: 0.4545
Overall Mean F-beta Score: 0.5686

Training Loss: 0.2901, Validation Loss: 0.3458, Validation F-beta: 0.5686
Epoch 81/4000


Training: 100%|██████████| 696/696 [27:52<00:00,  2.40s/it, loss=0.308]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.177]


Validation Dice Score
Class 0: 0.9927, Class 1: 0.6445, Class 2: 0.2860, Class 3: 0.4094, Class 4: 0.7587, Class 5: 0.3951, Class 6: 0.8416, 
Validation F-beta Score
Class 0: 0.9930, Class 1: 0.7810, Class 2: 0.2611, Class 3: 0.4159, Class 4: 0.7651, Class 5: 0.3783, Class 6: 0.8845, 

Overall Mean Dice Score: 0.6098
Overall Mean F-beta Score: 0.6450

Training Loss: 0.2885, Validation Loss: 0.2743, Validation F-beta: 0.6450
SUPER Best model saved. Loss:0.2743, Score:0.6450
Epoch 82/4000


Training: 100%|██████████| 696/696 [28:31<00:00,  2.46s/it, loss=0.271]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.38] 


Validation Dice Score
Class 0: 0.9914, Class 1: 0.6071, Class 2: 0.1692, Class 3: 0.3268, Class 4: 0.6270, Class 5: 0.3691, Class 6: 0.7026, 
Validation F-beta Score
Class 0: 0.9925, Class 1: 0.7496, Class 2: 0.1775, Class 3: 0.3302, Class 4: 0.6219, Class 5: 0.3412, Class 6: 0.7838, 

Overall Mean Dice Score: 0.5265
Overall Mean F-beta Score: 0.5654

Training Loss: 0.2904, Validation Loss: 0.3284, Validation F-beta: 0.5654
Epoch 83/4000


Training: 100%|██████████| 696/696 [29:30<00:00,  2.54s/it, loss=0.294]
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.31it/s, loss=0.287]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.6458, Class 2: 0.1764, Class 3: 0.4143, Class 4: 0.6172, Class 5: 0.2946, Class 6: 0.6989, 
Validation F-beta Score
Class 0: 0.9932, Class 1: 0.7612, Class 2: 0.1914, Class 3: 0.4336, Class 4: 0.5887, Class 5: 0.2404, Class 6: 0.8112, 

Overall Mean Dice Score: 0.5342
Overall Mean F-beta Score: 0.5670

Training Loss: 0.2934, Validation Loss: 0.3119, Validation F-beta: 0.5670
Epoch 84/4000


Training: 100%|██████████| 696/696 [29:32<00:00,  2.55s/it, loss=0.244] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.43it/s, loss=0.195]


Validation Dice Score
Class 0: 0.9920, Class 1: 0.6795, Class 2: 0.1374, Class 3: 0.3612, Class 4: 0.6975, Class 5: 0.4983, Class 6: 0.7047, 
Validation F-beta Score
Class 0: 0.9913, Class 1: 0.7811, Class 2: 0.1709, Class 3: 0.3695, Class 4: 0.7162, Class 5: 0.5440, Class 6: 0.8045, 

Overall Mean Dice Score: 0.5882
Overall Mean F-beta Score: 0.6430

Training Loss: 0.2882, Validation Loss: 0.2982, Validation F-beta: 0.6430
Epoch 85/4000


Training: 100%|██████████| 696/696 [30:38<00:00,  2.64s/it, loss=0.283]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.291]


Validation Dice Score
Class 0: 0.9906, Class 1: 0.6600, Class 2: 0.0274, Class 3: 0.2345, Class 4: 0.7042, Class 5: 0.3929, Class 6: 0.7078, 
Validation F-beta Score
Class 0: 0.9910, Class 1: 0.7905, Class 2: 0.0221, Class 3: 0.2802, Class 4: 0.7049, Class 5: 0.3863, Class 6: 0.8090, 

Overall Mean Dice Score: 0.5399
Overall Mean F-beta Score: 0.5942

Training Loss: 0.2960, Validation Loss: 0.3161, Validation F-beta: 0.5942
Epoch 86/4000


Training: 100%|██████████| 696/696 [27:50<00:00,  2.40s/it, loss=0.285]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.41it/s, loss=0.321]


Validation Dice Score
Class 0: 0.9935, Class 1: 0.7866, Class 2: 0.0462, Class 3: 0.3918, Class 4: 0.4786, Class 5: 0.5349, Class 6: 0.8534, 
Validation F-beta Score
Class 0: 0.9935, Class 1: 0.9059, Class 2: 0.0385, Class 3: 0.4338, Class 4: 0.4870, Class 5: 0.5681, Class 6: 0.8799, 

Overall Mean Dice Score: 0.6091
Overall Mean F-beta Score: 0.6550

Training Loss: 0.2879, Validation Loss: 0.2857, Validation F-beta: 0.6550
Epoch 87/4000


Training: 100%|██████████| 696/696 [29:15<00:00,  2.52s/it, loss=0.341]
Validation: 100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.34] 


Validation Dice Score
Class 0: 0.9924, Class 1: 0.6798, Class 2: 0.2050, Class 3: 0.2983, Class 4: 0.5874, Class 5: 0.4305, Class 6: 0.7794, 
Validation F-beta Score
Class 0: 0.9926, Class 1: 0.7671, Class 2: 0.2624, Class 3: 0.3350, Class 4: 0.6003, Class 5: 0.4167, Class 6: 0.7994, 

Overall Mean Dice Score: 0.5551
Overall Mean F-beta Score: 0.5837

Training Loss: 0.2877, Validation Loss: 0.3210, Validation F-beta: 0.5837
Epoch 88/4000


Training: 100%|██████████| 696/696 [30:14<00:00,  2.61s/it, loss=0.271]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.43it/s, loss=0.341]


Validation Dice Score
Class 0: 0.9917, Class 1: 0.5407, Class 2: 0.2273, Class 3: 0.4021, Class 4: 0.6907, Class 5: 0.3627, Class 6: 0.6167, 
Validation F-beta Score
Class 0: 0.9916, Class 1: 0.6998, Class 2: 0.2535, Class 3: 0.4476, Class 4: 0.7316, Class 5: 0.3654, Class 6: 0.6152, 

Overall Mean Dice Score: 0.5226
Overall Mean F-beta Score: 0.5719

Training Loss: 0.2886, Validation Loss: 0.3208, Validation F-beta: 0.5719
Epoch 89/4000


Training: 100%|██████████| 696/696 [27:29<00:00,  2.37s/it, loss=0.356]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.43it/s, loss=0.312]


Validation Dice Score
Class 0: 0.9923, Class 1: 0.6147, Class 2: 0.1928, Class 3: 0.2870, Class 4: 0.4929, Class 5: 0.3579, Class 6: 0.7014, 
Validation F-beta Score
Class 0: 0.9934, Class 1: 0.7332, Class 2: 0.2167, Class 3: 0.3362, Class 4: 0.5185, Class 5: 0.3821, Class 6: 0.8830, 

Overall Mean Dice Score: 0.4908
Overall Mean F-beta Score: 0.5706

Training Loss: 0.2879, Validation Loss: 0.3320, Validation F-beta: 0.5706
Epoch 90/4000


Training: 100%|██████████| 696/696 [20:21<00:00,  1.76s/it, loss=0.264]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.98it/s, loss=0.239]


Validation Dice Score
Class 0: 0.9928, Class 1: 0.6504, Class 2: 0.1109, Class 3: 0.3540, Class 4: 0.7140, Class 5: 0.3905, Class 6: 0.7008, 
Validation F-beta Score
Class 0: 0.9928, Class 1: 0.7603, Class 2: 0.1366, Class 3: 0.3688, Class 4: 0.8077, Class 5: 0.3807, Class 6: 0.8342, 

Overall Mean Dice Score: 0.5619
Overall Mean F-beta Score: 0.6303

Training Loss: 0.2848, Validation Loss: 0.3016, Validation F-beta: 0.6303
Epoch 91/4000


Training: 100%|██████████| 696/696 [17:09<00:00,  1.48s/it, loss=0.297]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.01it/s, loss=0.291]


Validation Dice Score
Class 0: 0.9910, Class 1: 0.7627, Class 2: 0.1766, Class 3: 0.4016, Class 4: 0.5812, Class 5: 0.4510, Class 6: 0.6806, 
Validation F-beta Score
Class 0: 0.9909, Class 1: 0.8826, Class 2: 0.1638, Class 3: 0.4422, Class 4: 0.5918, Class 5: 0.4599, Class 6: 0.7743, 

Overall Mean Dice Score: 0.5754
Overall Mean F-beta Score: 0.6301

Training Loss: 0.2915, Validation Loss: 0.3042, Validation F-beta: 0.6301
Epoch 92/4000


Training: 100%|██████████| 696/696 [17:07<00:00,  1.48s/it, loss=0.357]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.82it/s, loss=0.452]


Validation Dice Score
Class 0: 0.9910, Class 1: 0.4888, Class 2: 0.2463, Class 3: 0.4220, Class 4: 0.5417, Class 5: 0.4247, Class 6: 0.4668, 
Validation F-beta Score
Class 0: 0.9911, Class 1: 0.6142, Class 2: 0.2621, Class 3: 0.4117, Class 4: 0.5364, Class 5: 0.4594, Class 6: 0.5623, 

Overall Mean Dice Score: 0.4688
Overall Mean F-beta Score: 0.5168

Training Loss: 0.2885, Validation Loss: 0.3329, Validation F-beta: 0.5168
Epoch 93/4000


Training: 100%|██████████| 696/696 [17:48<00:00,  1.54s/it, loss=0.257]
Validation: 100%|██████████| 12/12 [00:07<00:00,  1.66it/s, loss=0.223]


Validation Dice Score
Class 0: 0.9933, Class 1: 0.5228, Class 2: 0.1870, Class 3: 0.4904, Class 4: 0.7338, Class 5: 0.5327, Class 6: 0.6949, 
Validation F-beta Score
Class 0: 0.9936, Class 1: 0.7253, Class 2: 0.1590, Class 3: 0.5544, Class 4: 0.7187, Class 5: 0.5962, Class 6: 0.7882, 

Overall Mean Dice Score: 0.5949
Overall Mean F-beta Score: 0.6766

Training Loss: 0.2882, Validation Loss: 0.2988, Validation F-beta: 0.6766
Epoch 94/4000


Training: 100%|██████████| 696/696 [23:12<00:00,  2.00s/it, loss=0.241]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.72it/s, loss=0.37] 


Validation Dice Score
Class 0: 0.9907, Class 1: 0.6939, Class 2: 0.1032, Class 3: 0.3032, Class 4: 0.4930, Class 5: 0.4611, Class 6: 0.4613, 
Validation F-beta Score
Class 0: 0.9929, Class 1: 0.8484, Class 2: 0.1207, Class 3: 0.2898, Class 4: 0.4509, Class 5: 0.4371, Class 6: 0.6349, 

Overall Mean Dice Score: 0.4825
Overall Mean F-beta Score: 0.5322

Training Loss: 0.2886, Validation Loss: 0.3362, Validation F-beta: 0.5322
Epoch 95/4000


Training: 100%|██████████| 696/696 [21:42<00:00,  1.87s/it, loss=0.383]
Validation: 100%|██████████| 12/12 [00:07<00:00,  1.70it/s, loss=0.317]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.7371, Class 2: 0.2874, Class 3: 0.3876, Class 4: 0.6062, Class 5: 0.4708, Class 6: 0.7951, 
Validation F-beta Score
Class 0: 0.9931, Class 1: 0.7616, Class 2: 0.2641, Class 3: 0.3726, Class 4: 0.5300, Class 5: 0.4586, Class 6: 0.8370, 

Overall Mean Dice Score: 0.5993
Overall Mean F-beta Score: 0.5920

Training Loss: 0.2906, Validation Loss: 0.2915, Validation F-beta: 0.5920
Epoch 96/4000


Training: 100%|██████████| 696/696 [22:51<00:00,  1.97s/it, loss=0.312]
Validation: 100%|██████████| 12/12 [00:07<00:00,  1.68it/s, loss=0.388]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.7978, Class 2: 0.1639, Class 3: 0.3568, Class 4: 0.6824, Class 5: 0.3203, Class 6: 0.7002, 
Validation F-beta Score
Class 0: 0.9906, Class 1: 0.8970, Class 2: 0.1312, Class 3: 0.3802, Class 4: 0.7352, Class 5: 0.3190, Class 6: 0.7987, 

Overall Mean Dice Score: 0.5715
Overall Mean F-beta Score: 0.6260

Training Loss: 0.2847, Validation Loss: 0.3089, Validation F-beta: 0.6260
Epoch 97/4000


Training: 100%|██████████| 696/696 [20:27<00:00,  1.76s/it, loss=0.275]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.25it/s, loss=0.35] 


Validation Dice Score
Class 0: 0.9934, Class 1: 0.6063, Class 2: 0.1415, Class 3: 0.2834, Class 4: 0.6122, Class 5: 0.4530, Class 6: 0.7043, 
Validation F-beta Score
Class 0: 0.9923, Class 1: 0.6660, Class 2: 0.1393, Class 3: 0.3382, Class 4: 0.6820, Class 5: 0.5263, Class 6: 0.7133, 

Overall Mean Dice Score: 0.5318
Overall Mean F-beta Score: 0.5852

Training Loss: 0.2857, Validation Loss: 0.3109, Validation F-beta: 0.5852
Epoch 98/4000


Training: 100%|██████████| 696/696 [16:47<00:00,  1.45s/it, loss=0.339]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.93it/s, loss=0.282]


Validation Dice Score
Class 0: 0.9910, Class 1: 0.6313, Class 2: 0.1744, Class 3: 0.2759, Class 4: 0.6013, Class 5: 0.4823, Class 6: 0.5458, 
Validation F-beta Score
Class 0: 0.9935, Class 1: 0.9115, Class 2: 0.1602, Class 3: 0.3298, Class 4: 0.5578, Class 5: 0.4494, Class 6: 0.7219, 

Overall Mean Dice Score: 0.5073
Overall Mean F-beta Score: 0.5941

Training Loss: 0.2860, Validation Loss: 0.3331, Validation F-beta: 0.5941
Epoch 99/4000


Training: 100%|██████████| 696/696 [16:43<00:00,  1.44s/it, loss=0.255]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.83it/s, loss=0.404]


Validation Dice Score
Class 0: 0.9922, Class 1: 0.6891, Class 2: 0.1443, Class 3: 0.4317, Class 4: 0.5753, Class 5: 0.4575, Class 6: 0.6239, 
Validation F-beta Score
Class 0: 0.9935, Class 1: 0.8224, Class 2: 0.1374, Class 3: 0.4380, Class 4: 0.5538, Class 5: 0.4805, Class 6: 0.6412, 

Overall Mean Dice Score: 0.5555
Overall Mean F-beta Score: 0.5872

Training Loss: 0.2847, Validation Loss: 0.3119, Validation F-beta: 0.5872
Epoch 100/4000


Training: 100%|██████████| 696/696 [18:26<00:00,  1.59s/it, loss=0.353]
Validation: 100%|██████████| 12/12 [00:06<00:00,  1.93it/s, loss=0.38] 


Validation Dice Score
Class 0: 0.9908, Class 1: 0.6187, Class 2: 0.2218, Class 3: 0.3210, Class 4: 0.6951, Class 5: 0.4295, Class 6: 0.5522, 
Validation F-beta Score
Class 0: 0.9912, Class 1: 0.6622, Class 2: 0.1894, Class 3: 0.3558, Class 4: 0.6606, Class 5: 0.4737, Class 6: 0.5645, 

Overall Mean Dice Score: 0.5233
Overall Mean F-beta Score: 0.5434

Training Loss: 0.2857, Validation Loss: 0.3174, Validation F-beta: 0.5434
Epoch 101/4000


Training:  35%|███▌      | 246/696 [06:28<12:58,  1.73s/it, loss=0.279]

In [12]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# VAl

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import TverskyLoss
import torch
import numpy as np
from tqdm import tqdm
import wandb
from src.dataset.dataset import make_val_dataloader

val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
lamda = 0.52

wandb.init(
    project='czii_SwinUnetR_val',  # 프로젝트 이름 설정
    name='SwinUNETR96_96_lr0.001_lambda0.52_batch2',         # 실행(run) 이름 설정
    config={
        'learning_rate': 0.001,
        'batch_size': batch_size,
        'lambda': lamda,
        'img_size': img_size,
        'device': 'cuda',
        "checkpoint_dir": "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2",
        
    }
)

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

val_loader = make_val_dataloader(
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0
)
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    softmax=True
)
    
    
from monai.metrics import DiceMetric

img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

val_loss, overall_mean_fbeta_score = validate_one_epoch(
    model=model, 
    val_loader=val_loader, 
    criterion=criterion, 
    device=device, 
    epoch=0, 
    calculate_dice_interval=1
)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
class_0_dice_score,▁
class_0_f_beta_score,▁
class_1_dice_score,▁
class_1_f_beta_score,▁
class_2_dice_score,▁
class_2_f_beta_score,▁
class_3_dice_score,▁
class_3_f_beta_score,▁
class_4_dice_score,▁
class_4_f_beta_score,▁

0,1
class_0_dice_score,0.65703
class_0_f_beta_score,0.50748
class_1_dice_score,0.53332
class_1_f_beta_score,0.64703
class_2_dice_score,0.00286
class_2_f_beta_score,0.02334
class_3_dice_score,0.23703
class_3_f_beta_score,0.23033
class_4_dice_score,0.65487
class_4_f_beta_score,0.62525


Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]
  checkpoint = torch.load(pretrain_path, map_location=device)
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s, loss=0.865]

Validation Dice Score
Class 0: 0.6570, Class 1: 0.5333, Class 2: 0.0029, Class 3: 0.2370, 
Class 4: 0.6549, Class 5: 0.4790, Class 6: 0.4255, 
Validation F-beta Score
Class 0: 0.5075, Class 1: 0.6470, Class 2: 0.0233, Class 3: 0.2303, 
Class 4: 0.6252, Class 5: 0.5145, Class 6: 0.4720, 
Overall Mean Dice Score: 0.4659
Overall Mean F-beta Score: 0.4978






# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

In [None]:
val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=1, 
            calculate_dice_interval=0
        )

Validation:   0%|          | 0/4 [00:03<?, ?it/s, loss=0.764]


ZeroDivisionError: integer modulo by zero

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv
