In [None]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)


import torch

print_config()

In [None]:
# class_info = {
#     0: {"name": "background", "weight": 10000},  # weight 없음
#     1: {"name": "apo-ferritin", "weight": 300},
#     2: {"name": "beta-amylase", "weight": 100}, # 4130
#     3: {"name": "beta-galactosidase", "weight": 150}, #3080
#     4: {"name": "ribosome", "weight": 6000},
#     5: {"name": "thyroglobulin", "weight": 4000},
#     6: {"name": "virus-like-particle", "weight": 2000},
# }

# # 가중치에 비례한 비율 계산
# raw_ratios = {
#     k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값
#     for k, v in class_info.items()
# }
# total = sum(raw_ratios.values())
# ratios = {k: v / total for k, v in raw_ratios.items()}

# # 최종 합계가 1인지 확인
# final_total = sum(ratios.values())
# print("클래스 비율:", ratios)
# print("최종 합계:", final_total)

# # 비율을 리스트로 변환
# ratios_list = [ratios[k] for k in sorted(ratios.keys())]
# print("클래스 비율 리스트:", ratios_list)

In [None]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 2000}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 2000},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.14084507042253522, 2: 0.014084507042253521, 3: 0.28169014084507044, 4: 0.14084507042253522, 5: 0.28169014084507044, 6: 0.14084507042253522}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.14084507042253522, 0.014084507042253521, 0.28169014084507044, 0.14084507042253522, 0.28169014084507044, 0.14084507042253522]


In [None]:
from src.dataset.dataset import create_dataloaders
import numpy as np
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd, CastToTyped,
)

train_img_dir = "./datasets/train/images"
train_label_dir = "./datasets/train/labels"
val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
# DATA CONFIG
img_depth = 96
img_size =  96 # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
# MODEL CONFIG
num_epochs = 4000
lamda = 0.52
lr = 0.001
feature_size = 48
use_checkpoint = True
use_v2 = True
drop_rate= 0.25
attn_drop_rate = 0.25
# CLASS_WEIGHTS
class_weights = None
class_weights = torch.tensor([0.001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치
ce_weight = 0.4
accumulation_steps = 4
# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    # CastToTyped(keys=["image"], dtype=np.float16),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
])


In [None]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0)

Loading dataset: 100%|██████████| 24/24 [00:41<00:00,  1.71s/it]
Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.62s/it]


https://monai.io/model-zoo.html

In [None]:
from monai.losses import TverskyLoss
import torch

def loss_fn(loss, class_weights, device):
    """
    Tversky 손실에 클래스별 가중치를 적용하여 최종 스칼라 값을 반환합니다.

    Args:
        loss: Tversky 손실 텐서 (B, num_classes, H, W, D).
        class_weights: 클래스별 가중치 텐서 (num_classes,).
        device: 사용할 장치 (예: 'cuda' 또는 'cpu').

    Returns:
        torch.Tensor: 최종 가중 평균 손실 값 (스칼라).
    """
    # 가중치를 device로 이동
    class_weights = class_weights.to(device)

    # 클래스 차원에 가중치 적용 (B, num_classes, ...)
    class_weights = class_weights.view(1, n_classes, 1, 1, 1)  # [1, num_classes, 1, 1, 1]
    weighted_loss = loss * class_weights

    # 모든 차원을 평균 내어 스칼라 손실 반환
    final_loss = torch.mean(weighted_loss)
    return final_loss

In [None]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
from monai.metrics import DiceMetric
from monai.losses import TverskyLoss
import torch.nn as nn
import torch.nn.functional as F

# TverskyLoss 설정
'''
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    reduction="none",  # 각 픽셀에 대한 손실 반환
    softmax=True
)
'''

class DynamicTverskyLoss(TverskyLoss):
    def __init__(self, lamda=0.5, **kwargs):
        super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
        self.lamda = lamda

    def set_lamda(self, lamda):
        self.lamda = lamda
        self.alpha = 1 - lamda
        self.beta = lamda
        
# criterion = DynamicTverskyLoss(
#     lamda=0.5,
#     include_background=False,
#     reduction="mean",
#     softmax=True
# )

class CombinedCETverskyLoss(nn.Module):
    def __init__(self, lamda=0.5, ce_weight=0.5, **kwargs):
        super().__init__()
        self._lamda = lamda  # lamda 값 저장
        self.tversky = DynamicTverskyLoss(lamda=lamda, **kwargs)
        self.ce = nn.CrossEntropyLoss()
        self.ce_weight = ce_weight
        
    def forward(self, inputs, targets):
        tversky_loss = self.tversky(inputs, targets)
        ce_loss = self.ce(inputs, targets)
        return self.ce_weight * ce_loss + (1 - self.ce_weight) * tversky_loss
    
    def set_lamda(self, lamda):
        self._lamda = lamda
        self.tversky.set_lamda(lamda)
    
    @property
    def lamda(self):
        return self._lamda

# 사용 예시
criterion = CombinedCETverskyLoss(
    lamda=0.5,
    ce_weight=ce_weight,  # CE Loss와 Tversky Loss의 비중을 ce_weight:1-ce_weight로 설정
    include_background=False,
    reduction="mean",
    softmax=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=feature_size,
    use_checkpoint=True,
    drop_rate = drop_rate,
    attn_drop_rate = attn_drop_rate,
    use_v2 = use_v2,
).to(device)

pretrain_str = "yes" if use_checkpoint else "no"
weight_str = "weighted" if class_weights is not None else ""

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
checkpoint_dir = checkpoint_base_dir / f"SwinUNETR_v2_step4_p{pretrain_str}_{weight_str}_f{feature_size}_d{img_depth}_s{img_size}_lr{lr:.0e}_a{lamda:.2f}_ce{ce_weight:.2f}_batch{batch_size}"
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")



In [None]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([2, 1, 96, 96, 96]) torch.Size([2, 1, 96, 96, 96])


In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f'SwinUNETR_p{pretrain_str}_{weight_str}_f{feature_size}_d{img_depth}_s{img_size}_lr{lr:.0e}_ce{ce_weight:.2f}_batch{batch_size}_{current_time}'

# wandb 초기화
wandb.init(
    project='czii_SwinUnetR',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        "attn_drop_rate": attn_drop_rate,
        "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpook0612[0m ([33mlimbw[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from monai.metrics import DiceMetric

def create_metric_dict(num_classes):
    """각 클래스별 DiceMetric 생성"""
    metrics = {}
    for i in range(num_classes):
        metrics[f'dice_class_{i}'] = DiceMetric(
            include_background=False if i == 0 else False,
            reduction="mean",
            get_not_nans=False
        )
    return metrics
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    # loss = loss_fn(criterion(outputs, labels_onehot), class_weights, device)
    loss = loss_fn(criterion(outputs, labels_onehot),class_weights=class_weights, device=device)
    return loss, outputs, labels, outputs.argmax(dim=1)

# def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
#     model.train()
#     epoch_loss = 0
#     optimizer.zero_grad()  # 그래디언트 초기화
#     with tqdm(train_loader, desc='Training') as pbar:
#         for i, batch_data in enumerate(pbar):
#             # 손실 계산
#             loss, _, _, _ = processing(batch_data, model, criterion, device)

#             # 그래디언트를 계산하고 누적
#             loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
#             loss.backward()  # 그래디언트 계산 및 누적
            
#             # 그래디언트 업데이트 (accumulation_steps마다 한 번)
#             if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
#                 optimizer.step()  # 파라미터 업데이트
#                 optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
#             # 손실값 누적 (스케일링 복구)
#             epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
#             pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
#     avg_loss = epoch_loss / len(train_loader)
#     wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
#     return avg_loss

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    data_iter = iter(train_loader)  # 이터레이터 생성
    with tqdm(total=len(train_loader)*accumulation_steps, desc='Training') as pbar:
        for i in range(len(train_loader)*accumulation_steps):
            try:
                batch_data = next(data_iter)
            except StopIteration:
                # DataLoader를 전부 소진한 경우, 다시 이터레이터를 만들어 이어서 사용
                data_iter = iter(train_loader)
                batch_data = next(data_iter)
            
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item()  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item())  # 실제 손실값 출력
            pbar.update(1)

    avg_loss = epoch_loss / len(train_loader)
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss

def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval):
    model.eval()
    val_loss = 0
    metrics = create_metric_dict(n_classes)
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            if i == 3:
                print()
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
            if i == 3:
                print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1})
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\n")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    return val_loss / len(val_loader), overall_mean_fbeta


def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
         # 현재 lambda 값 출력
        print(f"Current lambda: {criterion.lamda:.4f}")
        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps,
        )
        scheduler.step(train_loss)
        
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F-beta: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")
        
        if overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_fbeta_score = overall_mean_fbeta_score
            print(f"========================================================")
            print(f"NEW validation fbeta score: {best_val_fbeta_score:.4f}")
            print(f"========================================================")

        # Validation Loss 기준 모델 저장
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model_val_loss.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"Best model saved based on validation loss: {best_val_loss:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        if epochs_no_improve%6 == 0:
            # 손실이 개선되지 않았으므로 lambda 감소
            new_lamda = max(criterion.lamda - 0.01, 0.1)  # 최소값은 0.1로 설정
            criterion.set_lamda(new_lamda)
            print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")


    wandb.finish()


In [None]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps
)

Epoch 1/4000
Current lambda: 0.5000


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Training: 100%|██████████| 288/288 [07:54<00:00,  1.65s/it, loss=0.185]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.717]


Validation Dice Score
Class 0: 0.5921, Class 1: 0.0192, Class 2: 0.0007, Class 3: 0.0011, 
Class 4: 0.3519, Class 5: 0.0617, Class 6: 0.0029, 
Validation F-beta Score
Class 0: 0.4371, Class 1: 0.0962, Class 2: 0.0056, Class 3: 0.0074, 
Class 4: 0.5433, Class 5: 0.2312, Class 6: 0.0065, 
Overall Mean Dice Score: 0.0874
Overall Mean F-beta Score: 0.1769

Training Loss: 0.7203, Validation Loss: 0.6979, Validation F-beta: 0.1769
SUPER Best model saved. Loss:0.6979, Score:0.1769
Epoch 2/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.157]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.62] 


Validation Dice Score
Class 0: 0.6310, Class 1: 0.0354, Class 2: 0.0007, Class 3: 0.0012, 
Class 4: 0.4243, Class 5: 0.2246, Class 6: 0.1292, 
Validation F-beta Score
Class 0: 0.4791, Class 1: 0.0728, Class 2: 0.0054, Class 3: 0.0074, 
Class 4: 0.4746, Class 5: 0.2623, Class 6: 0.1982, 
Overall Mean Dice Score: 0.1629
Overall Mean F-beta Score: 0.2031

Training Loss: 0.6871, Validation Loss: 0.6607, Validation F-beta: 0.2031
SUPER Best model saved. Loss:0.6607, Score:0.2031
Epoch 3/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:36<00:00,  1.59s/it, loss=0.18] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.43it/s, loss=0.599]


Validation Dice Score
Class 0: 0.6348, Class 1: 0.0327, Class 2: 0.0012, Class 3: 0.0031, 
Class 4: 0.3718, Class 5: 0.2613, Class 6: 0.1836, 
Validation F-beta Score
Class 0: 0.4841, Class 1: 0.0354, Class 2: 0.0097, Class 3: 0.0155, 
Class 4: 0.3590, Class 5: 0.2879, Class 6: 0.2289, 
Overall Mean Dice Score: 0.1705
Overall Mean F-beta Score: 0.1854

Training Loss: 0.6565, Validation Loss: 0.6602, Validation F-beta: 0.1854
Best model saved based on validation loss: 0.6602
Epoch 4/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:40<00:00,  1.60s/it, loss=0.135]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.42it/s, loss=0.567]


Validation Dice Score
Class 0: 0.6592, Class 1: 0.3241, Class 2: 0.0010, Class 3: 0.0372, 
Class 4: 0.3225, Class 5: 0.2962, Class 6: 0.2139, 
Validation F-beta Score
Class 0: 0.5126, Class 1: 0.5617, Class 2: 0.0081, Class 3: 0.0679, 
Class 4: 0.2800, Class 5: 0.3355, Class 6: 0.2690, 
Overall Mean Dice Score: 0.2388
Overall Mean F-beta Score: 0.3028

Training Loss: 0.6373, Validation Loss: 0.6308, Validation F-beta: 0.3028
SUPER Best model saved. Loss:0.6308, Score:0.3028
Epoch 5/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:39<00:00,  1.60s/it, loss=0.137]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.678]


Validation Dice Score
Class 0: 0.6707, Class 1: 0.5264, Class 2: 0.0008, Class 3: 0.0521, 
Class 4: 0.3916, Class 5: 0.3885, Class 6: 0.3109, 
Validation F-beta Score
Class 0: 0.5247, Class 1: 0.4820, Class 2: 0.0065, Class 3: 0.1271, 
Class 4: 0.4575, Class 5: 0.3942, Class 6: 0.3926, 
Overall Mean Dice Score: 0.3339
Overall Mean F-beta Score: 0.3707

Training Loss: 0.6075, Validation Loss: 0.5844, Validation F-beta: 0.3707
SUPER Best model saved. Loss:0.5844, Score:0.3707
Epoch 6/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.138]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.42it/s, loss=0.658]


Validation Dice Score
Class 0: 0.6757, Class 1: 0.5128, Class 2: 0.0009, Class 3: 0.1617, 
Class 4: 0.3819, Class 5: 0.2885, Class 6: 0.3961, 
Validation F-beta Score
Class 0: 0.5310, Class 1: 0.5966, Class 2: 0.0071, Class 3: 0.2084, 
Class 4: 0.3571, Class 5: 0.3019, Class 6: 0.5327, 
Overall Mean Dice Score: 0.3482
Overall Mean F-beta Score: 0.3993

Training Loss: 0.5960, Validation Loss: 0.5713, Validation F-beta: 0.3993
SUPER Best model saved. Loss:0.5713, Score:0.3993
Epoch 7/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.137]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.45it/s, loss=0.38] 


Validation Dice Score
Class 0: 0.6815, Class 1: 0.5935, Class 2: 0.0009, Class 3: 0.2256, 
Class 4: 0.3717, Class 5: 0.3242, Class 6: 0.3370, 
Validation F-beta Score
Class 0: 0.5363, Class 1: 0.6828, Class 2: 0.0076, Class 3: 0.3114, 
Class 4: 0.4345, Class 5: 0.3376, Class 6: 0.4141, 
Overall Mean Dice Score: 0.3704
Overall Mean F-beta Score: 0.4361

Training Loss: 0.5899, Validation Loss: 0.5605, Validation F-beta: 0.4361
SUPER Best model saved. Loss:0.5605, Score:0.4361
Epoch 8/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.146]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.572]


Validation Dice Score
Class 0: 0.6788, Class 1: 0.3457, Class 2: 0.0008, Class 3: 0.2711, 
Class 4: 0.4563, Class 5: 0.3605, Class 6: 0.4496, 
Validation F-beta Score
Class 0: 0.5330, Class 1: 0.6019, Class 2: 0.0065, Class 3: 0.3085, 
Class 4: 0.5304, Class 5: 0.3693, Class 6: 0.4672, 
Overall Mean Dice Score: 0.3766
Overall Mean F-beta Score: 0.4554

Training Loss: 0.5772, Validation Loss: 0.5569, Validation F-beta: 0.4554
SUPER Best model saved. Loss:0.5569, Score:0.4554
Epoch 9/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.159]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.44it/s, loss=0.478]


Validation Dice Score
Class 0: 0.6917, Class 1: 0.3968, Class 2: 0.0015, Class 3: 0.1817, 
Class 4: 0.3325, Class 5: 0.3243, Class 6: 0.3737, 
Validation F-beta Score
Class 0: 0.5495, Class 1: 0.4880, Class 2: 0.0120, Class 3: 0.2822, 
Class 4: 0.2705, Class 5: 0.4425, Class 6: 0.4951, 
Overall Mean Dice Score: 0.3218
Overall Mean F-beta Score: 0.3957

Training Loss: 0.5721, Validation Loss: 0.6001, Validation F-beta: 0.3957
Epoch 10/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.135]
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.32it/s, loss=0.639]


Validation Dice Score
Class 0: 0.6877, Class 1: 0.3067, Class 2: 0.0011, Class 3: 0.1149, 
Class 4: 0.5048, Class 5: 0.3360, Class 6: 0.4985, 
Validation F-beta Score
Class 0: 0.5435, Class 1: 0.4030, Class 2: 0.0090, Class 3: 0.1198, 
Class 4: 0.5929, Class 5: 0.3465, Class 6: 0.4723, 
Overall Mean Dice Score: 0.3522
Overall Mean F-beta Score: 0.3869

Training Loss: 0.5701, Validation Loss: 0.5805, Validation F-beta: 0.3869
Epoch 11/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:39<00:00,  1.60s/it, loss=0.14] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.487]


Validation Dice Score
Class 0: 0.6926, Class 1: 0.6367, Class 2: 0.0006, Class 3: 0.2179, 
Class 4: 0.5340, Class 5: 0.3886, Class 6: 0.5367, 
Validation F-beta Score
Class 0: 0.5494, Class 1: 0.7463, Class 2: 0.0048, Class 3: 0.2702, 
Class 4: 0.5760, Class 5: 0.4203, Class 6: 0.5592, 
Overall Mean Dice Score: 0.4628
Overall Mean F-beta Score: 0.5144

Training Loss: 0.5678, Validation Loss: 0.5239, Validation F-beta: 0.5144
SUPER Best model saved. Loss:0.5239, Score:0.5144
Epoch 12/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:39<00:00,  1.60s/it, loss=0.148]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.479]


Validation Dice Score
Class 0: 0.7001, Class 1: 0.6230, Class 2: 0.0010, Class 3: 0.2683, 
Class 4: 0.4976, Class 5: 0.3420, Class 6: 0.3708, 
Validation F-beta Score
Class 0: 0.5574, Class 1: 0.6413, Class 2: 0.0083, Class 3: 0.3183, 
Class 4: 0.5181, Class 5: 0.5006, Class 6: 0.5569, 
Overall Mean Dice Score: 0.4203
Overall Mean F-beta Score: 0.5070

Training Loss: 0.5614, Validation Loss: 0.5431, Validation F-beta: 0.5070
Epoch 13/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:40<00:00,  1.60s/it, loss=0.115] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.479]


Validation Dice Score
Class 0: 0.7036, Class 1: 0.4634, Class 2: 0.0009, Class 3: 0.2719, 
Class 4: 0.5013, Class 5: 0.3478, Class 6: 0.5100, 
Validation F-beta Score
Class 0: 0.5638, Class 1: 0.5791, Class 2: 0.0073, Class 3: 0.3801, 
Class 4: 0.5079, Class 5: 0.4228, Class 6: 0.5789, 
Overall Mean Dice Score: 0.4189
Overall Mean F-beta Score: 0.4937

Training Loss: 0.5519, Validation Loss: 0.5431, Validation F-beta: 0.4937
Epoch 14/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:39<00:00,  1.59s/it, loss=0.139]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.34it/s, loss=0.622]


Validation Dice Score
Class 0: 0.7210, Class 1: 0.6187, Class 2: 0.0009, Class 3: 0.2929, 
Class 4: 0.4864, Class 5: 0.3414, Class 6: 0.6266, 
Validation F-beta Score
Class 0: 0.5826, Class 1: 0.6484, Class 2: 0.0071, Class 3: 0.4708, 
Class 4: 0.5446, Class 5: 0.4351, Class 6: 0.6232, 
Overall Mean Dice Score: 0.4732
Overall Mean F-beta Score: 0.5444

Training Loss: 0.5590, Validation Loss: 0.5134, Validation F-beta: 0.5444
SUPER Best model saved. Loss:0.5134, Score:0.5444
Epoch 15/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:39<00:00,  1.59s/it, loss=0.111] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.36it/s, loss=0.56] 


Validation Dice Score
Class 0: 0.7379, Class 1: 0.4485, Class 2: 0.0005, Class 3: 0.2057, 
Class 4: 0.3791, Class 5: 0.4222, Class 6: 0.5292, 
Validation F-beta Score
Class 0: 0.6048, Class 1: 0.4344, Class 2: 0.0038, Class 3: 0.2189, 
Class 4: 0.3731, Class 5: 0.4862, Class 6: 0.5284, 
Overall Mean Dice Score: 0.3970
Overall Mean F-beta Score: 0.4082

Training Loss: 0.5509, Validation Loss: 0.5524, Validation F-beta: 0.4082
Epoch 16/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.142]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.36it/s, loss=0.567]


Validation Dice Score
Class 0: 0.7566, Class 1: 0.5140, Class 2: 0.0009, Class 3: 0.1939, 
Class 4: 0.5087, Class 5: 0.3507, Class 6: 0.6374, 
Validation F-beta Score
Class 0: 0.6279, Class 1: 0.6826, Class 2: 0.0070, Class 3: 0.1901, 
Class 4: 0.5648, Class 5: 0.4013, Class 6: 0.7651, 
Overall Mean Dice Score: 0.4409
Overall Mean F-beta Score: 0.5208

Training Loss: 0.5560, Validation Loss: 0.5478, Validation F-beta: 0.5208
Epoch 17/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.143]
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.33it/s, loss=0.606]


Validation Dice Score
Class 0: 0.7945, Class 1: 0.4637, Class 2: 0.0017, Class 3: 0.2711, 
Class 4: 0.5574, Class 5: 0.3801, Class 6: 0.5429, 
Validation F-beta Score
Class 0: 0.6781, Class 1: 0.5058, Class 2: 0.0139, Class 3: 0.3290, 
Class 4: 0.5614, Class 5: 0.4296, Class 6: 0.5228, 
Overall Mean Dice Score: 0.4430
Overall Mean F-beta Score: 0.4697

Training Loss: 0.5530, Validation Loss: 0.5206, Validation F-beta: 0.4697
Epoch 18/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.144] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.555]


Validation Dice Score
Class 0: 0.8478, Class 1: 0.5915, Class 2: 0.0012, Class 3: 0.2585, 
Class 4: 0.4322, Class 5: 0.2587, Class 6: 0.4753, 
Validation F-beta Score
Class 0: 0.7525, Class 1: 0.7190, Class 2: 0.0095, Class 3: 0.3269, 
Class 4: 0.4736, Class 5: 0.2926, Class 6: 0.5231, 
Overall Mean Dice Score: 0.4032
Overall Mean F-beta Score: 0.4671

Training Loss: 0.5457, Validation Loss: 0.5445, Validation F-beta: 0.4671
Epoch 19/4000
Current lambda: 0.5000


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.167] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.563]


Validation Dice Score
Class 0: 0.9142, Class 1: 0.5335, Class 2: 0.0015, Class 3: 0.3470, 
Class 4: 0.4027, Class 5: 0.3381, Class 6: 0.5903, 
Validation F-beta Score
Class 0: 0.8607, Class 1: 0.4828, Class 2: 0.0111, Class 3: 0.4413, 
Class 4: 0.3325, Class 5: 0.3637, Class 6: 0.7071, 
Overall Mean Dice Score: 0.4423
Overall Mean F-beta Score: 0.4655

Training Loss: 0.5455, Validation Loss: 0.5452, Validation F-beta: 0.4655
Validation loss did not improve. Reducing lambda to 0.4900
Epoch 20/4000
Current lambda: 0.4900


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.14]  
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.30it/s, loss=0.611]


Validation Dice Score
Class 0: 0.9832, Class 1: 0.5200, Class 2: 0.0025, Class 3: 0.3986, 
Class 4: 0.6277, Class 5: 0.3168, Class 6: 0.4470, 
Validation F-beta Score
Class 0: 0.9765, Class 1: 0.5430, Class 2: 0.0079, Class 3: 0.4985, 
Class 4: 0.6600, Class 5: 0.3692, Class 6: 0.4522, 
Overall Mean Dice Score: 0.4620
Overall Mean F-beta Score: 0.5046

Training Loss: 0.5367, Validation Loss: 0.5190, Validation F-beta: 0.5046
Epoch 21/4000
Current lambda: 0.4900


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.111] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.504]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.4673, Class 2: 0.0506, Class 3: 0.3616, 
Class 4: 0.5457, Class 5: 0.4383, Class 6: 0.4759, 
Validation F-beta Score
Class 0: 0.9874, Class 1: 0.5142, Class 2: 0.0953, Class 3: 0.4418, 
Class 4: 0.5605, Class 5: 0.4746, Class 6: 0.5527, 
Overall Mean Dice Score: 0.4578
Overall Mean F-beta Score: 0.5088

Training Loss: 0.5380, Validation Loss: 0.5179, Validation F-beta: 0.5088
Epoch 22/4000
Current lambda: 0.4900


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.11]  
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.378]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.5317, Class 2: 0.0204, Class 3: 0.3191, 
Class 4: 0.5868, Class 5: 0.3907, Class 6: 0.4851, 
Validation F-beta Score
Class 0: 0.9866, Class 1: 0.5317, Class 2: 0.0484, Class 3: 0.3780, 
Class 4: 0.5537, Class 5: 0.4588, Class 6: 0.5103, 
Overall Mean Dice Score: 0.4627
Overall Mean F-beta Score: 0.4865

Training Loss: 0.5431, Validation Loss: 0.5211, Validation F-beta: 0.4865
Epoch 23/4000
Current lambda: 0.4900


Training: 100%|██████████| 288/288 [07:39<00:00,  1.59s/it, loss=0.142] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.552]


Validation Dice Score
Class 0: 0.9848, Class 1: 0.4802, Class 2: 0.0378, Class 3: 0.2114, 
Class 4: 0.6118, Class 5: 0.3221, Class 6: 0.4735, 
Validation F-beta Score
Class 0: 0.9808, Class 1: 0.6318, Class 2: 0.0742, Class 3: 0.3084, 
Class 4: 0.6888, Class 5: 0.3479, Class 6: 0.4704, 
Overall Mean Dice Score: 0.4198
Overall Mean F-beta Score: 0.4895

Training Loss: 0.5349, Validation Loss: 0.5291, Validation F-beta: 0.4895
Epoch 24/4000
Current lambda: 0.4900


Training: 100%|██████████| 288/288 [07:39<00:00,  1.60s/it, loss=0.118] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.607]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.3983, Class 2: 0.0843, Class 3: 0.3918, 
Class 4: 0.5901, Class 5: 0.3704, Class 6: 0.3934, 
Validation F-beta Score
Class 0: 0.9866, Class 1: 0.5082, Class 2: 0.1298, Class 3: 0.5087, 
Class 4: 0.5504, Class 5: 0.4264, Class 6: 0.4457, 
Overall Mean Dice Score: 0.4288
Overall Mean F-beta Score: 0.4879

Training Loss: 0.5336, Validation Loss: 0.5499, Validation F-beta: 0.4879
Epoch 25/4000
Current lambda: 0.4900


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.131] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.5]  


Validation Dice Score
Class 0: 0.9881, Class 1: 0.5030, Class 2: 0.1161, Class 3: 0.3018, 
Class 4: 0.4848, Class 5: 0.4168, Class 6: 0.5854, 
Validation F-beta Score
Class 0: 0.9837, Class 1: 0.5322, Class 2: 0.1534, Class 3: 0.2901, 
Class 4: 0.6224, Class 5: 0.4714, Class 6: 0.6564, 
Overall Mean Dice Score: 0.4584
Overall Mean F-beta Score: 0.5145

Training Loss: 0.5322, Validation Loss: 0.5205, Validation F-beta: 0.5145
Validation loss did not improve. Reducing lambda to 0.4800
Epoch 26/4000
Current lambda: 0.4800


Training: 100%|██████████| 288/288 [07:39<00:00,  1.60s/it, loss=0.12]  
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.491]


Validation Dice Score
Class 0: 0.9870, Class 1: 0.7699, Class 2: 0.0738, Class 3: 0.2911, 
Class 4: 0.4420, Class 5: 0.4091, Class 6: 0.7362, 
Validation F-beta Score
Class 0: 0.9842, Class 1: 0.8634, Class 2: 0.1406, Class 3: 0.3144, 
Class 4: 0.4969, Class 5: 0.4751, Class 6: 0.8461, 
Overall Mean Dice Score: 0.5296
Overall Mean F-beta Score: 0.5992

Training Loss: 0.5377, Validation Loss: 0.4779, Validation F-beta: 0.5992
SUPER Best model saved. Loss:0.4779, Score:0.5992
Epoch 27/4000
Current lambda: 0.4800


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.142] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.414]


Validation Dice Score
Class 0: 0.9875, Class 1: 0.4761, Class 2: 0.0563, Class 3: 0.2203, 
Class 4: 0.5394, Class 5: 0.4051, Class 6: 0.4272, 
Validation F-beta Score
Class 0: 0.9860, Class 1: 0.5197, Class 2: 0.0973, Class 3: 0.2632, 
Class 4: 0.5842, Class 5: 0.4059, Class 6: 0.4873, 
Overall Mean Dice Score: 0.4136
Overall Mean F-beta Score: 0.4521

Training Loss: 0.5263, Validation Loss: 0.5308, Validation F-beta: 0.4521
Epoch 28/4000
Current lambda: 0.4800


Training: 100%|██████████| 288/288 [07:39<00:00,  1.60s/it, loss=0.121] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.561]


Validation Dice Score
Class 0: 0.9884, Class 1: 0.6384, Class 2: 0.1046, Class 3: 0.2976, 
Class 4: 0.4569, Class 5: 0.4067, Class 6: 0.6385, 
Validation F-beta Score
Class 0: 0.9844, Class 1: 0.7494, Class 2: 0.2543, Class 3: 0.3627, 
Class 4: 0.4830, Class 5: 0.5094, Class 6: 0.6017, 
Overall Mean Dice Score: 0.4876
Overall Mean F-beta Score: 0.5413

Training Loss: 0.5242, Validation Loss: 0.5082, Validation F-beta: 0.5413
Epoch 29/4000
Current lambda: 0.4800


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.137] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.472]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.3899, Class 2: 0.0614, Class 3: 0.3519, 
Class 4: 0.5040, Class 5: 0.3661, Class 6: 0.6123, 
Validation F-beta Score
Class 0: 0.9881, Class 1: 0.5187, Class 2: 0.0955, Class 3: 0.3947, 
Class 4: 0.4960, Class 5: 0.4041, Class 6: 0.6618, 
Overall Mean Dice Score: 0.4448
Overall Mean F-beta Score: 0.4951

Training Loss: 0.5314, Validation Loss: 0.5228, Validation F-beta: 0.4951
Epoch 30/4000
Current lambda: 0.4800


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.137] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.33it/s, loss=0.612]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.4530, Class 2: 0.0549, Class 3: 0.1896, 
Class 4: 0.6355, Class 5: 0.2779, Class 6: 0.6034, 
Validation F-beta Score
Class 0: 0.9863, Class 1: 0.4774, Class 2: 0.0606, Class 3: 0.2083, 
Class 4: 0.6410, Class 5: 0.3940, Class 6: 0.6350, 
Overall Mean Dice Score: 0.4319
Overall Mean F-beta Score: 0.4711

Training Loss: 0.5260, Validation Loss: 0.5363, Validation F-beta: 0.4711
Epoch 31/4000
Current lambda: 0.4800


Training: 100%|██████████| 288/288 [07:36<00:00,  1.59s/it, loss=0.123] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.591]


Validation Dice Score
Class 0: 0.9889, Class 1: 0.5178, Class 2: 0.1927, Class 3: 0.1884, 
Class 4: 0.4822, Class 5: 0.3475, Class 6: 0.5179, 
Validation F-beta Score
Class 0: 0.9863, Class 1: 0.6011, Class 2: 0.2111, Class 3: 0.2459, 
Class 4: 0.4326, Class 5: 0.4235, Class 6: 0.5919, 
Overall Mean Dice Score: 0.4108
Overall Mean F-beta Score: 0.4590

Training Loss: 0.5187, Validation Loss: 0.5327, Validation F-beta: 0.4590
Validation loss did not improve. Reducing lambda to 0.4700
Epoch 32/4000
Current lambda: 0.4700


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.125] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.39] 


Validation Dice Score
Class 0: 0.9906, Class 1: 0.6816, Class 2: 0.0678, Class 3: 0.2260, 
Class 4: 0.5150, Class 5: 0.3934, Class 6: 0.7467, 
Validation F-beta Score
Class 0: 0.9898, Class 1: 0.6496, Class 2: 0.1036, Class 3: 0.2215, 
Class 4: 0.4686, Class 5: 0.4494, Class 6: 0.7958, 
Overall Mean Dice Score: 0.5125
Overall Mean F-beta Score: 0.5170

Training Loss: 0.5209, Validation Loss: 0.4982, Validation F-beta: 0.5170
Epoch 33/4000
Current lambda: 0.4700


Training: 100%|██████████| 288/288 [07:39<00:00,  1.59s/it, loss=0.113] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.571]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.5996, Class 2: 0.0864, Class 3: 0.3825, 
Class 4: 0.5534, Class 5: 0.4399, Class 6: 0.5483, 
Validation F-beta Score
Class 0: 0.9888, Class 1: 0.6437, Class 2: 0.1284, Class 3: 0.3913, 
Class 4: 0.5507, Class 5: 0.4723, Class 6: 0.5549, 
Overall Mean Dice Score: 0.5047
Overall Mean F-beta Score: 0.5226

Training Loss: 0.5253, Validation Loss: 0.4957, Validation F-beta: 0.5226
Epoch 34/4000
Current lambda: 0.4700


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.141] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.642]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.6114, Class 2: 0.1387, Class 3: 0.3466, 
Class 4: 0.4284, Class 5: 0.3672, Class 6: 0.6007, 
Validation F-beta Score
Class 0: 0.9877, Class 1: 0.6791, Class 2: 0.1828, Class 3: 0.3818, 
Class 4: 0.4721, Class 5: 0.4854, Class 6: 0.5816, 
Overall Mean Dice Score: 0.4709
Overall Mean F-beta Score: 0.5200

Training Loss: 0.5331, Validation Loss: 0.5287, Validation F-beta: 0.5200
Epoch 35/4000
Current lambda: 0.4700


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.119] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.545]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.5212, Class 2: 0.1691, Class 3: 0.3835, 
Class 4: 0.5688, Class 5: 0.3122, Class 6: 0.4535, 
Validation F-beta Score
Class 0: 0.9879, Class 1: 0.5682, Class 2: 0.2051, Class 3: 0.4496, 
Class 4: 0.5379, Class 5: 0.3686, Class 6: 0.5041, 
Overall Mean Dice Score: 0.4478
Overall Mean F-beta Score: 0.4857

Training Loss: 0.5261, Validation Loss: 0.5082, Validation F-beta: 0.4857
Epoch 36/4000
Current lambda: 0.4700


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.118] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.484]


Validation Dice Score
Class 0: 0.9894, Class 1: 0.5237, Class 2: 0.0482, Class 3: 0.2799, 
Class 4: 0.6313, Class 5: 0.4034, Class 6: 0.5117, 
Validation F-beta Score
Class 0: 0.9889, Class 1: 0.6400, Class 2: 0.0946, Class 3: 0.2613, 
Class 4: 0.6243, Class 5: 0.4009, Class 6: 0.5480, 
Overall Mean Dice Score: 0.4700
Overall Mean F-beta Score: 0.4949

Training Loss: 0.5210, Validation Loss: 0.5287, Validation F-beta: 0.4949
Epoch 37/4000
Current lambda: 0.4700


Training: 100%|██████████| 288/288 [07:36<00:00,  1.59s/it, loss=0.13]  
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.545]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.6995, Class 2: 0.1176, Class 3: 0.2909, 
Class 4: 0.4568, Class 5: 0.4077, Class 6: 0.6535, 
Validation F-beta Score
Class 0: 0.9865, Class 1: 0.7994, Class 2: 0.1828, Class 3: 0.3525, 
Class 4: 0.5219, Class 5: 0.5272, Class 6: 0.6898, 
Overall Mean Dice Score: 0.5017
Overall Mean F-beta Score: 0.5782

Training Loss: 0.5168, Validation Loss: 0.4960, Validation F-beta: 0.5782
Validation loss did not improve. Reducing lambda to 0.4600
Epoch 38/4000
Current lambda: 0.4600


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.138] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.451]


Validation Dice Score
Class 0: 0.9884, Class 1: 0.6758, Class 2: 0.1475, Class 3: 0.3577, 
Class 4: 0.5082, Class 5: 0.4083, Class 6: 0.7604, 
Validation F-beta Score
Class 0: 0.9855, Class 1: 0.7798, Class 2: 0.2619, Class 3: 0.3755, 
Class 4: 0.5899, Class 5: 0.4526, Class 6: 0.7838, 
Overall Mean Dice Score: 0.5421
Overall Mean F-beta Score: 0.5963

Training Loss: 0.5238, Validation Loss: 0.4733, Validation F-beta: 0.5963
Best model saved based on validation loss: 0.4733
Epoch 39/4000
Current lambda: 0.4600


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.128] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.499]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.4814, Class 2: 0.1199, Class 3: 0.2678, 
Class 4: 0.6169, Class 5: 0.4718, Class 6: 0.5254, 
Validation F-beta Score
Class 0: 0.9894, Class 1: 0.5442, Class 2: 0.1480, Class 3: 0.3089, 
Class 4: 0.6239, Class 5: 0.5172, Class 6: 0.5770, 
Overall Mean Dice Score: 0.4726
Overall Mean F-beta Score: 0.5142

Training Loss: 0.5231, Validation Loss: 0.5102, Validation F-beta: 0.5142
Epoch 40/4000
Current lambda: 0.4600


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.131] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.29it/s, loss=0.398]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.6074, Class 2: 0.0728, Class 3: 0.2514, 
Class 4: 0.4420, Class 5: 0.3897, Class 6: 0.5767, 
Validation F-beta Score
Class 0: 0.9875, Class 1: 0.6647, Class 2: 0.1050, Class 3: 0.3066, 
Class 4: 0.4418, Class 5: 0.3714, Class 6: 0.6023, 
Overall Mean Dice Score: 0.4535
Overall Mean F-beta Score: 0.4773

Training Loss: 0.5243, Validation Loss: 0.5184, Validation F-beta: 0.4773
Epoch 41/4000
Current lambda: 0.4600


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.134] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.414]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.6686, Class 2: 0.1043, Class 3: 0.3946, 
Class 4: 0.5741, Class 5: 0.4176, Class 6: 0.5016, 
Validation F-beta Score
Class 0: 0.9901, Class 1: 0.7069, Class 2: 0.1395, Class 3: 0.4038, 
Class 4: 0.5460, Class 5: 0.4468, Class 6: 0.5344, 
Overall Mean Dice Score: 0.5113
Overall Mean F-beta Score: 0.5276

Training Loss: 0.5217, Validation Loss: 0.4960, Validation F-beta: 0.5276
Epoch 42/4000
Current lambda: 0.4600


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.125] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.633]


Validation Dice Score
Class 0: 0.9880, Class 1: 0.4187, Class 2: 0.1663, Class 3: 0.4121, 
Class 4: 0.5943, Class 5: 0.3990, Class 6: 0.4574, 
Validation F-beta Score
Class 0: 0.9849, Class 1: 0.5147, Class 2: 0.2533, Class 3: 0.4987, 
Class 4: 0.7582, Class 5: 0.4207, Class 6: 0.5084, 
Overall Mean Dice Score: 0.4563
Overall Mean F-beta Score: 0.5401

Training Loss: 0.5189, Validation Loss: 0.5189, Validation F-beta: 0.5401
Epoch 43/4000
Current lambda: 0.4600


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.123] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.541]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.5808, Class 2: 0.1085, Class 3: 0.1956, 
Class 4: 0.4735, Class 5: 0.4428, Class 6: 0.5254, 
Validation F-beta Score
Class 0: 0.9872, Class 1: 0.6472, Class 2: 0.1459, Class 3: 0.2254, 
Class 4: 0.5309, Class 5: 0.4983, Class 6: 0.5704, 
Overall Mean Dice Score: 0.4436
Overall Mean F-beta Score: 0.4944

Training Loss: 0.5252, Validation Loss: 0.5347, Validation F-beta: 0.4944
Validation loss did not improve. Reducing lambda to 0.4500
Epoch 44/4000
Current lambda: 0.4500


Training: 100%|██████████| 288/288 [07:39<00:00,  1.60s/it, loss=0.106] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.321]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.6752, Class 2: 0.2456, Class 3: 0.3860, 
Class 4: 0.6122, Class 5: 0.4866, Class 6: 0.4546, 
Validation F-beta Score
Class 0: 0.9913, Class 1: 0.6851, Class 2: 0.2875, Class 3: 0.4550, 
Class 4: 0.6409, Class 5: 0.5251, Class 6: 0.5113, 
Overall Mean Dice Score: 0.5229
Overall Mean F-beta Score: 0.5635

Training Loss: 0.5169, Validation Loss: 0.4868, Validation F-beta: 0.5635
Epoch 45/4000
Current lambda: 0.4500


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.131] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.375]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.5653, Class 2: 0.1013, Class 3: 0.4180, 
Class 4: 0.5445, Class 5: 0.4145, Class 6: 0.5045, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.5978, Class 2: 0.1247, Class 3: 0.4885, 
Class 4: 0.5305, Class 5: 0.4590, Class 6: 0.5193, 
Overall Mean Dice Score: 0.4894
Overall Mean F-beta Score: 0.5190

Training Loss: 0.5109, Validation Loss: 0.4985, Validation F-beta: 0.5190
Epoch 46/4000
Current lambda: 0.4500


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.126] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.635]


Validation Dice Score
Class 0: 0.9882, Class 1: 0.6165, Class 2: 0.1268, Class 3: 0.1579, 
Class 4: 0.6069, Class 5: 0.4120, Class 6: 0.4582, 
Validation F-beta Score
Class 0: 0.9877, Class 1: 0.7399, Class 2: 0.2175, Class 3: 0.2033, 
Class 4: 0.5522, Class 5: 0.4705, Class 6: 0.4752, 
Overall Mean Dice Score: 0.4503
Overall Mean F-beta Score: 0.4882

Training Loss: 0.5103, Validation Loss: 0.5161, Validation F-beta: 0.4882
Epoch 47/4000
Current lambda: 0.4500


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.118] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.485]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.4562, Class 2: 0.1489, Class 3: 0.2855, 
Class 4: 0.6002, Class 5: 0.3367, Class 6: 0.5857, 
Validation F-beta Score
Class 0: 0.9874, Class 1: 0.4899, Class 2: 0.1769, Class 3: 0.3944, 
Class 4: 0.6801, Class 5: 0.4418, Class 6: 0.6109, 
Overall Mean Dice Score: 0.4529
Overall Mean F-beta Score: 0.5234

Training Loss: 0.5099, Validation Loss: 0.5159, Validation F-beta: 0.5234
Epoch 48/4000
Current lambda: 0.4500


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.126] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.36it/s, loss=0.536]


Validation Dice Score
Class 0: 0.9894, Class 1: 0.4757, Class 2: 0.1529, Class 3: 0.3470, 
Class 4: 0.6236, Class 5: 0.4235, Class 6: 0.5811, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.6265, Class 2: 0.2241, Class 3: 0.3671, 
Class 4: 0.6192, Class 5: 0.5404, Class 6: 0.6219, 
Overall Mean Dice Score: 0.4902
Overall Mean F-beta Score: 0.5550

Training Loss: 0.5041, Validation Loss: 0.5074, Validation F-beta: 0.5550
Epoch 49/4000
Current lambda: 0.4500


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.143] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.391]


Validation Dice Score
Class 0: 0.9907, Class 1: 0.5641, Class 2: 0.0849, Class 3: 0.4100, 
Class 4: 0.4787, Class 5: 0.4757, Class 6: 0.7255, 
Validation F-beta Score
Class 0: 0.9875, Class 1: 0.6603, Class 2: 0.1024, Class 3: 0.4858, 
Class 4: 0.5183, Class 5: 0.5758, Class 6: 0.7349, 
Overall Mean Dice Score: 0.5308
Overall Mean F-beta Score: 0.5950

Training Loss: 0.5151, Validation Loss: 0.4763, Validation F-beta: 0.5950
Validation loss did not improve. Reducing lambda to 0.4400
Epoch 50/4000
Current lambda: 0.4400


Training: 100%|██████████| 288/288 [07:39<00:00,  1.60s/it, loss=0.124] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.29it/s, loss=0.544]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.3960, Class 2: 0.1081, Class 3: 0.3388, 
Class 4: 0.5052, Class 5: 0.3662, Class 6: 0.5205, 
Validation F-beta Score
Class 0: 0.9886, Class 1: 0.4032, Class 2: 0.1586, Class 3: 0.3594, 
Class 4: 0.5268, Class 5: 0.3770, Class 6: 0.5077, 
Overall Mean Dice Score: 0.4254
Overall Mean F-beta Score: 0.4348

Training Loss: 0.5029, Validation Loss: 0.5399, Validation F-beta: 0.4348
Epoch 51/4000
Current lambda: 0.4400


Training: 100%|██████████| 288/288 [07:40<00:00,  1.60s/it, loss=0.115] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.313]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.6766, Class 2: 0.1060, Class 3: 0.3685, 
Class 4: 0.5925, Class 5: 0.4907, Class 6: 0.7794, 
Validation F-beta Score
Class 0: 0.9874, Class 1: 0.7305, Class 2: 0.1411, Class 3: 0.4256, 
Class 4: 0.6819, Class 5: 0.5763, Class 6: 0.7975, 
Overall Mean Dice Score: 0.5816
Overall Mean F-beta Score: 0.6424

Training Loss: 0.5059, Validation Loss: 0.4557, Validation F-beta: 0.6424
SUPER Best model saved. Loss:0.4557, Score:0.6424
Epoch 52/4000
Current lambda: 0.4400


Training: 100%|██████████| 288/288 [07:34<00:00,  1.58s/it, loss=0.126] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.552]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.5655, Class 2: 0.1100, Class 3: 0.2985, 
Class 4: 0.6361, Class 5: 0.4075, Class 6: 0.5149, 
Validation F-beta Score
Class 0: 0.9898, Class 1: 0.6556, Class 2: 0.1541, Class 3: 0.3443, 
Class 4: 0.6051, Class 5: 0.4750, Class 6: 0.5796, 
Overall Mean Dice Score: 0.4845
Overall Mean F-beta Score: 0.5319

Training Loss: 0.5124, Validation Loss: 0.5171, Validation F-beta: 0.5319
Epoch 53/4000
Current lambda: 0.4400


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.127] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.41it/s, loss=0.475]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.4839, Class 2: 0.0717, Class 3: 0.2875, 
Class 4: 0.5730, Class 5: 0.4211, Class 6: 0.3483, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.5244, Class 2: 0.0930, Class 3: 0.3971, 
Class 4: 0.5400, Class 5: 0.4728, Class 6: 0.3905, 
Overall Mean Dice Score: 0.4228
Overall Mean F-beta Score: 0.4650

Training Loss: 0.5012, Validation Loss: 0.5467, Validation F-beta: 0.4650
Epoch 54/4000
Current lambda: 0.4400


Training: 100%|██████████| 288/288 [07:36<00:00,  1.59s/it, loss=0.127] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.491]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.6278, Class 2: 0.1514, Class 3: 0.3106, 
Class 4: 0.6054, Class 5: 0.4205, Class 6: 0.6980, 
Validation F-beta Score
Class 0: 0.9912, Class 1: 0.6633, Class 2: 0.2053, Class 3: 0.3298, 
Class 4: 0.5451, Class 5: 0.4354, Class 6: 0.7208, 
Overall Mean Dice Score: 0.5325
Overall Mean F-beta Score: 0.5389

Training Loss: 0.5052, Validation Loss: 0.4952, Validation F-beta: 0.5389
Epoch 55/4000
Current lambda: 0.4400


Training: 100%|██████████| 288/288 [07:36<00:00,  1.59s/it, loss=0.141] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.506]


Validation Dice Score
Class 0: 0.9882, Class 1: 0.4648, Class 2: 0.1516, Class 3: 0.3144, 
Class 4: 0.5260, Class 5: 0.3855, Class 6: 0.5596, 
Validation F-beta Score
Class 0: 0.9880, Class 1: 0.4752, Class 2: 0.2174, Class 3: 0.4173, 
Class 4: 0.5011, Class 5: 0.4394, Class 6: 0.6761, 
Overall Mean Dice Score: 0.4500
Overall Mean F-beta Score: 0.5018

Training Loss: 0.5071, Validation Loss: 0.5100, Validation F-beta: 0.5018
Epoch 56/4000
Current lambda: 0.4400


Training: 100%|██████████| 288/288 [07:35<00:00,  1.58s/it, loss=0.129] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.47] 


Validation Dice Score
Class 0: 0.9910, Class 1: 0.6185, Class 2: 0.0713, Class 3: 0.4051, 
Class 4: 0.7479, Class 5: 0.4427, Class 6: 0.5108, 
Validation F-beta Score
Class 0: 0.9909, Class 1: 0.7103, Class 2: 0.0906, Class 3: 0.4962, 
Class 4: 0.7394, Class 5: 0.5416, Class 6: 0.6133, 
Overall Mean Dice Score: 0.5450
Overall Mean F-beta Score: 0.6201

Training Loss: 0.5008, Validation Loss: 0.4849, Validation F-beta: 0.6201
Validation loss did not improve. Reducing lambda to 0.4300
Epoch 57/4000
Current lambda: 0.4300


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.15]  
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.36it/s, loss=0.393]


Validation Dice Score
Class 0: 0.9916, Class 1: 0.5204, Class 2: 0.1790, Class 3: 0.4099, 
Class 4: 0.3649, Class 5: 0.4445, Class 6: 0.7269, 
Validation F-beta Score
Class 0: 0.9916, Class 1: 0.5473, Class 2: 0.1961, Class 3: 0.3844, 
Class 4: 0.4033, Class 5: 0.4617, Class 6: 0.7595, 
Overall Mean Dice Score: 0.4933
Overall Mean F-beta Score: 0.5112

Training Loss: 0.5087, Validation Loss: 0.4815, Validation F-beta: 0.5112
Epoch 58/4000
Current lambda: 0.4300


Training: 100%|██████████| 288/288 [07:36<00:00,  1.58s/it, loss=0.108] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.41it/s, loss=0.506]


Validation Dice Score
Class 0: 0.9893, Class 1: 0.6756, Class 2: 0.1508, Class 3: 0.3249, 
Class 4: 0.4979, Class 5: 0.4138, Class 6: 0.7917, 
Validation F-beta Score
Class 0: 0.9913, Class 1: 0.7069, Class 2: 0.1721, Class 3: 0.3906, 
Class 4: 0.4493, Class 5: 0.4173, Class 6: 0.7953, 
Overall Mean Dice Score: 0.5408
Overall Mean F-beta Score: 0.5519

Training Loss: 0.5074, Validation Loss: 0.4787, Validation F-beta: 0.5519
Epoch 59/4000
Current lambda: 0.4300


Training: 100%|██████████| 288/288 [07:35<00:00,  1.58s/it, loss=0.14]  
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.36it/s, loss=0.503]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.7382, Class 2: 0.1714, Class 3: 0.3534, 
Class 4: 0.5075, Class 5: 0.3466, Class 6: 0.7094, 
Validation F-beta Score
Class 0: 0.9910, Class 1: 0.7754, Class 2: 0.1642, Class 3: 0.3934, 
Class 4: 0.4741, Class 5: 0.4060, Class 6: 0.6745, 
Overall Mean Dice Score: 0.5310
Overall Mean F-beta Score: 0.5447

Training Loss: 0.5025, Validation Loss: 0.4578, Validation F-beta: 0.5447
Epoch 60/4000
Current lambda: 0.4300


Training: 100%|██████████| 288/288 [07:36<00:00,  1.59s/it, loss=0.14]  
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.30it/s, loss=0.409]


Validation Dice Score
Class 0: 0.9901, Class 1: 0.5933, Class 2: 0.1366, Class 3: 0.3370, 
Class 4: 0.6304, Class 5: 0.4069, Class 6: 0.7732, 
Validation F-beta Score
Class 0: 0.9895, Class 1: 0.6206, Class 2: 0.2108, Class 3: 0.3699, 
Class 4: 0.6354, Class 5: 0.4598, Class 6: 0.8228, 
Overall Mean Dice Score: 0.5482
Overall Mean F-beta Score: 0.5817

Training Loss: 0.5024, Validation Loss: 0.4733, Validation F-beta: 0.5817
Epoch 61/4000
Current lambda: 0.4300


Training: 100%|██████████| 288/288 [07:35<00:00,  1.58s/it, loss=0.129] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.548]


Validation Dice Score
Class 0: 0.9889, Class 1: 0.5162, Class 2: 0.1327, Class 3: 0.2817, 
Class 4: 0.6237, Class 5: 0.4151, Class 6: 0.5077, 
Validation F-beta Score
Class 0: 0.9880, Class 1: 0.6898, Class 2: 0.1949, Class 3: 0.3364, 
Class 4: 0.5825, Class 5: 0.5322, Class 6: 0.5336, 
Overall Mean Dice Score: 0.4689
Overall Mean F-beta Score: 0.5349

Training Loss: 0.4951, Validation Loss: 0.5101, Validation F-beta: 0.5349
Epoch 62/4000
Current lambda: 0.4300


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.112] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.494]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.5733, Class 2: 0.1782, Class 3: 0.3862, 
Class 4: 0.6102, Class 5: 0.4791, Class 6: 0.6469, 
Validation F-beta Score
Class 0: 0.9898, Class 1: 0.6517, Class 2: 0.2027, Class 3: 0.4246, 
Class 4: 0.6160, Class 5: 0.5022, Class 6: 0.6643, 
Overall Mean Dice Score: 0.5392
Overall Mean F-beta Score: 0.5718

Training Loss: 0.4981, Validation Loss: 0.4866, Validation F-beta: 0.5718
Validation loss did not improve. Reducing lambda to 0.4200
Epoch 63/4000
Current lambda: 0.4200


Training: 100%|██████████| 288/288 [07:36<00:00,  1.59s/it, loss=0.139] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.282]


Validation Dice Score
Class 0: 0.9896, Class 1: 0.7690, Class 2: 0.1277, Class 3: 0.4644, 
Class 4: 0.5960, Class 5: 0.4319, Class 6: 0.6224, 
Validation F-beta Score
Class 0: 0.9908, Class 1: 0.7933, Class 2: 0.1555, Class 3: 0.4779, 
Class 4: 0.5757, Class 5: 0.4110, Class 6: 0.6845, 
Overall Mean Dice Score: 0.5767
Overall Mean F-beta Score: 0.5885

Training Loss: 0.5034, Validation Loss: 0.4528, Validation F-beta: 0.5885
Best model saved based on validation loss: 0.4528
Epoch 64/4000
Current lambda: 0.4200


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.13]  
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.522]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.6438, Class 2: 0.2094, Class 3: 0.3729, 
Class 4: 0.4913, Class 5: 0.4363, Class 6: 0.4786, 
Validation F-beta Score
Class 0: 0.9898, Class 1: 0.6228, Class 2: 0.2437, Class 3: 0.3958, 
Class 4: 0.4694, Class 5: 0.5482, Class 6: 0.4489, 
Overall Mean Dice Score: 0.4846
Overall Mean F-beta Score: 0.4970

Training Loss: 0.4951, Validation Loss: 0.4884, Validation F-beta: 0.4970
Epoch 65/4000
Current lambda: 0.4200


Training: 100%|██████████| 288/288 [07:35<00:00,  1.58s/it, loss=0.122] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.477]


Validation Dice Score
Class 0: 0.9883, Class 1: 0.4585, Class 2: 0.1744, Class 3: 0.3627, 
Class 4: 0.6793, Class 5: 0.4250, Class 6: 0.6337, 
Validation F-beta Score
Class 0: 0.9909, Class 1: 0.5211, Class 2: 0.2251, Class 3: 0.4018, 
Class 4: 0.6550, Class 5: 0.3991, Class 6: 0.6509, 
Overall Mean Dice Score: 0.5118
Overall Mean F-beta Score: 0.5256

Training Loss: 0.4997, Validation Loss: 0.4958, Validation F-beta: 0.5256
Epoch 66/4000
Current lambda: 0.4200


Training: 100%|██████████| 288/288 [07:36<00:00,  1.59s/it, loss=0.121] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.503]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.5263, Class 2: 0.2443, Class 3: 0.3233, 
Class 4: 0.6038, Class 5: 0.3901, Class 6: 0.4061, 
Validation F-beta Score
Class 0: 0.9898, Class 1: 0.5163, Class 2: 0.2910, Class 3: 0.3911, 
Class 4: 0.5819, Class 5: 0.4377, Class 6: 0.4723, 
Overall Mean Dice Score: 0.4499
Overall Mean F-beta Score: 0.4799

Training Loss: 0.5016, Validation Loss: 0.5153, Validation F-beta: 0.4799
Epoch 67/4000
Current lambda: 0.4200


Training: 100%|██████████| 288/288 [07:38<00:00,  1.59s/it, loss=0.119] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.40it/s, loss=0.411]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.6808, Class 2: 0.0540, Class 3: 0.3046, 
Class 4: 0.6872, Class 5: 0.4071, Class 6: 0.7408, 
Validation F-beta Score
Class 0: 0.9912, Class 1: 0.7048, Class 2: 0.0780, Class 3: 0.3196, 
Class 4: 0.6565, Class 5: 0.4212, Class 6: 0.7294, 
Overall Mean Dice Score: 0.5641
Overall Mean F-beta Score: 0.5663

Training Loss: 0.5010, Validation Loss: 0.4797, Validation F-beta: 0.5663
Epoch 68/4000
Current lambda: 0.4200


Training: 100%|██████████| 288/288 [07:36<00:00,  1.58s/it, loss=0.112] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.452]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.6806, Class 2: 0.1335, Class 3: 0.4102, 
Class 4: 0.6652, Class 5: 0.4063, Class 6: 0.6222, 
Validation F-beta Score
Class 0: 0.9923, Class 1: 0.7219, Class 2: 0.1659, Class 3: 0.4635, 
Class 4: 0.6020, Class 5: 0.4099, Class 6: 0.6263, 
Overall Mean Dice Score: 0.5569
Overall Mean F-beta Score: 0.5647

Training Loss: 0.4882, Validation Loss: 0.4937, Validation F-beta: 0.5647
Validation loss did not improve. Reducing lambda to 0.4100
Epoch 69/4000
Current lambda: 0.4100


Training: 100%|██████████| 288/288 [07:37<00:00,  1.59s/it, loss=0.109] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.458]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.7547, Class 2: 0.2073, Class 3: 0.4025, 
Class 4: 0.4828, Class 5: 0.4023, Class 6: 0.6880, 
Validation F-beta Score
Class 0: 0.9894, Class 1: 0.8413, Class 2: 0.2911, Class 3: 0.3620, 
Class 4: 0.4910, Class 5: 0.4667, Class 6: 0.7253, 
Overall Mean Dice Score: 0.5461
Overall Mean F-beta Score: 0.5773

Training Loss: 0.4928, Validation Loss: 0.4746, Validation F-beta: 0.5773
Epoch 70/4000
Current lambda: 0.4100


Training: 100%|██████████| 288/288 [07:36<00:00,  1.58s/it, loss=0.113] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.32it/s, loss=0.543]


Validation Dice Score
Class 0: 0.9900, Class 1: 0.7053, Class 2: 0.0324, Class 3: 0.3046, 
Class 4: 0.5983, Class 5: 0.3193, Class 6: 0.6137, 
Validation F-beta Score
Class 0: 0.9897, Class 1: 0.7075, Class 2: 0.0338, Class 3: 0.3349, 
Class 4: 0.6076, Class 5: 0.3223, Class 6: 0.6273, 
Overall Mean Dice Score: 0.5082
Overall Mean F-beta Score: 0.5199

Training Loss: 0.4949, Validation Loss: 0.4978, Validation F-beta: 0.5199
Epoch 71/4000
Current lambda: 0.4100


Training: 100%|██████████| 288/288 [07:36<00:00,  1.59s/it, loss=0.117] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.55] 


Validation Dice Score
Class 0: 0.9916, Class 1: 0.4462, Class 2: 0.1350, Class 3: 0.3224, 
Class 4: 0.5497, Class 5: 0.4385, Class 6: 0.7059, 
Validation F-beta Score
Class 0: 0.9922, Class 1: 0.5677, Class 2: 0.1313, Class 3: 0.3352, 
Class 4: 0.5460, Class 5: 0.4620, Class 6: 0.7640, 
Overall Mean Dice Score: 0.4925
Overall Mean F-beta Score: 0.5350

Training Loss: 0.4912, Validation Loss: 0.4919, Validation F-beta: 0.5350
Epoch 72/4000
Current lambda: 0.4100


Training: 100%|██████████| 288/288 [07:36<00:00,  1.58s/it, loss=0.108] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.39it/s, loss=0.612]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.5380, Class 2: 0.2433, Class 3: 0.3712, 
Class 4: 0.4590, Class 5: 0.4005, Class 6: 0.6405, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.6002, Class 2: 0.2923, Class 3: 0.3388, 
Class 4: 0.5094, Class 5: 0.4462, Class 6: 0.6282, 
Overall Mean Dice Score: 0.4818
Overall Mean F-beta Score: 0.5045

Training Loss: 0.4936, Validation Loss: 0.5141, Validation F-beta: 0.5045
Epoch 73/4000
Current lambda: 0.4100


Training: 100%|██████████| 288/288 [07:45<00:00,  1.62s/it, loss=0.0976]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.526]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.6918, Class 2: 0.1826, Class 3: 0.2422, 
Class 4: 0.5416, Class 5: 0.3973, Class 6: 0.6601, 
Validation F-beta Score
Class 0: 0.9909, Class 1: 0.7127, Class 2: 0.2023, Class 3: 0.2575, 
Class 4: 0.5373, Class 5: 0.3944, Class 6: 0.6720, 
Overall Mean Dice Score: 0.5066
Overall Mean F-beta Score: 0.5148

Training Loss: 0.4863, Validation Loss: 0.5103, Validation F-beta: 0.5148
Epoch 74/4000
Current lambda: 0.4100


Training: 100%|██████████| 288/288 [08:05<00:00,  1.69s/it, loss=0.127] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.36it/s, loss=0.522]


Validation Dice Score
Class 0: 0.9906, Class 1: 0.7736, Class 2: 0.1234, Class 3: 0.3607, 
Class 4: 0.5364, Class 5: 0.4206, Class 6: 0.7848, 
Validation F-beta Score
Class 0: 0.9908, Class 1: 0.8163, Class 2: 0.1493, Class 3: 0.3929, 
Class 4: 0.5477, Class 5: 0.4078, Class 6: 0.7555, 
Overall Mean Dice Score: 0.5752
Overall Mean F-beta Score: 0.5840

Training Loss: 0.4919, Validation Loss: 0.4638, Validation F-beta: 0.5840
Validation loss did not improve. Reducing lambda to 0.4000
Epoch 75/4000
Current lambda: 0.4000


Training: 100%|██████████| 288/288 [08:05<00:00,  1.69s/it, loss=0.0969]
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.482]


Validation Dice Score
Class 0: 0.9907, Class 1: 0.6055, Class 2: 0.2644, Class 3: 0.5221, 
Class 4: 0.5648, Class 5: 0.4701, Class 6: 0.7375, 
Validation F-beta Score
Class 0: 0.9916, Class 1: 0.6450, Class 2: 0.2946, Class 3: 0.5286, 
Class 4: 0.5233, Class 5: 0.5324, Class 6: 0.7334, 
Overall Mean Dice Score: 0.5800
Overall Mean F-beta Score: 0.5926

Training Loss: 0.4922, Validation Loss: 0.4367, Validation F-beta: 0.5926
Best model saved based on validation loss: 0.4367
Epoch 76/4000
Current lambda: 0.4000


Training: 100%|██████████| 288/288 [08:07<00:00,  1.69s/it, loss=0.133] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.36it/s, loss=0.492]


Validation Dice Score
Class 0: 0.9909, Class 1: 0.4021, Class 2: 0.1556, Class 3: 0.2540, 
Class 4: 0.6587, Class 5: 0.2660, Class 6: 0.4550, 
Validation F-beta Score
Class 0: 0.9908, Class 1: 0.5404, Class 2: 0.1761, Class 3: 0.2916, 
Class 4: 0.6570, Class 5: 0.2570, Class 6: 0.4482, 
Overall Mean Dice Score: 0.4071
Overall Mean F-beta Score: 0.4388

Training Loss: 0.4946, Validation Loss: 0.5508, Validation F-beta: 0.4388
Epoch 77/4000
Current lambda: 0.4000


Training: 100%|██████████| 288/288 [08:05<00:00,  1.69s/it, loss=0.14]  
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.555]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.6568, Class 2: 0.0414, Class 3: 0.3420, 
Class 4: 0.6034, Class 5: 0.3893, Class 6: 0.8268, 
Validation F-beta Score
Class 0: 0.9906, Class 1: 0.6678, Class 2: 0.0577, Class 3: 0.3425, 
Class 4: 0.5732, Class 5: 0.4300, Class 6: 0.8153, 
Overall Mean Dice Score: 0.5637
Overall Mean F-beta Score: 0.5658

Training Loss: 0.5001, Validation Loss: 0.4810, Validation F-beta: 0.5658
Epoch 78/4000
Current lambda: 0.4000


Training: 100%|██████████| 288/288 [08:03<00:00,  1.68s/it, loss=0.113] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.38it/s, loss=0.495]


Validation Dice Score
Class 0: 0.9906, Class 1: 0.6173, Class 2: 0.2399, Class 3: 0.2462, 
Class 4: 0.5381, Class 5: 0.4149, Class 6: 0.6966, 
Validation F-beta Score
Class 0: 0.9897, Class 1: 0.6663, Class 2: 0.2100, Class 3: 0.2486, 
Class 4: 0.5608, Class 5: 0.4616, Class 6: 0.7514, 
Overall Mean Dice Score: 0.5026
Overall Mean F-beta Score: 0.5377

Training Loss: 0.4949, Validation Loss: 0.4948, Validation F-beta: 0.5377
Epoch 79/4000
Current lambda: 0.4000


Training: 100%|██████████| 288/288 [07:57<00:00,  1.66s/it, loss=0.129] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.36it/s, loss=0.411]


Validation Dice Score
Class 0: 0.9917, Class 1: 0.5833, Class 2: 0.1647, Class 3: 0.3801, 
Class 4: 0.5532, Class 5: 0.4661, Class 6: 0.5845, 
Validation F-beta Score
Class 0: 0.9920, Class 1: 0.6967, Class 2: 0.1972, Class 3: 0.4235, 
Class 4: 0.5470, Class 5: 0.4659, Class 6: 0.5976, 
Overall Mean Dice Score: 0.5135
Overall Mean F-beta Score: 0.5461

Training Loss: 0.4903, Validation Loss: 0.4850, Validation F-beta: 0.5461
Epoch 80/4000
Current lambda: 0.4000


Training: 100%|██████████| 288/288 [08:00<00:00,  1.67s/it, loss=0.122] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.31it/s, loss=0.619]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.4376, Class 2: 0.1209, Class 3: 0.3535, 
Class 4: 0.6161, Class 5: 0.3815, Class 6: 0.5298, 
Validation F-beta Score
Class 0: 0.9900, Class 1: 0.4184, Class 2: 0.1391, Class 3: 0.4047, 
Class 4: 0.5733, Class 5: 0.4850, Class 6: 0.5288, 
Overall Mean Dice Score: 0.4637
Overall Mean F-beta Score: 0.4820

Training Loss: 0.4843, Validation Loss: 0.5018, Validation F-beta: 0.4820
Validation loss did not improve. Reducing lambda to 0.3900
Epoch 81/4000
Current lambda: 0.3900


Training: 100%|██████████| 288/288 [07:59<00:00,  1.67s/it, loss=0.139] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.489]


Validation Dice Score
Class 0: 0.9924, Class 1: 0.6508, Class 2: 0.1148, Class 3: 0.5056, 
Class 4: 0.7648, Class 5: 0.4014, Class 6: 0.5501, 
Validation F-beta Score
Class 0: 0.9932, Class 1: 0.6961, Class 2: 0.1348, Class 3: 0.5928, 
Class 4: 0.7232, Class 5: 0.3734, Class 6: 0.5501, 
Overall Mean Dice Score: 0.5745
Overall Mean F-beta Score: 0.5871

Training Loss: 0.4890, Validation Loss: 0.4605, Validation F-beta: 0.5871
Epoch 82/4000
Current lambda: 0.3900


Training: 100%|██████████| 288/288 [08:06<00:00,  1.69s/it, loss=0.109] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.559]


Validation Dice Score
Class 0: 0.9926, Class 1: 0.5169, Class 2: 0.2946, Class 3: 0.2512, 
Class 4: 0.4930, Class 5: 0.5230, Class 6: 0.4618, 
Validation F-beta Score
Class 0: 0.9928, Class 1: 0.6900, Class 2: 0.3308, Class 3: 0.2663, 
Class 4: 0.4585, Class 5: 0.5743, Class 6: 0.4643, 
Overall Mean Dice Score: 0.4492
Overall Mean F-beta Score: 0.4907

Training Loss: 0.4846, Validation Loss: 0.5016, Validation F-beta: 0.4907
Epoch 83/4000
Current lambda: 0.3900


Training: 100%|██████████| 288/288 [08:03<00:00,  1.68s/it, loss=0.142] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.587]


Validation Dice Score
Class 0: 0.9921, Class 1: 0.6625, Class 2: 0.2360, Class 3: 0.4899, 
Class 4: 0.5713, Class 5: 0.3607, Class 6: 0.5922, 
Validation F-beta Score
Class 0: 0.9918, Class 1: 0.6815, Class 2: 0.2901, Class 3: 0.5055, 
Class 4: 0.5625, Class 5: 0.3743, Class 6: 0.6014, 
Overall Mean Dice Score: 0.5353
Overall Mean F-beta Score: 0.5450

Training Loss: 0.4766, Validation Loss: 0.4827, Validation F-beta: 0.5450
Epoch 84/4000
Current lambda: 0.3900


Training: 100%|██████████| 288/288 [08:05<00:00,  1.68s/it, loss=0.122] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.32it/s, loss=0.409]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.6807, Class 2: 0.1682, Class 3: 0.3753, 
Class 4: 0.6738, Class 5: 0.3522, Class 6: 0.5121, 
Validation F-beta Score
Class 0: 0.9917, Class 1: 0.6830, Class 2: 0.1578, Class 3: 0.3725, 
Class 4: 0.5980, Class 5: 0.3712, Class 6: 0.5035, 
Overall Mean Dice Score: 0.5188
Overall Mean F-beta Score: 0.5056

Training Loss: 0.4918, Validation Loss: 0.4840, Validation F-beta: 0.5056
Epoch 85/4000
Current lambda: 0.3900


Training: 100%|██████████| 288/288 [08:04<00:00,  1.68s/it, loss=0.122] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.35it/s, loss=0.568]


Validation Dice Score
Class 0: 0.9932, Class 1: 0.6460, Class 2: 0.1415, Class 3: 0.3583, 
Class 4: 0.4326, Class 5: 0.5033, Class 6: 0.6354, 
Validation F-beta Score
Class 0: 0.9935, Class 1: 0.7067, Class 2: 0.1564, Class 3: 0.3686, 
Class 4: 0.4111, Class 5: 0.5058, Class 6: 0.6748, 
Overall Mean Dice Score: 0.5151
Overall Mean F-beta Score: 0.5334

Training Loss: 0.4865, Validation Loss: 0.4958, Validation F-beta: 0.5334
Epoch 86/4000
Current lambda: 0.3900


Training: 100%|██████████| 288/288 [07:59<00:00,  1.67s/it, loss=0.122] 
Validation: 100%|██████████| 12/12 [00:08<00:00,  1.37it/s, loss=0.394]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.7254, Class 2: 0.0429, Class 3: 0.2528, 
Class 4: 0.5549, Class 5: 0.5032, Class 6: 0.6219, 
Validation F-beta Score
Class 0: 0.9921, Class 1: 0.8478, Class 2: 0.0585, Class 3: 0.2522, 
Class 4: 0.4997, Class 5: 0.4896, Class 6: 0.6640, 
Overall Mean Dice Score: 0.5316
Overall Mean F-beta Score: 0.5507

Training Loss: 0.4859, Validation Loss: 0.4843, Validation F-beta: 0.5507
Validation loss did not improve. Reducing lambda to 0.3800
Epoch 87/4000
Current lambda: 0.3800


Training: 100%|██████████| 288/288 [08:02<00:00,  1.68s/it, loss=0.104] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.31it/s, loss=0.531]


Validation Dice Score
Class 0: 0.9913, Class 1: 0.5998, Class 2: 0.1689, Class 3: 0.4812, 
Class 4: 0.5007, Class 5: 0.4226, Class 6: 0.6676, 
Validation F-beta Score
Class 0: 0.9913, Class 1: 0.6546, Class 2: 0.1942, Class 3: 0.4894, 
Class 4: 0.4980, Class 5: 0.4452, Class 6: 0.6905, 
Overall Mean Dice Score: 0.5344
Overall Mean F-beta Score: 0.5555

Training Loss: 0.4857, Validation Loss: 0.4792, Validation F-beta: 0.5555
Epoch 88/4000
Current lambda: 0.3800


Training:   6%|▋         | 18/288 [00:31<07:45,  1.73s/it, loss=0.107] 

In [None]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# VAl

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import TverskyLoss
import torch
import numpy as np
from tqdm import tqdm
import wandb
from src.dataset.dataset import make_val_dataloader

val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
lamda = 0.52

wandb.init(
    project='czii_SwinUnetR_val',  # 프로젝트 이름 설정
    name='SwinUNETR96_96_lr0.001_lambda0.52_batch2',         # 실행(run) 이름 설정
    config={
        'learning_rate': 0.001,
        'batch_size': batch_size,
        'lambda': lamda,
        'img_size': img_size,
        'device': 'cuda',
        "checkpoint_dir": "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2",
        
    }
)

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

val_loader = make_val_dataloader(
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0
)
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    softmax=True
)
    
    
from monai.metrics import DiceMetric

img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

val_loss, overall_mean_fbeta_score = validate_one_epoch(
    model=model, 
    val_loader=val_loader, 
    criterion=criterion, 
    device=device, 
    epoch=0, 
    calculate_dice_interval=1
)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
class_0_dice_score,▁
class_0_f_beta_score,▁
class_1_dice_score,▁
class_1_f_beta_score,▁
class_2_dice_score,▁
class_2_f_beta_score,▁
class_3_dice_score,▁
class_3_f_beta_score,▁
class_4_dice_score,▁
class_4_f_beta_score,▁

0,1
class_0_dice_score,0.65703
class_0_f_beta_score,0.50748
class_1_dice_score,0.53332
class_1_f_beta_score,0.64703
class_2_dice_score,0.00286
class_2_f_beta_score,0.02334
class_3_dice_score,0.23703
class_3_f_beta_score,0.23033
class_4_dice_score,0.65487
class_4_f_beta_score,0.62525


Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]
  checkpoint = torch.load(pretrain_path, map_location=device)
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s, loss=0.865]

Validation Dice Score
Class 0: 0.6570, Class 1: 0.5333, Class 2: 0.0029, Class 3: 0.2370, 
Class 4: 0.6549, Class 5: 0.4790, Class 6: 0.4255, 
Validation F-beta Score
Class 0: 0.5075, Class 1: 0.6470, Class 2: 0.0233, Class 3: 0.2303, 
Class 4: 0.6252, Class 5: 0.5145, Class 6: 0.4720, 
Overall Mean Dice Score: 0.4659
Overall Mean F-beta Score: 0.4978






# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

In [None]:
val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=1, 
            calculate_dice_interval=0
        )

Validation:   0%|          | 0/4 [00:03<?, ?it/s, loss=0.764]


ZeroDivisionError: integer modulo by zero

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv
