In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)


import torch

print_config()

  from .autonotebook import tqdm as notebook_tqdm


MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: c:\Users\<username>\.conda\envs\UM\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.2.0
Tensorboard version: 2.18.0
gdown version: 5.2.0
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.17.2
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.

In [2]:
# class_info = {
#     0: {"name": "background", "weight": 10000},  # weight 없음
#     1: {"name": "apo-ferritin", "weight": 300},
#     2: {"name": "beta-amylase", "weight": 100}, # 4130
#     3: {"name": "beta-galactosidase", "weight": 150}, #3080
#     4: {"name": "ribosome", "weight": 6000},
#     5: {"name": "thyroglobulin", "weight": 4000},
#     6: {"name": "virus-like-particle", "weight": 2000},
# }

# # 가중치에 비례한 비율 계산
# raw_ratios = {
#     k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값
#     for k, v in class_info.items()
# }
# total = sum(raw_ratios.values())
# ratios = {k: v / total for k, v in raw_ratios.items()}

# # 최종 합계가 1인지 확인
# final_total = sum(ratios.values())
# print("클래스 비율:", ratios)
# print("최종 합계:", final_total)

# # 비율을 리스트로 변환
# ratios_list = [ratios[k] for k in sorted(ratios.keys())]
# print("클래스 비율 리스트:", ratios_list)

In [3]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 2000}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 2000},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.14084507042253522, 2: 0.014084507042253521, 3: 0.28169014084507044, 4: 0.14084507042253522, 5: 0.28169014084507044, 6: 0.14084507042253522}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.14084507042253522, 0.014084507042253521, 0.28169014084507044, 0.14084507042253522, 0.28169014084507044, 0.14084507042253522]


In [4]:
from src.dataset.dataset import create_dataloaders
import numpy as np
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd, CastToTyped,
)

train_img_dir = "./datasets/train/images"
train_label_dir = "./datasets/train/labels"
val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
# DATA CONFIG
img_depth = 96
img_size =  96 # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
# MODEL CONFIG
num_epochs = 4000
lamda = 0.52
lr = 0.001
feature_size = 48
use_checkpoint = True
use_v2 = True
drop_rate= 0.25
attn_drop_rate = 0.25
# CLASS_WEIGHTS
class_weights = None
class_weights = torch.tensor([0.001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치

accumulation_steps = 4
# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    # CastToTyped(keys=["image"], dtype=np.float16),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
])


In [5]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0)

Loading dataset: 100%|██████████| 24/24 [00:42<00:00,  1.77s/it]
Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.72s/it]


https://monai.io/model-zoo.html

In [6]:
from monai.losses import TverskyLoss
import torch

def loss_fn(loss, class_weights, device):
    """
    Tversky 손실에 클래스별 가중치를 적용하여 최종 스칼라 값을 반환합니다.

    Args:
        loss: Tversky 손실 텐서 (B, num_classes, H, W, D).
        class_weights: 클래스별 가중치 텐서 (num_classes,).
        device: 사용할 장치 (예: 'cuda' 또는 'cpu').

    Returns:
        torch.Tensor: 최종 가중 평균 손실 값 (스칼라).
    """
    # 가중치를 device로 이동
    class_weights = class_weights.to(device)

    # 클래스 차원에 가중치 적용 (B, num_classes, ...)
    class_weights = class_weights.view(1, n_classes, 1, 1, 1)  # [1, num_classes, 1, 1, 1]
    weighted_loss = loss * class_weights

    # 모든 차원을 평균 내어 스칼라 손실 반환
    final_loss = torch.mean(weighted_loss)
    return final_loss

In [7]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
from monai.metrics import DiceMetric
from monai.losses import TverskyLoss
import torch.nn as nn
import torch.nn.functional as F

# TverskyLoss 설정
'''
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    reduction="none",  # 각 픽셀에 대한 손실 반환
    softmax=True
)
'''

class DynamicTverskyLoss(TverskyLoss):
    def __init__(self, lamda=0.5, **kwargs):
        super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
        self.lamda = lamda

    def set_lamda(self, lamda):
        self.lamda = lamda
        self.alpha = 1 - lamda
        self.beta = lamda
        
# criterion = DynamicTverskyLoss(
#     lamda=0.5,
#     include_background=False,
#     reduction="mean",
#     softmax=True
# )

class CombinedCETverskyLoss(nn.Module):
    def __init__(self, lamda=0.5, ce_weight=0.5, **kwargs):
        super().__init__()
        self._lamda = lamda  # lamda 값 저장
        self.tversky = DynamicTverskyLoss(lamda=lamda, **kwargs)
        self.ce = nn.CrossEntropyLoss()
        self.ce_weight = ce_weight
        
    def forward(self, inputs, targets):
        tversky_loss = self.tversky(inputs, targets)
        ce_loss = self.ce(inputs, targets)
        return self.ce_weight * ce_loss + (1 - self.ce_weight) * tversky_loss
    
    def set_lamda(self, lamda):
        self._lamda = lamda
        self.tversky.set_lamda(lamda)
    
    @property
    def lamda(self):
        return self._lamda

# 사용 예시
criterion = CombinedCETverskyLoss(
    lamda=0.5,
    ce_weight=0.2,  # CE Loss와 Tversky Loss의 비중을 0.5:0.5로 설정
    include_background=False,
    reduction="mean",
    softmax=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=feature_size,
    use_checkpoint=True,
    drop_rate = drop_rate,
    attn_drop_rate = attn_drop_rate,
    use_v2 = use_v2,
).to(device)

pretrain_str = "yes" if use_checkpoint else "no"
weight_str = "weighted" if class_weights is not None else ""

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
checkpoint_dir = checkpoint_base_dir / f"SwinUNETR_v2_step4_p{pretrain_str}_{weight_str}_f{feature_size}_d{img_depth}_s{img_size}_lr{lr:.0e}_a{lamda:.2f}_b{1-lamda:.2f}_batch{batch_size}"
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")



기존 best model 발견: model_checkpoints\SwinUNETR_v2_step4_pyes_weighted_f48_d96_s96_lr1e-03_a0.52_b0.48_batch2\best_model.pt


  checkpoint = torch.load(best_model_path, map_location=device)


기존 학습된 가중치를 성공적으로 로드했습니다.


In [8]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([2, 1, 96, 96, 96]) torch.Size([2, 1, 96, 96, 96])


In [9]:
torch.backends.cudnn.benchmark = True

In [10]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f'SwinUNETR_p{pretrain_str}_{weight_str}_f{feature_size}_d{img_depth}_s{img_size}_lr{lr:.0e}_a{lamda:.2f}_b{1-lamda:.2f}_batch{batch_size}_{current_time}'

# wandb 초기화
wandb.init(
    project='czii_SwinUnetR',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        "attn_drop_rate": attn_drop_rate,
        "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpook0612[0m ([33mlimbw[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from monai.metrics import DiceMetric

def create_metric_dict(num_classes):
    """각 클래스별 DiceMetric 생성"""
    metrics = {}
    for i in range(num_classes):
        metrics[f'dice_class_{i}'] = DiceMetric(
            include_background=False if i == 0 else False,
            reduction="mean",
            get_not_nans=False
        )
    return metrics
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    # loss = loss_fn(criterion(outputs, labels_onehot), class_weights, device)
    loss = loss_fn(criterion(outputs, labels_onehot),class_weights=class_weights, device=device)
    return loss, outputs, labels, outputs.argmax(dim=1)

# def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
#     model.train()
#     epoch_loss = 0
#     optimizer.zero_grad()  # 그래디언트 초기화
#     with tqdm(train_loader, desc='Training') as pbar:
#         for i, batch_data in enumerate(pbar):
#             # 손실 계산
#             loss, _, _, _ = processing(batch_data, model, criterion, device)

#             # 그래디언트를 계산하고 누적
#             loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
#             loss.backward()  # 그래디언트 계산 및 누적
            
#             # 그래디언트 업데이트 (accumulation_steps마다 한 번)
#             if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
#                 optimizer.step()  # 파라미터 업데이트
#                 optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
#             # 손실값 누적 (스케일링 복구)
#             epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
#             pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
#     avg_loss = epoch_loss / len(train_loader)
#     wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
#     return avg_loss

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    data_iter = iter(train_loader)  # 이터레이터 생성
    with tqdm(total=len(train_loader)*4, desc='Training') as pbar:
        for i in range(len(train_loader)*4):
            try:
                batch_data = next(data_iter)
            except StopIteration:
                # DataLoader를 전부 소진한 경우, 다시 이터레이터를 만들어 이어서 사용
                data_iter = iter(train_loader)
                batch_data = next(data_iter)
            
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
            pbar.update(1)

    avg_loss = epoch_loss / len(train_loader)*4
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss

def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval):
    model.eval()
    val_loss = 0
    metrics = create_metric_dict(n_classes)
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            if i == 3:
                print()
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
            if i == 3:
                print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1})
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\n")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    return val_loss / len(val_loader), overall_mean_fbeta


def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
         # 현재 lambda 값 출력
        print(f"Current lambda: {criterion.lamda:.4f}")
        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps,
        )
        scheduler.step(train_loss)
        
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F-beta: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")
        
        # Validation Loss 기준 모델 저장
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model_val_loss.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"Best model saved based on validation loss: {best_val_loss:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        if epochs_no_improve%6 == 0:
            # 손실이 개선되지 않았으므로 lambda 감소
            new_lamda = max(criterion.lamda - 0.01, 0.1)  # 최소값은 0.1로 설정
            criterion.set_lamda(new_lamda)
            print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")


    wandb.finish()


In [12]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=30,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps
)

Epoch 2/4000
Current lambda: 0.5000


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Training: 100%|██████████| 96/96 [02:49<00:00,  1.77s/it, loss=0.575]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.41it/s, loss=0.604]


Validation Dice Score
Class 0: 0.9816, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, 
Class 4: 0.2763, Class 5: 0.0000, Class 6: 0.0852, 
Validation F-beta Score
Class 0: 0.9804, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, 
Class 4: 0.2900, Class 5: 0.0000, Class 6: 0.1811, 
Overall Mean Dice Score: 0.0723
Overall Mean F-beta Score: 0.0942

Training Loss: 2.4272, Validation Loss: 0.5989, Validation F-beta: 0.0942
SUPER Best model saved. Loss:0.5989, Score:0.0942
Epoch 3/4000
Current lambda: 0.5000


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.57] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.43it/s, loss=0.535]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, 
Class 4: 0.2708, Class 5: 0.2593, Class 6: 0.5608, 
Validation F-beta Score
Class 0: 0.9852, Class 1: 0.0000, Class 2: 0.5000, Class 3: 0.0000, 
Class 4: 0.3938, Class 5: 0.2484, Class 6: 0.5447, 
Overall Mean Dice Score: 0.2182
Overall Mean F-beta Score: 0.2374

Training Loss: 2.3360, Validation Loss: 0.5356, Validation F-beta: 0.2374
SUPER Best model saved. Loss:0.5356, Score:0.2374
Epoch 4/4000
Current lambda: 0.5000


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.587]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s, loss=0.548]


Validation Dice Score
Class 0: 0.9845, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, 
Class 4: 0.2289, Class 5: 0.1375, Class 6: 0.2473, 
Validation F-beta Score
Class 0: 0.9778, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, 
Class 4: 0.2748, Class 5: 0.1692, Class 6: 0.3740, 
Overall Mean Dice Score: 0.1227
Overall Mean F-beta Score: 0.1636

Training Loss: 2.2584, Validation Loss: 0.5671, Validation F-beta: 0.1636
Epoch 5/4000
Current lambda: 0.5000


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.582]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.545]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0027, 
Class 4: 0.5121, Class 5: 0.2689, Class 6: 0.1319, 
Validation F-beta Score
Class 0: 0.9845, Class 1: 0.5000, Class 2: 0.0000, Class 3: 0.0015, 
Class 4: 0.5193, Class 5: 0.3466, Class 6: 0.1527, 
Overall Mean Dice Score: 0.1831
Overall Mean F-beta Score: 0.3040

Training Loss: 2.2340, Validation Loss: 0.5411, Validation F-beta: 0.3040
Validation loss did not improve. Reducing lambda to 0.4900
Epoch 6/4000
Current lambda: 0.4900


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.537]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.48it/s, loss=0.511]


Validation Dice Score
Class 0: 0.9836, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0680, 
Class 4: 0.4457, Class 5: 0.3479, Class 6: 0.5125, 
Validation F-beta Score
Class 0: 0.9829, Class 1: 0.0000, Class 2: 0.2500, Class 3: 0.0888, 
Class 4: 0.5736, Class 5: 0.3340, Class 6: 0.4813, 
Overall Mean Dice Score: 0.2748
Overall Mean F-beta Score: 0.2955

Training Loss: 2.1940, Validation Loss: 0.5145, Validation F-beta: 0.2955
SUPER Best model saved. Loss:0.5145, Score:0.2955
Epoch 7/4000
Current lambda: 0.4900


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.532]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.583]


Validation Dice Score
Class 0: 0.9839, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0215, 
Class 4: 0.3228, Class 5: 0.2817, Class 6: 0.2317, 
Validation F-beta Score
Class 0: 0.9857, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0146, 
Class 4: 0.3960, Class 5: 0.2534, Class 6: 0.2177, 
Overall Mean Dice Score: 0.1715
Overall Mean F-beta Score: 0.1763

Training Loss: 2.1529, Validation Loss: 0.5512, Validation F-beta: 0.1763
Epoch 8/4000
Current lambda: 0.4900


Training: 100%|██████████| 96/96 [02:32<00:00,  1.58s/it, loss=0.52] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=0.557]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.0015, Class 2: 0.0000, Class 3: 0.2471, 
Class 4: 0.2588, Class 5: 0.2829, Class 6: 0.3428, 
Validation F-beta Score
Class 0: 0.9812, Class 1: 0.0014, Class 2: 0.5000, Class 3: 0.2715, 
Class 4: 0.3982, Class 5: 0.3659, Class 6: 0.3997, 
Overall Mean Dice Score: 0.2266
Overall Mean F-beta Score: 0.2874

Training Loss: 2.1391, Validation Loss: 0.5249, Validation F-beta: 0.2874
Epoch 9/4000
Current lambda: 0.4900


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.536]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.511]


Validation Dice Score
Class 0: 0.9843, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.1113, 
Class 4: 0.4444, Class 5: 0.3055, Class 6: 0.2442, 
Validation F-beta Score
Class 0: 0.9804, Class 1: 0.0000, Class 2: 0.5000, Class 3: 0.0867, 
Class 4: 0.4842, Class 5: 0.3341, Class 6: 0.2900, 
Overall Mean Dice Score: 0.2211
Overall Mean F-beta Score: 0.2390

Training Loss: 2.1359, Validation Loss: 0.5288, Validation F-beta: 0.2390
Epoch 10/4000
Current lambda: 0.4900


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.527]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.48it/s, loss=0.484]


Validation Dice Score
Class 0: 0.9915, Class 1: 0.0247, Class 2: 0.0000, Class 3: 0.2001, 
Class 4: 0.3525, Class 5: 0.3474, Class 6: 0.7821, 
Validation F-beta Score
Class 0: 0.9916, Class 1: 0.0333, Class 2: 0.0000, Class 3: 0.1625, 
Class 4: 0.3558, Class 5: 0.3641, Class 6: 0.8450, 
Overall Mean Dice Score: 0.3414
Overall Mean F-beta Score: 0.3521

Training Loss: 2.1088, Validation Loss: 0.4778, Validation F-beta: 0.3521
SUPER Best model saved. Loss:0.4778, Score:0.3521
Epoch 11/4000
Current lambda: 0.4900


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.48] 
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.22it/s, loss=0.518]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.0092, Class 2: 0.0000, Class 3: 0.1158, 
Class 4: 0.3978, Class 5: 0.1304, Class 6: 0.4298, 
Validation F-beta Score
Class 0: 0.9882, Class 1: 0.0059, Class 2: 0.5000, Class 3: 0.2462, 
Class 4: 0.3466, Class 5: 0.2133, Class 6: 0.5671, 
Overall Mean Dice Score: 0.2166
Overall Mean F-beta Score: 0.2758

Training Loss: 2.0472, Validation Loss: 0.5356, Validation F-beta: 0.2758
Epoch 12/4000
Current lambda: 0.4900


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.539]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.448]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.0516, Class 2: 0.0000, Class 3: 0.2342, 
Class 4: 0.5414, Class 5: 0.3138, Class 6: 0.5109, 
Validation F-beta Score
Class 0: 0.9848, Class 1: 0.0372, Class 2: 0.0000, Class 3: 0.3321, 
Class 4: 0.6420, Class 5: 0.3566, Class 6: 0.4210, 
Overall Mean Dice Score: 0.3304
Overall Mean F-beta Score: 0.3578

Training Loss: 2.0394, Validation Loss: 0.4832, Validation F-beta: 0.3578
Validation loss did not improve. Reducing lambda to 0.4800
Epoch 13/4000
Current lambda: 0.4800


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.495]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s, loss=0.469]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.1657, Class 2: 0.0001, Class 3: 0.2285, 
Class 4: 0.5165, Class 5: 0.3571, Class 6: 0.4944, 
Validation F-beta Score
Class 0: 0.9826, Class 1: 0.1845, Class 2: 0.0001, Class 3: 0.2849, 
Class 4: 0.5932, Class 5: 0.5180, Class 6: 0.5802, 
Overall Mean Dice Score: 0.3524
Overall Mean F-beta Score: 0.4322

Training Loss: 2.0482, Validation Loss: 0.4725, Validation F-beta: 0.4322
SUPER Best model saved. Loss:0.4725, Score:0.4322
Epoch 14/4000
Current lambda: 0.4800


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.482]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.49] 


Validation Dice Score
Class 0: 0.9893, Class 1: 0.2356, Class 2: 0.0000, Class 3: 0.1807, 
Class 4: 0.5456, Class 5: 0.3063, Class 6: 0.8344, 
Validation F-beta Score
Class 0: 0.9904, Class 1: 0.2724, Class 2: 0.0000, Class 3: 0.1830, 
Class 4: 0.6125, Class 5: 0.2709, Class 6: 0.8200, 
Overall Mean Dice Score: 0.4205
Overall Mean F-beta Score: 0.4318

Training Loss: 1.9681, Validation Loss: 0.4561, Validation F-beta: 0.4318
Best model saved based on validation loss: 0.4561
Epoch 15/4000
Current lambda: 0.4800


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.416]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.476]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.1674, Class 2: 0.0093, Class 3: 0.2469, 
Class 4: 0.4004, Class 5: 0.4403, Class 6: 0.6063, 
Validation F-beta Score
Class 0: 0.9906, Class 1: 0.4205, Class 2: 0.0062, Class 3: 0.3929, 
Class 4: 0.3055, Class 5: 0.4017, Class 6: 0.5529, 
Overall Mean Dice Score: 0.3723
Overall Mean F-beta Score: 0.4147

Training Loss: 1.9758, Validation Loss: 0.4770, Validation F-beta: 0.4147
Epoch 16/4000
Current lambda: 0.4800


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.434]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.483]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.0237, Class 2: 0.0117, Class 3: 0.1471, 
Class 4: 0.5297, Class 5: 0.2754, Class 6: 0.8626, 
Validation F-beta Score
Class 0: 0.9878, Class 1: 0.0186, Class 2: 0.0079, Class 3: 0.1432, 
Class 4: 0.5019, Class 5: 0.3003, Class 6: 0.8424, 
Overall Mean Dice Score: 0.3677
Overall Mean F-beta Score: 0.3613

Training Loss: 1.9578, Validation Loss: 0.4754, Validation F-beta: 0.3613
Epoch 17/4000
Current lambda: 0.4800


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.573]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.479]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.4398, Class 2: 0.0878, Class 3: 0.3037, 
Class 4: 0.6279, Class 5: 0.4015, Class 6: 0.7199, 
Validation F-beta Score
Class 0: 0.9919, Class 1: 0.4342, Class 2: 0.0823, Class 3: 0.3170, 
Class 4: 0.6426, Class 5: 0.3366, Class 6: 0.6028, 
Overall Mean Dice Score: 0.4986
Overall Mean F-beta Score: 0.4667

Training Loss: 1.9463, Validation Loss: 0.4220, Validation F-beta: 0.4667
SUPER Best model saved. Loss:0.4220, Score:0.4667
Epoch 18/4000
Current lambda: 0.4800


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.488]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.454]


Validation Dice Score
Class 0: 0.9888, Class 1: 0.3417, Class 2: 0.0405, Class 3: 0.3549, 
Class 4: 0.3398, Class 5: 0.3922, Class 6: 0.6190, 
Validation F-beta Score
Class 0: 0.9854, Class 1: 0.4972, Class 2: 0.0585, Class 3: 0.3525, 
Class 4: 0.3517, Class 5: 0.6493, Class 6: 0.6365, 
Overall Mean Dice Score: 0.4095
Overall Mean F-beta Score: 0.4974

Training Loss: 1.9054, Validation Loss: 0.4592, Validation F-beta: 0.4974
Validation loss did not improve. Reducing lambda to 0.4700
Epoch 19/4000
Current lambda: 0.4700


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.461]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s, loss=0.41] 


Validation Dice Score
Class 0: 0.9867, Class 1: 0.3820, Class 2: 0.0301, Class 3: 0.1257, 
Class 4: 0.4990, Class 5: 0.4164, Class 6: 0.3767, 
Validation F-beta Score
Class 0: 0.9904, Class 1: 0.5868, Class 2: 0.0389, Class 3: 0.1477, 
Class 4: 0.3782, Class 5: 0.4989, Class 6: 0.4773, 
Overall Mean Dice Score: 0.3599
Overall Mean F-beta Score: 0.4178

Training Loss: 1.8880, Validation Loss: 0.4847, Validation F-beta: 0.4178
Epoch 20/4000
Current lambda: 0.4700


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.432]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.574]


Validation Dice Score
Class 0: 0.9855, Class 1: 0.3750, Class 2: 0.1126, Class 3: 0.3682, 
Class 4: 0.1475, Class 5: 0.3135, Class 6: 0.5629, 
Validation F-beta Score
Class 0: 0.9905, Class 1: 0.4664, Class 2: 0.0831, Class 3: 0.3953, 
Class 4: 0.3774, Class 5: 0.3166, Class 6: 0.5200, 
Overall Mean Dice Score: 0.3534
Overall Mean F-beta Score: 0.4151

Training Loss: 1.8837, Validation Loss: 0.4879, Validation F-beta: 0.4151
Epoch 21/4000
Current lambda: 0.4700


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.497]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.23it/s, loss=0.484]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.4098, Class 2: 0.0622, Class 3: 0.1890, 
Class 4: 0.4578, Class 5: 0.3701, Class 6: 0.4662, 
Validation F-beta Score
Class 0: 0.9868, Class 1: 0.4352, Class 2: 0.0477, Class 3: 0.1658, 
Class 4: 0.4002, Class 5: 0.3487, Class 6: 0.5766, 
Overall Mean Dice Score: 0.3786
Overall Mean F-beta Score: 0.3853

Training Loss: 1.8817, Validation Loss: 0.4634, Validation F-beta: 0.3853
Epoch 22/4000
Current lambda: 0.4700


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.471]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s, loss=0.364]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.4994, Class 2: 0.0000, Class 3: 0.1127, 
Class 4: 0.5861, Class 5: 0.2683, Class 6: 0.7708, 
Validation F-beta Score
Class 0: 0.9836, Class 1: 0.4891, Class 2: 0.0000, Class 3: 0.1493, 
Class 4: 0.7924, Class 5: 0.2756, Class 6: 0.8634, 
Overall Mean Dice Score: 0.4475
Overall Mean F-beta Score: 0.5140

Training Loss: 1.8702, Validation Loss: 0.4493, Validation F-beta: 0.5140
Validation loss did not improve. Reducing lambda to 0.4600
Epoch 23/4000
Current lambda: 0.4600


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.521]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s, loss=0.486]


Validation Dice Score
Class 0: 0.9893, Class 1: 0.5005, Class 2: 0.0117, Class 3: 0.1120, 
Class 4: 0.3696, Class 5: 0.2572, Class 6: 0.8026, 
Validation F-beta Score
Class 0: 0.9893, Class 1: 0.7043, Class 2: 0.0071, Class 3: 0.1058, 
Class 4: 0.3768, Class 5: 0.2720, Class 6: 0.8197, 
Overall Mean Dice Score: 0.4084
Overall Mean F-beta Score: 0.4557

Training Loss: 1.8731, Validation Loss: 0.4648, Validation F-beta: 0.4557
Epoch 24/4000
Current lambda: 0.4600


Training: 100%|██████████| 96/96 [02:28<00:00,  1.55s/it, loss=0.427]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.48it/s, loss=0.425]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.5220, Class 2: 0.1393, Class 3: 0.2876, 
Class 4: 0.5105, Class 5: 0.2092, Class 6: 0.3854, 
Validation F-beta Score
Class 0: 0.9889, Class 1: 0.8550, Class 2: 0.1503, Class 3: 0.2731, 
Class 4: 0.4334, Class 5: 0.2196, Class 6: 0.4046, 
Overall Mean Dice Score: 0.3829
Overall Mean F-beta Score: 0.4371

Training Loss: 1.8555, Validation Loss: 0.4816, Validation F-beta: 0.4371
Epoch 25/4000
Current lambda: 0.4600


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.529]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.48it/s, loss=0.403]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.1838, Class 2: 0.0523, Class 3: 0.0394, 
Class 4: 0.4914, Class 5: 0.1789, Class 6: 0.6377, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.1657, Class 2: 0.0698, Class 3: 0.0460, 
Class 4: 0.7579, Class 5: 0.1834, Class 6: 0.6315, 
Overall Mean Dice Score: 0.3062
Overall Mean F-beta Score: 0.3569

Training Loss: 1.7898, Validation Loss: 0.4797, Validation F-beta: 0.3569
Epoch 26/4000
Current lambda: 0.4600


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.44] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.483]


Validation Dice Score
Class 0: 0.9913, Class 1: 0.1791, Class 2: 0.1443, Class 3: 0.2270, 
Class 4: 0.5069, Class 5: 0.5536, Class 6: 0.1487, 
Validation F-beta Score
Class 0: 0.9910, Class 1: 0.2538, Class 2: 0.2019, Class 3: 0.2315, 
Class 4: 0.4274, Class 5: 0.6167, Class 6: 0.3412, 
Overall Mean Dice Score: 0.3231
Overall Mean F-beta Score: 0.3741

Training Loss: 1.8126, Validation Loss: 0.4732, Validation F-beta: 0.3741
Epoch 27/4000
Current lambda: 0.4600


Training: 100%|██████████| 96/96 [02:28<00:00,  1.55s/it, loss=0.496]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.428]


Validation Dice Score
Class 0: 0.9920, Class 1: 0.7382, Class 2: 0.0000, Class 3: 0.4121, 
Class 4: 0.4260, Class 5: 0.3921, Class 6: 0.5353, 
Validation F-beta Score
Class 0: 0.9909, Class 1: 0.8124, Class 2: 0.0000, Class 3: 0.5033, 
Class 4: 0.4331, Class 5: 0.4475, Class 6: 0.6625, 
Overall Mean Dice Score: 0.5007
Overall Mean F-beta Score: 0.5718

Training Loss: 1.8203, Validation Loss: 0.4408, Validation F-beta: 0.5718
Validation loss did not improve. Reducing lambda to 0.4500
Epoch 28/4000
Current lambda: 0.4500


Training: 100%|██████████| 96/96 [02:28<00:00,  1.55s/it, loss=0.397]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.442]


Validation Dice Score
Class 0: 0.9894, Class 1: 0.4921, Class 2: 0.1424, Class 3: 0.1438, 
Class 4: 0.4510, Class 5: 0.4322, Class 6: 0.6086, 
Validation F-beta Score
Class 0: 0.9899, Class 1: 0.5129, Class 2: 0.1372, Class 3: 0.2038, 
Class 4: 0.4160, Class 5: 0.4782, Class 6: 0.5529, 
Overall Mean Dice Score: 0.4255
Overall Mean F-beta Score: 0.4327

Training Loss: 1.8208, Validation Loss: 0.4556, Validation F-beta: 0.4327
Epoch 29/4000
Current lambda: 0.4500


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.509]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.489]


Validation Dice Score
Class 0: 0.9918, Class 1: 0.4363, Class 2: 0.0429, Class 3: 0.2941, 
Class 4: 0.5299, Class 5: 0.3685, Class 6: 0.6143, 
Validation F-beta Score
Class 0: 0.9912, Class 1: 0.5796, Class 2: 0.0457, Class 3: 0.4295, 
Class 4: 0.5218, Class 5: 0.4514, Class 6: 0.5471, 
Overall Mean Dice Score: 0.4486
Overall Mean F-beta Score: 0.5059

Training Loss: 1.8619, Validation Loss: 0.4473, Validation F-beta: 0.5059
Validation loss did not improve. Reducing lambda to 0.4400
Epoch 30/4000
Current lambda: 0.4400


Training: 100%|██████████| 96/96 [02:28<00:00,  1.55s/it, loss=0.502]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s, loss=0.366]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.7042, Class 2: 0.0091, Class 3: 0.2578, 
Class 4: 0.3551, Class 5: 0.4060, Class 6: 0.4092, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.7581, Class 2: 0.0106, Class 3: 0.2110, 
Class 4: 0.3376, Class 5: 0.4025, Class 6: 0.4264, 
Overall Mean Dice Score: 0.4264
Overall Mean F-beta Score: 0.4271

Training Loss: 1.8201, Validation Loss: 0.4361, Validation F-beta: 0.4271
Epoch 31/4000
Current lambda: 0.4400


Training: 100%|██████████| 96/96 [02:28<00:00,  1.55s/it, loss=0.476]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.24it/s, loss=0.409]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.3039, Class 2: 0.0000, Class 3: 0.2371, 
Class 4: 0.3631, Class 5: 0.3622, Class 6: 0.4248, 
Validation F-beta Score
Class 0: 0.9864, Class 1: 0.3175, Class 2: 0.0000, Class 3: 0.2744, 
Class 4: 0.4184, Class 5: 0.3979, Class 6: 0.4258, 
Overall Mean Dice Score: 0.3382
Overall Mean F-beta Score: 0.3668

Training Loss: 1.8103, Validation Loss: 0.4810, Validation F-beta: 0.3668
Epoch 32/4000
Current lambda: 0.4400


Training: 100%|██████████| 96/96 [02:28<00:00,  1.55s/it, loss=0.498]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.472]


Validation Dice Score
Class 0: 0.9839, Class 1: 0.3376, Class 2: 0.0338, Class 3: 0.3606, 
Class 4: 0.5356, Class 5: 0.3830, Class 6: 0.1936, 
Validation F-beta Score
Class 0: 0.9851, Class 1: 0.4185, Class 2: 0.0719, Class 3: 0.3792, 
Class 4: 0.4820, Class 5: 0.4250, Class 6: 0.2116, 
Overall Mean Dice Score: 0.3621
Overall Mean F-beta Score: 0.3833

Training Loss: 1.7840, Validation Loss: 0.4647, Validation F-beta: 0.3833
Epoch 33/4000
Current lambda: 0.4400


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.448]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.388]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.3590, Class 2: 0.1273, Class 3: 0.3669, 
Class 4: 0.4645, Class 5: 0.5165, Class 6: 0.7529, 
Validation F-beta Score
Class 0: 0.9888, Class 1: 0.4389, Class 2: 0.1228, Class 3: 0.4443, 
Class 4: 0.5389, Class 5: 0.5101, Class 6: 0.8908, 
Overall Mean Dice Score: 0.4920
Overall Mean F-beta Score: 0.5646

Training Loss: 1.7958, Validation Loss: 0.4075, Validation F-beta: 0.5646
SUPER Best model saved. Loss:0.4075, Score:0.5646
Epoch 34/4000
Current lambda: 0.4400


Training: 100%|██████████| 96/96 [02:28<00:00,  1.55s/it, loss=0.457]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s, loss=0.344]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.5681, Class 2: 0.0253, Class 3: 0.2827, 
Class 4: 0.3589, Class 5: 0.4178, Class 6: 0.6694, 
Validation F-beta Score
Class 0: 0.9868, Class 1: 0.7652, Class 2: 0.0244, Class 3: 0.2614, 
Class 4: 0.3347, Class 5: 0.4435, Class 6: 0.6315, 
Overall Mean Dice Score: 0.4594
Overall Mean F-beta Score: 0.4873

Training Loss: 1.7674, Validation Loss: 0.4356, Validation F-beta: 0.4873
Epoch 35/4000
Current lambda: 0.4400


Training: 100%|██████████| 96/96 [02:40<00:00,  1.67s/it, loss=0.529]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.37it/s, loss=0.455]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.1314, Class 2: 0.1320, Class 3: 0.2817, 
Class 4: 0.5309, Class 5: 0.5542, Class 6: 0.6069, 
Validation F-beta Score
Class 0: 0.9846, Class 1: 0.1772, Class 2: 0.1777, Class 3: 0.3191, 
Class 4: 0.6978, Class 5: 0.5088, Class 6: 0.5960, 
Overall Mean Dice Score: 0.4210
Overall Mean F-beta Score: 0.4598

Training Loss: 1.7644, Validation Loss: 0.4700, Validation F-beta: 0.4598
Epoch 36/4000
Current lambda: 0.4400


Training: 100%|██████████| 96/96 [02:45<00:00,  1.72s/it, loss=0.499]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.31it/s, loss=0.516]


Validation Dice Score
Class 0: 0.9906, Class 1: 0.5983, Class 2: 0.1576, Class 3: 0.2454, 
Class 4: 0.2801, Class 5: 0.4124, Class 6: 0.6439, 
Validation F-beta Score
Class 0: 0.9900, Class 1: 0.5944, Class 2: 0.1876, Class 3: 0.2284, 
Class 4: 0.2655, Class 5: 0.4104, Class 6: 0.7036, 
Overall Mean Dice Score: 0.4360
Overall Mean F-beta Score: 0.4405

Training Loss: 1.7122, Validation Loss: 0.4480, Validation F-beta: 0.4405
Epoch 37/4000
Current lambda: 0.4400


Training: 100%|██████████| 96/96 [02:46<00:00,  1.73s/it, loss=0.45] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.35it/s, loss=0.403]


Validation Dice Score
Class 0: 0.9917, Class 1: 0.4480, Class 2: 0.2021, Class 3: 0.2029, 
Class 4: 0.4091, Class 5: 0.5621, Class 6: 0.7966, 
Validation F-beta Score
Class 0: 0.9898, Class 1: 0.4955, Class 2: 0.2149, Class 3: 0.2117, 
Class 4: 0.5963, Class 5: 0.5602, Class 6: 0.8398, 
Overall Mean Dice Score: 0.4838
Overall Mean F-beta Score: 0.5407

Training Loss: 1.7427, Validation Loss: 0.4156, Validation F-beta: 0.5407
Epoch 38/4000
Current lambda: 0.4400


Training: 100%|██████████| 96/96 [02:45<00:00,  1.72s/it, loss=0.338]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.37it/s, loss=0.503]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7116, Class 2: 0.2200, Class 3: 0.3202, 
Class 4: 0.5646, Class 5: 0.4419, Class 6: 0.3841, 
Validation F-beta Score
Class 0: 0.9901, Class 1: 0.7783, Class 2: 0.2171, Class 3: 0.2832, 
Class 4: 0.5214, Class 5: 0.4696, Class 6: 0.4698, 
Overall Mean Dice Score: 0.4845
Overall Mean F-beta Score: 0.5045

Training Loss: 1.7703, Validation Loss: 0.4212, Validation F-beta: 0.5045
Validation loss did not improve. Reducing lambda to 0.4300
Epoch 39/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.51] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.406]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.7368, Class 2: 0.1928, Class 3: 0.2544, 
Class 4: 0.4741, Class 5: 0.3976, Class 6: 0.5578, 
Validation F-beta Score
Class 0: 0.9893, Class 1: 0.8402, Class 2: 0.2376, Class 3: 0.3208, 
Class 4: 0.4965, Class 5: 0.4210, Class 6: 0.5605, 
Overall Mean Dice Score: 0.4841
Overall Mean F-beta Score: 0.5278

Training Loss: 1.7551, Validation Loss: 0.4112, Validation F-beta: 0.5278
Epoch 40/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:29<00:00,  1.55s/it, loss=0.485]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s, loss=0.484]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.5655, Class 2: 0.0596, Class 3: 0.2539, 
Class 4: 0.1726, Class 5: 0.4061, Class 6: 0.6055, 
Validation F-beta Score
Class 0: 0.9904, Class 1: 0.6588, Class 2: 0.0981, Class 3: 0.2750, 
Class 4: 0.1742, Class 5: 0.3656, Class 6: 0.6633, 
Overall Mean Dice Score: 0.4007
Overall Mean F-beta Score: 0.4274

Training Loss: 1.7153, Validation Loss: 0.4752, Validation F-beta: 0.4274
Epoch 41/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.499]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.25it/s, loss=0.566]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.3526, Class 2: 0.1227, Class 3: 0.2685, 
Class 4: 0.2209, Class 5: 0.4974, Class 6: 0.6711, 
Validation F-beta Score
Class 0: 0.9850, Class 1: 0.3314, Class 2: 0.1219, Class 3: 0.3221, 
Class 4: 0.2412, Class 5: 0.5055, Class 6: 0.6681, 
Overall Mean Dice Score: 0.4021
Overall Mean F-beta Score: 0.4137

Training Loss: 1.7246, Validation Loss: 0.4202, Validation F-beta: 0.4137
Epoch 42/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:45<00:00,  1.72s/it, loss=0.542]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.555]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.2996, Class 2: 0.0632, Class 3: 0.3663, 
Class 4: 0.5125, Class 5: 0.2311, Class 6: 0.2217, 
Validation F-beta Score
Class 0: 0.9884, Class 1: 0.3547, Class 2: 0.0463, Class 3: 0.3349, 
Class 4: 0.5229, Class 5: 0.2857, Class 6: 0.2313, 
Overall Mean Dice Score: 0.3263
Overall Mean F-beta Score: 0.3459

Training Loss: 1.6963, Validation Loss: 0.4585, Validation F-beta: 0.3459
Epoch 43/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.503]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.499]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.0000, Class 2: 0.1236, Class 3: 0.3822, 
Class 4: 0.5903, Class 5: 0.4295, Class 6: 0.2223, 
Validation F-beta Score
Class 0: 0.9900, Class 1: 0.0000, Class 2: 0.1476, Class 3: 0.4015, 
Class 4: 0.6613, Class 5: 0.3997, Class 6: 0.2379, 
Overall Mean Dice Score: 0.3248
Overall Mean F-beta Score: 0.3401

Training Loss: 1.7167, Validation Loss: 0.4895, Validation F-beta: 0.3401
Epoch 44/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:37<00:00,  1.65s/it, loss=0.382]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.43it/s, loss=0.36] 


Validation Dice Score
Class 0: 0.9928, Class 1: 0.8009, Class 2: 0.0000, Class 3: 0.0627, 
Class 4: 0.6943, Class 5: 0.5791, Class 6: 0.6523, 
Validation F-beta Score
Class 0: 0.9934, Class 1: 0.8398, Class 2: 0.0000, Class 3: 0.0939, 
Class 4: 0.6719, Class 5: 0.5421, Class 6: 0.6532, 
Overall Mean Dice Score: 0.5579
Overall Mean F-beta Score: 0.5602

Training Loss: 1.7481, Validation Loss: 0.4074, Validation F-beta: 0.5602
Best model saved based on validation loss: 0.4074
Epoch 45/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:35<00:00,  1.62s/it, loss=0.414]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s, loss=0.341]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.7457, Class 2: 0.0366, Class 3: 0.4124, 
Class 4: 0.5015, Class 5: 0.4082, Class 6: 0.8560, 
Validation F-beta Score
Class 0: 0.9929, Class 1: 0.7031, Class 2: 0.0452, Class 3: 0.4551, 
Class 4: 0.3955, Class 5: 0.4111, Class 6: 0.8239, 
Overall Mean Dice Score: 0.5847
Overall Mean F-beta Score: 0.5577

Training Loss: 1.7392, Validation Loss: 0.4124, Validation F-beta: 0.5577
Epoch 46/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.428]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.30it/s, loss=0.302]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.5982, Class 2: 0.3284, Class 3: 0.2198, 
Class 4: 0.4642, Class 5: 0.3732, Class 6: 0.6127, 
Validation F-beta Score
Class 0: 0.9864, Class 1: 0.7258, Class 2: 0.3470, Class 3: 0.3175, 
Class 4: 0.5492, Class 5: 0.5172, Class 6: 0.6970, 
Overall Mean Dice Score: 0.4536
Overall Mean F-beta Score: 0.5613

Training Loss: 1.7480, Validation Loss: 0.4182, Validation F-beta: 0.5613
Epoch 47/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:38<00:00,  1.65s/it, loss=0.475]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.33it/s, loss=0.506]


Validation Dice Score
Class 0: 0.9923, Class 1: 0.4677, Class 2: 0.0016, Class 3: 0.2089, 
Class 4: 0.5476, Class 5: 0.4555, Class 6: 0.4579, 
Validation F-beta Score
Class 0: 0.9898, Class 1: 0.5521, Class 2: 0.0027, Class 3: 0.2970, 
Class 4: 0.5405, Class 5: 0.5173, Class 6: 0.4689, 
Overall Mean Dice Score: 0.4275
Overall Mean F-beta Score: 0.4752

Training Loss: 1.7042, Validation Loss: 0.4415, Validation F-beta: 0.4752
Epoch 48/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.451]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.43it/s, loss=0.344]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.7150, Class 2: 0.1802, Class 3: 0.3848, 
Class 4: 0.6991, Class 5: 0.5493, Class 6: 0.6317, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.7207, Class 2: 0.2859, Class 3: 0.4082, 
Class 4: 0.6615, Class 5: 0.5127, Class 6: 0.5813, 
Overall Mean Dice Score: 0.5960
Overall Mean F-beta Score: 0.5769

Training Loss: 1.7167, Validation Loss: 0.3794, Validation F-beta: 0.5769
SUPER Best model saved. Loss:0.3794, Score:0.5769
Epoch 49/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.498]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.41it/s, loss=0.5]  


Validation Dice Score
Class 0: 0.9903, Class 1: 0.3844, Class 2: 0.0000, Class 3: 0.2724, 
Class 4: 0.6017, Class 5: 0.4878, Class 6: 0.4190, 
Validation F-beta Score
Class 0: 0.9902, Class 1: 0.4515, Class 2: 0.0000, Class 3: 0.2797, 
Class 4: 0.5967, Class 5: 0.4785, Class 6: 0.4807, 
Overall Mean Dice Score: 0.4331
Overall Mean F-beta Score: 0.4574

Training Loss: 1.7477, Validation Loss: 0.4262, Validation F-beta: 0.4574
Epoch 50/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.395]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=0.372]


Validation Dice Score
Class 0: 0.9891, Class 1: 0.3174, Class 2: 0.2266, Class 3: 0.1876, 
Class 4: 0.4896, Class 5: 0.5191, Class 6: 0.6011, 
Validation F-beta Score
Class 0: 0.9895, Class 1: 0.3088, Class 2: 0.2938, Class 3: 0.1436, 
Class 4: 0.4687, Class 5: 0.5514, Class 6: 0.6915, 
Overall Mean Dice Score: 0.4230
Overall Mean F-beta Score: 0.4328

Training Loss: 1.7247, Validation Loss: 0.4283, Validation F-beta: 0.4328
Epoch 51/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:35<00:00,  1.62s/it, loss=0.487]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.19it/s, loss=0.461]


Validation Dice Score
Class 0: 0.9922, Class 1: 0.6152, Class 2: 0.1597, Class 3: 0.5272, 
Class 4: 0.3515, Class 5: 0.3865, Class 6: 0.4444, 
Validation F-beta Score
Class 0: 0.9923, Class 1: 0.6238, Class 2: 0.1239, Class 3: 0.4944, 
Class 4: 0.2943, Class 5: 0.5197, Class 6: 0.4481, 
Overall Mean Dice Score: 0.4650
Overall Mean F-beta Score: 0.4761

Training Loss: 1.7078, Validation Loss: 0.4059, Validation F-beta: 0.4761
Epoch 52/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.498]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.41it/s, loss=0.529]


Validation Dice Score
Class 0: 0.9909, Class 1: 0.5038, Class 2: 0.2221, Class 3: 0.3446, 
Class 4: 0.3119, Class 5: 0.5191, Class 6: 0.6406, 
Validation F-beta Score
Class 0: 0.9925, Class 1: 0.5388, Class 2: 0.2751, Class 3: 0.3686, 
Class 4: 0.2859, Class 5: 0.6184, Class 6: 0.6500, 
Overall Mean Dice Score: 0.4640
Overall Mean F-beta Score: 0.4923

Training Loss: 1.7345, Validation Loss: 0.4139, Validation F-beta: 0.4923
Epoch 53/4000
Current lambda: 0.4300


Training: 100%|██████████| 96/96 [02:33<00:00,  1.60s/it, loss=0.404]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.43it/s, loss=0.416]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.7102, Class 2: 0.1707, Class 3: 0.2833, 
Class 4: 0.5620, Class 5: 0.4171, Class 6: 0.6110, 
Validation F-beta Score
Class 0: 0.9906, Class 1: 0.7792, Class 2: 0.1645, Class 3: 0.2571, 
Class 4: 0.5718, Class 5: 0.4691, Class 6: 0.6711, 
Overall Mean Dice Score: 0.5167
Overall Mean F-beta Score: 0.5497

Training Loss: 1.7031, Validation Loss: 0.4005, Validation F-beta: 0.5497
Validation loss did not improve. Reducing lambda to 0.4200
Epoch 54/4000
Current lambda: 0.4200


Training: 100%|██████████| 96/96 [02:35<00:00,  1.62s/it, loss=0.422]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s, loss=0.478]


Validation Dice Score
Class 0: 0.9880, Class 1: 0.7853, Class 2: 0.1247, Class 3: 0.2493, 
Class 4: 0.4402, Class 5: 0.3996, Class 6: 0.6342, 
Validation F-beta Score
Class 0: 0.9891, Class 1: 0.8432, Class 2: 0.1236, Class 3: 0.2654, 
Class 4: 0.4713, Class 5: 0.4072, Class 6: 0.6719, 
Overall Mean Dice Score: 0.5017
Overall Mean F-beta Score: 0.5318

Training Loss: 1.6993, Validation Loss: 0.4179, Validation F-beta: 0.5318
Epoch 55/4000
Current lambda: 0.4200


Training: 100%|██████████| 96/96 [02:39<00:00,  1.67s/it, loss=0.382]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=0.393]


Validation Dice Score
Class 0: 0.9933, Class 1: 0.5405, Class 2: 0.3168, Class 3: 0.3244, 
Class 4: 0.5265, Class 5: 0.5425, Class 6: 0.6822, 
Validation F-beta Score
Class 0: 0.9923, Class 1: 0.6171, Class 2: 0.3776, Class 3: 0.3120, 
Class 4: 0.5641, Class 5: 0.6176, Class 6: 0.7244, 
Overall Mean Dice Score: 0.5232
Overall Mean F-beta Score: 0.5670

Training Loss: 1.6694, Validation Loss: 0.4319, Validation F-beta: 0.5670
Epoch 56/4000
Current lambda: 0.4200


Training: 100%|██████████| 96/96 [02:41<00:00,  1.69s/it, loss=0.393]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=0.366]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.6203, Class 2: 0.2456, Class 3: 0.4791, 
Class 4: 0.5643, Class 5: 0.3589, Class 6: 0.5653, 
Validation F-beta Score
Class 0: 0.9911, Class 1: 0.6134, Class 2: 0.3684, Class 3: 0.4211, 
Class 4: 0.5065, Class 5: 0.3821, Class 6: 0.5894, 
Overall Mean Dice Score: 0.5176
Overall Mean F-beta Score: 0.5025

Training Loss: 1.6763, Validation Loss: 0.3684, Validation F-beta: 0.5025
Best model saved based on validation loss: 0.3684
Epoch 57/4000
Current lambda: 0.4200


Training: 100%|██████████| 96/96 [02:43<00:00,  1.70s/it, loss=0.39] 
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.31it/s, loss=0.449]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.0000, Class 2: 0.2127, Class 3: 0.2616, 
Class 4: 0.4368, Class 5: 0.3504, Class 6: 0.5600, 
Validation F-beta Score
Class 0: 0.9914, Class 1: 0.0000, Class 2: 0.3437, Class 3: 0.2523, 
Class 4: 0.4507, Class 5: 0.3039, Class 6: 0.6813, 
Overall Mean Dice Score: 0.3218
Overall Mean F-beta Score: 0.3376

Training Loss: 1.6630, Validation Loss: 0.4797, Validation F-beta: 0.3376
Epoch 58/4000
Current lambda: 0.4200


Training: 100%|██████████| 96/96 [02:46<00:00,  1.73s/it, loss=0.43] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.35it/s, loss=0.377]


Validation Dice Score
Class 0: 0.9921, Class 1: 0.7675, Class 2: 0.1720, Class 3: 0.1778, 
Class 4: 0.4928, Class 5: 0.4478, Class 6: 0.5477, 
Validation F-beta Score
Class 0: 0.9921, Class 1: 0.8890, Class 2: 0.2131, Class 3: 0.1476, 
Class 4: 0.4433, Class 5: 0.4592, Class 6: 0.6371, 
Overall Mean Dice Score: 0.4867
Overall Mean F-beta Score: 0.5152

Training Loss: 1.6892, Validation Loss: 0.4227, Validation F-beta: 0.5152
Epoch 59/4000
Current lambda: 0.4200


Training: 100%|██████████| 96/96 [02:48<00:00,  1.76s/it, loss=0.437]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.33it/s, loss=0.286]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.5983, Class 2: 0.2393, Class 3: 0.3621, 
Class 4: 0.4918, Class 5: 0.2822, Class 6: 0.6498, 
Validation F-beta Score
Class 0: 0.9921, Class 1: 0.6363, Class 2: 0.2566, Class 3: 0.3790, 
Class 4: 0.4146, Class 5: 0.2189, Class 6: 0.6575, 
Overall Mean Dice Score: 0.4768
Overall Mean F-beta Score: 0.4613

Training Loss: 1.6916, Validation Loss: 0.4018, Validation F-beta: 0.4613
Epoch 60/4000
Current lambda: 0.4200


Training: 100%|██████████| 96/96 [02:43<00:00,  1.71s/it, loss=0.348]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.355]


Validation Dice Score
Class 0: 0.9916, Class 1: 0.6964, Class 2: 0.1028, Class 3: 0.2057, 
Class 4: 0.5228, Class 5: 0.5270, Class 6: 0.6472, 
Validation F-beta Score
Class 0: 0.9895, Class 1: 0.7819, Class 2: 0.1000, Class 3: 0.1721, 
Class 4: 0.6452, Class 5: 0.6077, Class 6: 0.7259, 
Overall Mean Dice Score: 0.5198
Overall Mean F-beta Score: 0.5866

Training Loss: 1.6905, Validation Loss: 0.4065, Validation F-beta: 0.5866
Validation loss did not improve. Reducing lambda to 0.4100
Epoch 61/4000
Current lambda: 0.4100


Training: 100%|██████████| 96/96 [02:45<00:00,  1.72s/it, loss=0.401]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.13it/s, loss=0.472]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.3919, Class 2: 0.3838, Class 3: 0.5152, 
Class 4: 0.5482, Class 5: 0.1279, Class 6: 0.4624, 
Validation F-beta Score
Class 0: 0.9894, Class 1: 0.3702, Class 2: 0.4053, Class 3: 0.5168, 
Class 4: 0.5509, Class 5: 0.1264, Class 6: 0.4757, 
Overall Mean Dice Score: 0.4091
Overall Mean F-beta Score: 0.4080

Training Loss: 1.6756, Validation Loss: 0.4346, Validation F-beta: 0.4080
Epoch 62/4000
Current lambda: 0.4100


Training: 100%|██████████| 96/96 [02:46<00:00,  1.74s/it, loss=0.497]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.30it/s, loss=0.406]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.7675, Class 2: 0.0000, Class 3: 0.1739, 
Class 4: 0.2727, Class 5: 0.4765, Class 6: 0.8644, 
Validation F-beta Score
Class 0: 0.9912, Class 1: 0.7284, Class 2: 0.0000, Class 3: 0.1817, 
Class 4: 0.2139, Class 5: 0.4954, Class 6: 0.9308, 
Overall Mean Dice Score: 0.5110
Overall Mean F-beta Score: 0.5100

Training Loss: 1.7131, Validation Loss: 0.4284, Validation F-beta: 0.5100
Epoch 63/4000
Current lambda: 0.4100


Training: 100%|██████████| 96/96 [02:45<00:00,  1.72s/it, loss=0.418]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.37it/s, loss=0.338]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.6094, Class 2: 0.0697, Class 3: 0.3855, 
Class 4: 0.6004, Class 5: 0.4302, Class 6: 0.8296, 
Validation F-beta Score
Class 0: 0.9909, Class 1: 0.6810, Class 2: 0.0772, Class 3: 0.3578, 
Class 4: 0.5061, Class 5: 0.3524, Class 6: 0.8864, 
Overall Mean Dice Score: 0.5710
Overall Mean F-beta Score: 0.5567

Training Loss: 1.7112, Validation Loss: 0.4042, Validation F-beta: 0.5567
Epoch 64/4000
Current lambda: 0.4100


Training: 100%|██████████| 96/96 [02:41<00:00,  1.68s/it, loss=0.434]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.29it/s, loss=0.462]


Validation Dice Score
Class 0: 0.9917, Class 1: 0.5782, Class 2: 0.2228, Class 3: 0.1541, 
Class 4: 0.2664, Class 5: 0.5239, Class 6: 0.6655, 
Validation F-beta Score
Class 0: 0.9927, Class 1: 0.6367, Class 2: 0.2347, Class 3: 0.1197, 
Class 4: 0.2412, Class 5: 0.4860, Class 6: 0.7160, 
Overall Mean Dice Score: 0.4376
Overall Mean F-beta Score: 0.4399

Training Loss: 1.6649, Validation Loss: 0.4395, Validation F-beta: 0.4399
Epoch 65/4000
Current lambda: 0.4100


Training: 100%|██████████| 96/96 [02:40<00:00,  1.67s/it, loss=0.457]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.421]


Validation Dice Score
Class 0: 0.9933, Class 1: 0.3282, Class 2: 0.1107, Class 3: 0.4797, 
Class 4: 0.4237, Class 5: 0.3860, Class 6: 0.2289, 
Validation F-beta Score
Class 0: 0.9928, Class 1: 0.4336, Class 2: 0.1271, Class 3: 0.4804, 
Class 4: 0.4005, Class 5: 0.3750, Class 6: 0.2385, 
Overall Mean Dice Score: 0.3693
Overall Mean F-beta Score: 0.3856

Training Loss: 1.6840, Validation Loss: 0.4367, Validation F-beta: 0.3856
Epoch 66/4000
Current lambda: 0.4100


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.406]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.37it/s, loss=0.441]


Validation Dice Score
Class 0: 0.9913, Class 1: 0.3658, Class 2: 0.0001, Class 3: 0.4008, 
Class 4: 0.5560, Class 5: 0.5240, Class 6: 0.4532, 
Validation F-beta Score
Class 0: 0.9913, Class 1: 0.3557, Class 2: 0.0002, Class 3: 0.4133, 
Class 4: 0.5188, Class 5: 0.6617, Class 6: 0.4413, 
Overall Mean Dice Score: 0.4600
Overall Mean F-beta Score: 0.4782

Training Loss: 1.6703, Validation Loss: 0.4622, Validation F-beta: 0.4782
Validation loss did not improve. Reducing lambda to 0.4000
Epoch 67/4000
Current lambda: 0.4000


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.43] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.43it/s, loss=0.499]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.5789, Class 2: 0.2083, Class 3: 0.2296, 
Class 4: 0.4762, Class 5: 0.3330, Class 6: 0.4279, 
Validation F-beta Score
Class 0: 0.9914, Class 1: 0.6904, Class 2: 0.2461, Class 3: 0.2028, 
Class 4: 0.4702, Class 5: 0.4129, Class 6: 0.4613, 
Overall Mean Dice Score: 0.4091
Overall Mean F-beta Score: 0.4475

Training Loss: 1.6918, Validation Loss: 0.4487, Validation F-beta: 0.4475
Epoch 68/4000
Current lambda: 0.4000


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.439]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.446]


Validation Dice Score
Class 0: 0.9913, Class 1: 0.8095, Class 2: 0.2713, Class 3: 0.2729, 
Class 4: 0.5408, Class 5: 0.5090, Class 6: 0.0129, 
Validation F-beta Score
Class 0: 0.9937, Class 1: 0.8502, Class 2: 0.2874, Class 3: 0.3197, 
Class 4: 0.4887, Class 5: 0.4747, Class 6: 0.0215, 
Overall Mean Dice Score: 0.4290
Overall Mean F-beta Score: 0.4310

Training Loss: 1.6599, Validation Loss: 0.4453, Validation F-beta: 0.4310
Epoch 69/4000
Current lambda: 0.4000


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.446]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.476]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.5624, Class 2: 0.1019, Class 3: 0.4021, 
Class 4: 0.5092, Class 5: 0.3076, Class 6: 0.8438, 
Validation F-beta Score
Class 0: 0.9880, Class 1: 0.5792, Class 2: 0.0986, Class 3: 0.4467, 
Class 4: 0.5768, Class 5: 0.3009, Class 6: 0.8976, 
Overall Mean Dice Score: 0.5250
Overall Mean F-beta Score: 0.5603

Training Loss: 1.6758, Validation Loss: 0.4131, Validation F-beta: 0.5603
Epoch 70/4000
Current lambda: 0.4000


Training: 100%|██████████| 96/96 [02:38<00:00,  1.65s/it, loss=0.395]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.41it/s, loss=0.361]


Validation Dice Score
Class 0: 0.9924, Class 1: 0.3930, Class 2: 0.1371, Class 3: 0.5581, 
Class 4: 0.7132, Class 5: 0.4848, Class 6: 0.8243, 
Validation F-beta Score
Class 0: 0.9931, Class 1: 0.3781, Class 2: 0.1558, Class 3: 0.5912, 
Class 4: 0.7132, Class 5: 0.4067, Class 6: 0.9206, 
Overall Mean Dice Score: 0.5947
Overall Mean F-beta Score: 0.6020

Training Loss: 1.6603, Validation Loss: 0.3512, Validation F-beta: 0.6020
SUPER Best model saved. Loss:0.3512, Score:0.6020
Epoch 71/4000
Current lambda: 0.4000


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.472]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.18it/s, loss=0.486]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.7537, Class 2: 0.0000, Class 3: 0.4210, 
Class 4: 0.2115, Class 5: 0.3144, Class 6: 0.6219, 
Validation F-beta Score
Class 0: 0.9931, Class 1: 0.7474, Class 2: 0.0000, Class 3: 0.4239, 
Class 4: 0.4008, Class 5: 0.3448, Class 6: 0.6393, 
Overall Mean Dice Score: 0.4645
Overall Mean F-beta Score: 0.5112

Training Loss: 1.7096, Validation Loss: 0.4320, Validation F-beta: 0.5112
Epoch 72/4000
Current lambda: 0.4000


Training: 100%|██████████| 96/96 [02:31<00:00,  1.58s/it, loss=0.456]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.393]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.7410, Class 2: 0.2261, Class 3: 0.5404, 
Class 4: 0.5348, Class 5: 0.4532, Class 6: 0.6880, 
Validation F-beta Score
Class 0: 0.9916, Class 1: 0.6592, Class 2: 0.2325, Class 3: 0.6027, 
Class 4: 0.5022, Class 5: 0.3828, Class 6: 0.7448, 
Overall Mean Dice Score: 0.5915
Overall Mean F-beta Score: 0.5783

Training Loss: 1.6769, Validation Loss: 0.3584, Validation F-beta: 0.5783
Epoch 73/4000
Current lambda: 0.4000


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.476]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.341]


Validation Dice Score
Class 0: 0.9918, Class 1: 0.5282, Class 2: 0.0953, Class 3: 0.4562, 
Class 4: 0.3006, Class 5: 0.4941, Class 6: 0.9211, 
Validation F-beta Score
Class 0: 0.9929, Class 1: 0.4620, Class 2: 0.1044, Class 3: 0.4392, 
Class 4: 0.2924, Class 5: 0.4887, Class 6: 0.9123, 
Overall Mean Dice Score: 0.5400
Overall Mean F-beta Score: 0.5189

Training Loss: 1.6759, Validation Loss: 0.3910, Validation F-beta: 0.5189
Epoch 74/4000
Current lambda: 0.4000


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.413]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.424]


Validation Dice Score
Class 0: 0.9912, Class 1: 0.6071, Class 2: 0.0656, Class 3: 0.1014, 
Class 4: 0.5019, Class 5: 0.5327, Class 6: 0.6896, 
Validation F-beta Score
Class 0: 0.9910, Class 1: 0.6508, Class 2: 0.0917, Class 3: 0.1085, 
Class 4: 0.5315, Class 5: 0.4778, Class 6: 0.9587, 
Overall Mean Dice Score: 0.4865
Overall Mean F-beta Score: 0.5455

Training Loss: 1.6623, Validation Loss: 0.4219, Validation F-beta: 0.5455
Epoch 75/4000
Current lambda: 0.4000


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.43] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s, loss=0.491]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.4948, Class 2: 0.1953, Class 3: 0.1894, 
Class 4: 0.2512, Class 5: 0.3290, Class 6: 0.8560, 
Validation F-beta Score
Class 0: 0.9916, Class 1: 0.5396, Class 2: 0.2665, Class 3: 0.1664, 
Class 4: 0.1938, Class 5: 0.2882, Class 6: 0.8421, 
Overall Mean Dice Score: 0.4241
Overall Mean F-beta Score: 0.4060

Training Loss: 1.6763, Validation Loss: 0.4572, Validation F-beta: 0.4060
Validation loss did not improve. Reducing lambda to 0.3900
Epoch 76/4000
Current lambda: 0.3900


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.391]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.44] 


Validation Dice Score
Class 0: 0.9901, Class 1: 0.7650, Class 2: 0.2276, Class 3: 0.0458, 
Class 4: 0.5595, Class 5: 0.3174, Class 6: 0.6595, 
Validation F-beta Score
Class 0: 0.9899, Class 1: 0.7765, Class 2: 0.1965, Class 3: 0.0409, 
Class 4: 0.5928, Class 5: 0.3234, Class 6: 0.6659, 
Overall Mean Dice Score: 0.4694
Overall Mean F-beta Score: 0.4799

Training Loss: 1.6931, Validation Loss: 0.4459, Validation F-beta: 0.4799
Epoch 77/4000
Current lambda: 0.3900


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.393]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.511]


Validation Dice Score
Class 0: 0.9907, Class 1: 0.3520, Class 2: 0.2783, Class 3: 0.4412, 
Class 4: 0.6222, Class 5: 0.4861, Class 6: 0.2172, 
Validation F-beta Score
Class 0: 0.9918, Class 1: 0.3783, Class 2: 0.3334, Class 3: 0.3747, 
Class 4: 0.6614, Class 5: 0.4307, Class 6: 0.2208, 
Overall Mean Dice Score: 0.4238
Overall Mean F-beta Score: 0.4132

Training Loss: 1.7196, Validation Loss: 0.4410, Validation F-beta: 0.4132
Epoch 78/4000
Current lambda: 0.3900


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.43] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.453]


Validation Dice Score
Class 0: 0.9905, Class 1: 0.7819, Class 2: 0.0237, Class 3: 0.2774, 
Class 4: 0.5236, Class 5: 0.4061, Class 6: 0.8556, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.7819, Class 2: 0.0265, Class 3: 0.2940, 
Class 4: 0.5540, Class 5: 0.4458, Class 6: 0.8548, 
Overall Mean Dice Score: 0.5689
Overall Mean F-beta Score: 0.5861

Training Loss: 1.7030, Validation Loss: 0.3855, Validation F-beta: 0.5861
Epoch 79/4000
Current lambda: 0.3900


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.434]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.42] 


Validation Dice Score
Class 0: 0.9899, Class 1: 0.6984, Class 2: 0.1991, Class 3: 0.3907, 
Class 4: 0.5843, Class 5: 0.4144, Class 6: 0.7180, 
Validation F-beta Score
Class 0: 0.9920, Class 1: 0.6343, Class 2: 0.2505, Class 3: 0.3311, 
Class 4: 0.5167, Class 5: 0.4471, Class 6: 0.7373, 
Overall Mean Dice Score: 0.5612
Overall Mean F-beta Score: 0.5333

Training Loss: 1.6646, Validation Loss: 0.3852, Validation F-beta: 0.5333
Epoch 80/4000
Current lambda: 0.3900


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.383]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.525]


Validation Dice Score
Class 0: 0.9893, Class 1: 0.3518, Class 2: 0.0000, Class 3: 0.2787, 
Class 4: 0.5250, Class 5: 0.3695, Class 6: 0.6308, 
Validation F-beta Score
Class 0: 0.9891, Class 1: 0.4315, Class 2: 0.0000, Class 3: 0.2537, 
Class 4: 0.4758, Class 5: 0.3601, Class 6: 0.6984, 
Overall Mean Dice Score: 0.4312
Overall Mean F-beta Score: 0.4439

Training Loss: 1.6832, Validation Loss: 0.4420, Validation F-beta: 0.4439
Epoch 81/4000
Current lambda: 0.3900


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.301]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.22it/s, loss=0.395]


Validation Dice Score
Class 0: 0.9937, Class 1: 0.7263, Class 2: 0.1650, Class 3: 0.5656, 
Class 4: 0.6471, Class 5: 0.4813, Class 6: 0.6320, 
Validation F-beta Score
Class 0: 0.9930, Class 1: 0.7493, Class 2: 0.1642, Class 3: 0.5923, 
Class 4: 0.7425, Class 5: 0.4821, Class 6: 0.6222, 
Overall Mean Dice Score: 0.6104
Overall Mean F-beta Score: 0.6377

Training Loss: 1.6789, Validation Loss: 0.3590, Validation F-beta: 0.6377
Validation loss did not improve. Reducing lambda to 0.3800
Epoch 82/4000
Current lambda: 0.3800


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.43] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s, loss=0.439]


Validation Dice Score
Class 0: 0.9917, Class 1: 0.5866, Class 2: 0.0149, Class 3: 0.2254, 
Class 4: 0.6236, Class 5: 0.5123, Class 6: 0.8868, 
Validation F-beta Score
Class 0: 0.9937, Class 1: 0.5445, Class 2: 0.0135, Class 3: 0.2299, 
Class 4: 0.5489, Class 5: 0.5416, Class 6: 0.8936, 
Overall Mean Dice Score: 0.5670
Overall Mean F-beta Score: 0.5517

Training Loss: 1.6802, Validation Loss: 0.3711, Validation F-beta: 0.5517
Epoch 83/4000
Current lambda: 0.3800


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.532]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.48it/s, loss=0.448]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.7802, Class 2: 0.0869, Class 3: 0.2588, 
Class 4: 0.5657, Class 5: 0.5264, Class 6: 0.5020, 
Validation F-beta Score
Class 0: 0.9912, Class 1: 0.8949, Class 2: 0.0743, Class 3: 0.2002, 
Class 4: 0.5973, Class 5: 0.4576, Class 6: 0.5468, 
Overall Mean Dice Score: 0.5266
Overall Mean F-beta Score: 0.5394

Training Loss: 1.6590, Validation Loss: 0.4117, Validation F-beta: 0.5394
Epoch 84/4000
Current lambda: 0.3800


Training: 100%|██████████| 96/96 [02:40<00:00,  1.67s/it, loss=0.461]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.37it/s, loss=0.433]


Validation Dice Score
Class 0: 0.9915, Class 1: 0.6960, Class 2: 0.2224, Class 3: 0.6437, 
Class 4: 0.3882, Class 5: 0.3661, Class 6: 0.3727, 
Validation F-beta Score
Class 0: 0.9919, Class 1: 0.6714, Class 2: 0.3020, Class 3: 0.6308, 
Class 4: 0.3106, Class 5: 0.4610, Class 6: 0.4235, 
Overall Mean Dice Score: 0.4933
Overall Mean F-beta Score: 0.4994

Training Loss: 1.6221, Validation Loss: 0.3991, Validation F-beta: 0.4994
Epoch 85/4000
Current lambda: 0.3800


Training: 100%|██████████| 96/96 [02:39<00:00,  1.66s/it, loss=0.454]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.347]


Validation Dice Score
Class 0: 0.9906, Class 1: 0.4624, Class 2: 0.1463, Class 3: 0.2612, 
Class 4: 0.4494, Class 5: 0.5642, Class 6: 0.8464, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.4060, Class 2: 0.1479, Class 3: 0.2905, 
Class 4: 0.4626, Class 5: 0.5673, Class 6: 0.9088, 
Overall Mean Dice Score: 0.5167
Overall Mean F-beta Score: 0.5270

Training Loss: 1.7008, Validation Loss: 0.4170, Validation F-beta: 0.5270
Epoch 86/4000
Current lambda: 0.3800


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.451]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.37it/s, loss=0.414]


Validation Dice Score
Class 0: 0.9900, Class 1: 0.7498, Class 2: 0.4444, Class 3: 0.4623, 
Class 4: 0.4848, Class 5: 0.4807, Class 6: 0.4148, 
Validation F-beta Score
Class 0: 0.9934, Class 1: 0.7660, Class 2: 0.4557, Class 3: 0.5233, 
Class 4: 0.4041, Class 5: 0.4345, Class 6: 0.4517, 
Overall Mean Dice Score: 0.5185
Overall Mean F-beta Score: 0.5159

Training Loss: 1.6809, Validation Loss: 0.4061, Validation F-beta: 0.5159
Epoch 87/4000
Current lambda: 0.3800


Training: 100%|██████████| 96/96 [02:30<00:00,  1.57s/it, loss=0.456]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.535]


Validation Dice Score
Class 0: 0.9892, Class 1: 0.4150, Class 2: 0.1146, Class 3: 0.2631, 
Class 4: 0.6745, Class 5: 0.2921, Class 6: 0.4509, 
Validation F-beta Score
Class 0: 0.9912, Class 1: 0.3917, Class 2: 0.1615, Class 3: 0.2418, 
Class 4: 0.6332, Class 5: 0.2651, Class 6: 0.5500, 
Overall Mean Dice Score: 0.4191
Overall Mean F-beta Score: 0.4164

Training Loss: 1.6932, Validation Loss: 0.4345, Validation F-beta: 0.4164
Validation loss did not improve. Reducing lambda to 0.3700
Epoch 88/4000
Current lambda: 0.3700


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.447]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s, loss=0.423]


Validation Dice Score
Class 0: 0.9885, Class 1: 0.4399, Class 2: 0.0076, Class 3: 0.1896, 
Class 4: 0.4237, Class 5: 0.3329, Class 6: 0.6798, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.3581, Class 2: 0.0066, Class 3: 0.2279, 
Class 4: 0.3415, Class 5: 0.2932, Class 6: 0.6832, 
Overall Mean Dice Score: 0.4132
Overall Mean F-beta Score: 0.3808

Training Loss: 1.6524, Validation Loss: 0.4483, Validation F-beta: 0.3808
Epoch 89/4000
Current lambda: 0.3700


Training: 100%|██████████| 96/96 [02:38<00:00,  1.66s/it, loss=0.403]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=0.435]


Validation Dice Score
Class 0: 0.9934, Class 1: 0.4181, Class 2: 0.0924, Class 3: 0.3931, 
Class 4: 0.4881, Class 5: 0.5378, Class 6: 0.2012, 
Validation F-beta Score
Class 0: 0.9939, Class 1: 0.4990, Class 2: 0.0921, Class 3: 0.3893, 
Class 4: 0.4251, Class 5: 0.5114, Class 6: 0.2097, 
Overall Mean Dice Score: 0.4077
Overall Mean F-beta Score: 0.4069

Training Loss: 1.6494, Validation Loss: 0.4480, Validation F-beta: 0.4069
Epoch 90/4000
Current lambda: 0.3700


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.384]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=0.389]


Validation Dice Score
Class 0: 0.9896, Class 1: 0.5927, Class 2: 0.2723, Class 3: 0.3231, 
Class 4: 0.4177, Class 5: 0.2916, Class 6: 0.7204, 
Validation F-beta Score
Class 0: 0.9909, Class 1: 0.6089, Class 2: 0.3039, Class 3: 0.2580, 
Class 4: 0.4095, Class 5: 0.2498, Class 6: 0.7818, 
Overall Mean Dice Score: 0.4691
Overall Mean F-beta Score: 0.4616

Training Loss: 1.6541, Validation Loss: 0.4184, Validation F-beta: 0.4616
Epoch 91/4000
Current lambda: 0.3700


Training: 100%|██████████| 96/96 [02:32<00:00,  1.59s/it, loss=0.446]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.23it/s, loss=0.496]


Validation Dice Score
Class 0: 0.9893, Class 1: 0.5433, Class 2: 0.1307, Class 3: 0.4046, 
Class 4: 0.6308, Class 5: 0.2885, Class 6: 0.6441, 
Validation F-beta Score
Class 0: 0.9891, Class 1: 0.4962, Class 2: 0.0948, Class 3: 0.4917, 
Class 4: 0.6624, Class 5: 0.2625, Class 6: 0.6530, 
Overall Mean Dice Score: 0.5023
Overall Mean F-beta Score: 0.5132

Training Loss: 1.7199, Validation Loss: 0.4137, Validation F-beta: 0.5132
Epoch 92/4000
Current lambda: 0.3700


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.466]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.38] 


Validation Dice Score
Class 0: 0.9887, Class 1: 0.7540, Class 2: 0.0274, Class 3: 0.1870, 
Class 4: 0.5015, Class 5: 0.3243, Class 6: 0.6317, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.7817, Class 2: 0.0436, Class 3: 0.1870, 
Class 4: 0.4632, Class 5: 0.2878, Class 6: 0.6254, 
Overall Mean Dice Score: 0.4797
Overall Mean F-beta Score: 0.4690

Training Loss: 1.6684, Validation Loss: 0.3944, Validation F-beta: 0.4690
Epoch 93/4000
Current lambda: 0.3700


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.508]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.419]


Validation Dice Score
Class 0: 0.9913, Class 1: 0.7402, Class 2: 0.1401, Class 3: 0.3451, 
Class 4: 0.6347, Class 5: 0.4841, Class 6: 0.9262, 
Validation F-beta Score
Class 0: 0.9919, Class 1: 0.6751, Class 2: 0.1767, Class 3: 0.3360, 
Class 4: 0.6097, Class 5: 0.4922, Class 6: 0.9125, 
Overall Mean Dice Score: 0.6260
Overall Mean F-beta Score: 0.6051

Training Loss: 1.6799, Validation Loss: 0.3604, Validation F-beta: 0.6051
Validation loss did not improve. Reducing lambda to 0.3600
Epoch 94/4000
Current lambda: 0.3600


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.517]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.432]


Validation Dice Score
Class 0: 0.9941, Class 1: 0.6248, Class 2: 0.2401, Class 3: 0.5042, 
Class 4: 0.5844, Class 5: 0.4362, Class 6: 0.6430, 
Validation F-beta Score
Class 0: 0.9937, Class 1: 0.6108, Class 2: 0.2686, Class 3: 0.4958, 
Class 4: 0.8321, Class 5: 0.4286, Class 6: 0.6553, 
Overall Mean Dice Score: 0.5585
Overall Mean F-beta Score: 0.6045

Training Loss: 1.6954, Validation Loss: 0.3844, Validation F-beta: 0.6045
Validation loss did not improve. Reducing lambda to 0.3500
Epoch 95/4000
Current lambda: 0.3500


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.422]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.469]


Validation Dice Score
Class 0: 0.9929, Class 1: 0.5121, Class 2: 0.2087, Class 3: 0.1746, 
Class 4: 0.2243, Class 5: 0.3611, Class 6: 0.4551, 
Validation F-beta Score
Class 0: 0.9927, Class 1: 0.5991, Class 2: 0.3954, Class 3: 0.1700, 
Class 4: 0.1891, Class 5: 0.3555, Class 6: 0.7087, 
Overall Mean Dice Score: 0.3454
Overall Mean F-beta Score: 0.4045

Training Loss: 1.6820, Validation Loss: 0.4415, Validation F-beta: 0.4045
Epoch 96/4000
Current lambda: 0.3500


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.44] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.455]


Validation Dice Score
Class 0: 0.9910, Class 1: 0.3946, Class 2: 0.0018, Class 3: 0.3748, 
Class 4: 0.6058, Class 5: 0.4105, Class 6: 0.6436, 
Validation F-beta Score
Class 0: 0.9922, Class 1: 0.4012, Class 2: 0.0028, Class 3: 0.3033, 
Class 4: 0.5424, Class 5: 0.4563, Class 6: 0.7081, 
Overall Mean Dice Score: 0.4859
Overall Mean F-beta Score: 0.4823

Training Loss: 1.6754, Validation Loss: 0.4301, Validation F-beta: 0.4823
Epoch 97/4000
Current lambda: 0.3500


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.498]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s, loss=0.432]


Validation Dice Score
Class 0: 0.9891, Class 1: 0.3993, Class 2: 0.0755, Class 3: 0.4675, 
Class 4: 0.6516, Class 5: 0.3128, Class 6: 0.4317, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.3807, Class 2: 0.0751, Class 3: 0.5295, 
Class 4: 0.5709, Class 5: 0.3242, Class 6: 0.4412, 
Overall Mean Dice Score: 0.4526
Overall Mean F-beta Score: 0.4493

Training Loss: 1.6747, Validation Loss: 0.4399, Validation F-beta: 0.4493
Epoch 98/4000
Current lambda: 0.3500


Training: 100%|██████████| 96/96 [02:29<00:00,  1.56s/it, loss=0.312]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s, loss=0.387]


Validation Dice Score
Class 0: 0.9935, Class 1: 0.7453, Class 2: 0.0668, Class 3: 0.3296, 
Class 4: 0.7833, Class 5: 0.5010, Class 6: 0.6958, 
Validation F-beta Score
Class 0: 0.9930, Class 1: 0.8595, Class 2: 0.0910, Class 3: 0.2952, 
Class 4: 0.7398, Class 5: 0.6372, Class 6: 0.8582, 
Overall Mean Dice Score: 0.6110
Overall Mean F-beta Score: 0.6780

Training Loss: 1.6950, Validation Loss: 0.3775, Validation F-beta: 0.6780
Validation loss did not improve. Reducing lambda to 0.3400
Epoch 99/4000
Current lambda: 0.3400


Training: 100%|██████████| 96/96 [02:41<00:00,  1.68s/it, loss=0.571]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.463]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.3514, Class 2: 0.1270, Class 3: 0.2104, 
Class 4: 0.3956, Class 5: 0.4299, Class 6: 0.6337, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.3564, Class 2: 0.1020, Class 3: 0.1935, 
Class 4: 0.3527, Class 5: 0.4256, Class 6: 0.6474, 
Overall Mean Dice Score: 0.4042
Overall Mean F-beta Score: 0.3951

Training Loss: 1.6598, Validation Loss: 0.4342, Validation F-beta: 0.3951
Epoch 100/4000
Current lambda: 0.3400


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.407]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.519]


Validation Dice Score
Class 0: 0.9924, Class 1: 0.2056, Class 2: 0.1050, Class 3: 0.1863, 
Class 4: 0.3736, Class 5: 0.4366, Class 6: 0.2392, 
Validation F-beta Score
Class 0: 0.9928, Class 1: 0.2029, Class 2: 0.1206, Class 3: 0.1828, 
Class 4: 0.3293, Class 5: 0.4537, Class 6: 0.3036, 
Overall Mean Dice Score: 0.2882
Overall Mean F-beta Score: 0.2945

Training Loss: 1.6975, Validation Loss: 0.4923, Validation F-beta: 0.2945
Epoch 101/4000
Current lambda: 0.3400


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.32] 
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.18it/s, loss=0.292]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.8085, Class 2: 0.1491, Class 3: 0.5282, 
Class 4: 0.6379, Class 5: 0.5183, Class 6: 0.9018, 
Validation F-beta Score
Class 0: 0.9933, Class 1: 0.8480, Class 2: 0.1397, Class 3: 0.5179, 
Class 4: 0.6215, Class 5: 0.4007, Class 6: 0.9404, 
Overall Mean Dice Score: 0.6789
Overall Mean F-beta Score: 0.6657

Training Loss: 1.6761, Validation Loss: 0.3059, Validation F-beta: 0.6657
SUPER Best model saved. Loss:0.3059, Score:0.6657
Epoch 102/4000
Current lambda: 0.3400


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.432]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.486]


Validation Dice Score
Class 0: 0.9845, Class 1: 0.4277, Class 2: 0.1216, Class 3: 0.2300, 
Class 4: 0.3973, Class 5: 0.2363, Class 6: 0.6927, 
Validation F-beta Score
Class 0: 0.9908, Class 1: 0.4491, Class 2: 0.1240, Class 3: 0.2171, 
Class 4: 0.3029, Class 5: 0.2146, Class 6: 0.7197, 
Overall Mean Dice Score: 0.3968
Overall Mean F-beta Score: 0.3807

Training Loss: 1.6803, Validation Loss: 0.4444, Validation F-beta: 0.3807
Epoch 103/4000
Current lambda: 0.3400


Training: 100%|██████████| 96/96 [02:42<00:00,  1.69s/it, loss=0.394]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.35it/s, loss=0.491]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.5184, Class 2: 0.0000, Class 3: 0.4348, 
Class 4: 0.3983, Class 5: 0.2714, Class 6: 0.1785, 
Validation F-beta Score
Class 0: 0.9880, Class 1: 0.5129, Class 2: 0.0000, Class 3: 0.4077, 
Class 4: 0.4076, Class 5: 0.3879, Class 6: 0.2197, 
Overall Mean Dice Score: 0.3603
Overall Mean F-beta Score: 0.3872

Training Loss: 1.6576, Validation Loss: 0.4730, Validation F-beta: 0.3872
Epoch 104/4000
Current lambda: 0.3400


Training: 100%|██████████| 96/96 [02:44<00:00,  1.71s/it, loss=0.432]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.30it/s, loss=0.442]


Validation Dice Score
Class 0: 0.9922, Class 1: 0.3501, Class 2: 0.0787, Class 3: 0.1331, 
Class 4: 0.4649, Class 5: 0.5734, Class 6: 0.2249, 
Validation F-beta Score
Class 0: 0.9920, Class 1: 0.3104, Class 2: 0.0938, Class 3: 0.0966, 
Class 4: 0.4087, Class 5: 0.5904, Class 6: 0.2133, 
Overall Mean Dice Score: 0.3493
Overall Mean F-beta Score: 0.3239

Training Loss: 1.6980, Validation Loss: 0.4829, Validation F-beta: 0.3239
Epoch 105/4000
Current lambda: 0.3400


Training: 100%|██████████| 96/96 [02:41<00:00,  1.69s/it, loss=0.396]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.36it/s, loss=0.383]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.8595, Class 2: 0.1677, Class 3: 0.3644, 
Class 4: 0.3361, Class 5: 0.2752, Class 6: 0.6334, 
Validation F-beta Score
Class 0: 0.9923, Class 1: 0.8917, Class 2: 0.2585, Class 3: 0.2953, 
Class 4: 0.3054, Class 5: 0.2364, Class 6: 0.9403, 
Overall Mean Dice Score: 0.4937
Overall Mean F-beta Score: 0.5338

Training Loss: 1.6719, Validation Loss: 0.4064, Validation F-beta: 0.5338
Epoch 106/4000
Current lambda: 0.3400


Training: 100%|██████████| 96/96 [02:39<00:00,  1.66s/it, loss=0.389]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s, loss=0.373]


Validation Dice Score
Class 0: 0.9928, Class 1: 0.4920, Class 2: 0.1172, Class 3: 0.1091, 
Class 4: 0.8107, Class 5: 0.5492, Class 6: 0.5924, 
Validation F-beta Score
Class 0: 0.9931, Class 1: 0.4752, Class 2: 0.1268, Class 3: 0.0995, 
Class 4: 0.7515, Class 5: 0.6080, Class 6: 0.9303, 
Overall Mean Dice Score: 0.5107
Overall Mean F-beta Score: 0.5729

Training Loss: 1.6698, Validation Loss: 0.4245, Validation F-beta: 0.5729
Validation loss did not improve. Reducing lambda to 0.3300
Epoch 107/4000
Current lambda: 0.3300


Training: 100%|██████████| 96/96 [02:38<00:00,  1.66s/it, loss=0.463]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.462]


Validation Dice Score
Class 0: 0.9901, Class 1: 0.6038, Class 2: 0.1212, Class 3: 0.2888, 
Class 4: 0.6049, Class 5: 0.5056, Class 6: 0.6416, 
Validation F-beta Score
Class 0: 0.9922, Class 1: 0.6254, Class 2: 0.1577, Class 3: 0.2235, 
Class 4: 0.5621, Class 5: 0.4708, Class 6: 0.6901, 
Overall Mean Dice Score: 0.5289
Overall Mean F-beta Score: 0.5144

Training Loss: 1.6600, Validation Loss: 0.3911, Validation F-beta: 0.5144
Epoch 108/4000
Current lambda: 0.3300


Training: 100%|██████████| 96/96 [02:38<00:00,  1.66s/it, loss=0.369]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.385]


Validation Dice Score
Class 0: 0.9910, Class 1: 0.5792, Class 2: 0.0952, Class 3: 0.3775, 
Class 4: 0.6886, Class 5: 0.3419, Class 6: 0.2247, 
Validation F-beta Score
Class 0: 0.9937, Class 1: 0.9106, Class 2: 0.0692, Class 3: 0.4053, 
Class 4: 0.5965, Class 5: 0.3485, Class 6: 0.2123, 
Overall Mean Dice Score: 0.4424
Overall Mean F-beta Score: 0.4946

Training Loss: 1.6396, Validation Loss: 0.4167, Validation F-beta: 0.4946
Epoch 109/4000
Current lambda: 0.3300


Training: 100%|██████████| 96/96 [02:38<00:00,  1.65s/it, loss=0.347]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.41it/s, loss=0.432]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.4344, Class 2: 0.2070, Class 3: 0.5432, 
Class 4: 0.4568, Class 5: 0.3327, Class 6: 0.6970, 
Validation F-beta Score
Class 0: 0.9929, Class 1: 0.4787, Class 2: 0.2712, Class 3: 0.5458, 
Class 4: 0.4156, Class 5: 0.2818, Class 6: 0.7186, 
Overall Mean Dice Score: 0.4928
Overall Mean F-beta Score: 0.4881

Training Loss: 1.6534, Validation Loss: 0.4030, Validation F-beta: 0.4881
Epoch 110/4000
Current lambda: 0.3300


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.43] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.449]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.5950, Class 2: 0.0000, Class 3: 0.2075, 
Class 4: 0.3486, Class 5: 0.2917, Class 6: 0.8658, 
Validation F-beta Score
Class 0: 0.9919, Class 1: 0.6699, Class 2: 0.0000, Class 3: 0.2409, 
Class 4: 0.2946, Class 5: 0.2768, Class 6: 0.9448, 
Overall Mean Dice Score: 0.4617
Overall Mean F-beta Score: 0.4854

Training Loss: 1.6991, Validation Loss: 0.4599, Validation F-beta: 0.4854
Epoch 111/4000
Current lambda: 0.3300


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.397]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.18it/s, loss=0.531]


Validation Dice Score
Class 0: 0.9894, Class 1: 0.5011, Class 2: 0.1051, Class 3: 0.0985, 
Class 4: 0.5667, Class 5: 0.3461, Class 6: 0.4447, 
Validation F-beta Score
Class 0: 0.9942, Class 1: 0.6013, Class 2: 0.1295, Class 3: 0.0759, 
Class 4: 0.4520, Class 5: 0.2853, Class 6: 0.4532, 
Overall Mean Dice Score: 0.3914
Overall Mean F-beta Score: 0.3735

Training Loss: 1.6874, Validation Loss: 0.4494, Validation F-beta: 0.3735
Epoch 112/4000
Current lambda: 0.3300


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.472]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.404]


Validation Dice Score
Class 0: 0.9900, Class 1: 0.7095, Class 2: 0.1809, Class 3: 0.3910, 
Class 4: 0.5053, Class 5: 0.4157, Class 6: 0.6466, 
Validation F-beta Score
Class 0: 0.9876, Class 1: 0.6949, Class 2: 0.1941, Class 3: 0.3237, 
Class 4: 0.5236, Class 5: 0.4418, Class 6: 0.6424, 
Overall Mean Dice Score: 0.5336
Overall Mean F-beta Score: 0.5253

Training Loss: 1.6771, Validation Loss: 0.3798, Validation F-beta: 0.5253
Validation loss did not improve. Reducing lambda to 0.3200
Epoch 113/4000
Current lambda: 0.3200


Training: 100%|██████████| 96/96 [02:39<00:00,  1.66s/it, loss=0.489]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9901, Class 1: 0.5993, Class 2: 0.1397, Class 3: 0.5941, 
Class 4: 0.3701, Class 5: 0.2811, Class 6: 0.4257, 
Validation F-beta Score
Class 0: 0.9930, Class 1: 0.6679, Class 2: 0.1512, Class 3: 0.5523, 
Class 4: 0.3117, Class 5: 0.2409, Class 6: 0.5583, 
Overall Mean Dice Score: 0.4541
Overall Mean F-beta Score: 0.4662

Training Loss: 1.6883, Validation Loss: 0.4293, Validation F-beta: 0.4662
Epoch 114/4000
Current lambda: 0.3200


Training: 100%|██████████| 96/96 [02:38<00:00,  1.66s/it, loss=0.486]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.449]


Validation Dice Score
Class 0: 0.9883, Class 1: 0.5834, Class 2: 0.2283, Class 3: 0.4114, 
Class 4: 0.5730, Class 5: 0.4563, Class 6: 0.5749, 
Validation F-beta Score
Class 0: 0.9921, Class 1: 0.6337, Class 2: 0.2118, Class 3: 0.3507, 
Class 4: 0.5368, Class 5: 0.4135, Class 6: 0.5879, 
Overall Mean Dice Score: 0.5198
Overall Mean F-beta Score: 0.5045

Training Loss: 1.6583, Validation Loss: 0.4036, Validation F-beta: 0.5045
Epoch 115/4000
Current lambda: 0.3200


Training: 100%|██████████| 96/96 [02:38<00:00,  1.65s/it, loss=0.498]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=0.4]  


Validation Dice Score
Class 0: 0.9901, Class 1: 0.6856, Class 2: 0.0000, Class 3: 0.4716, 
Class 4: 0.5715, Class 5: 0.4024, Class 6: 0.8771, 
Validation F-beta Score
Class 0: 0.9943, Class 1: 0.7073, Class 2: 0.0000, Class 3: 0.5127, 
Class 4: 0.4467, Class 5: 0.4265, Class 6: 0.8737, 
Overall Mean Dice Score: 0.6016
Overall Mean F-beta Score: 0.5934

Training Loss: 1.6762, Validation Loss: 0.4119, Validation F-beta: 0.5934
Epoch 116/4000
Current lambda: 0.3200


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.46] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.516]


Validation Dice Score
Class 0: 0.9918, Class 1: 0.5605, Class 2: 0.0000, Class 3: 0.3839, 
Class 4: 0.3203, Class 5: 0.3498, Class 6: 0.6400, 
Validation F-beta Score
Class 0: 0.9922, Class 1: 0.5914, Class 2: 0.0000, Class 3: 0.4038, 
Class 4: 0.5185, Class 5: 0.3515, Class 6: 0.6533, 
Overall Mean Dice Score: 0.4509
Overall Mean F-beta Score: 0.5037

Training Loss: 1.6828, Validation Loss: 0.4403, Validation F-beta: 0.5037
Epoch 117/4000
Current lambda: 0.3200


Training: 100%|██████████| 96/96 [02:38<00:00,  1.65s/it, loss=0.378]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.41it/s, loss=0.44] 


Validation Dice Score
Class 0: 0.9914, Class 1: 0.4842, Class 2: 0.0230, Class 3: 0.2882, 
Class 4: 0.6393, Class 5: 0.6011, Class 6: 0.8325, 
Validation F-beta Score
Class 0: 0.9920, Class 1: 0.5324, Class 2: 0.0152, Class 3: 0.2632, 
Class 4: 0.5752, Class 5: 0.6262, Class 6: 0.9368, 
Overall Mean Dice Score: 0.5691
Overall Mean F-beta Score: 0.5868

Training Loss: 1.6763, Validation Loss: 0.4020, Validation F-beta: 0.5868
Epoch 118/4000
Current lambda: 0.3200


Training: 100%|██████████| 96/96 [02:39<00:00,  1.66s/it, loss=0.353]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.37it/s, loss=0.337]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.7117, Class 2: 0.2250, Class 3: 0.2415, 
Class 4: 0.6013, Class 5: 0.2900, Class 6: 0.7263, 
Validation F-beta Score
Class 0: 0.9894, Class 1: 0.6946, Class 2: 0.2435, Class 3: 0.1998, 
Class 4: 0.5638, Class 5: 0.2879, Class 6: 0.7148, 
Overall Mean Dice Score: 0.5142
Overall Mean F-beta Score: 0.4922

Training Loss: 1.6336, Validation Loss: 0.3999, Validation F-beta: 0.4922
Validation loss did not improve. Reducing lambda to 0.3100
Epoch 119/4000
Current lambda: 0.3100


Training: 100%|██████████| 96/96 [02:39<00:00,  1.66s/it, loss=0.413]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.35it/s, loss=0.41] 


Validation Dice Score
Class 0: 0.9884, Class 1: 0.8334, Class 2: 0.1188, Class 3: 0.0351, 
Class 4: 0.6068, Class 5: 0.3921, Class 6: 0.6046, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.8305, Class 2: 0.1516, Class 3: 0.0229, 
Class 4: 0.5274, Class 5: 0.3233, Class 6: 0.6818, 
Overall Mean Dice Score: 0.4944
Overall Mean F-beta Score: 0.4772

Training Loss: 1.6284, Validation Loss: 0.4160, Validation F-beta: 0.4772
Epoch 120/4000
Current lambda: 0.3100


Training: 100%|██████████| 96/96 [02:38<00:00,  1.65s/it, loss=0.417]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.381]


Validation Dice Score
Class 0: 0.9913, Class 1: 0.6002, Class 2: 0.1484, Class 3: 0.4581, 
Class 4: 0.6694, Class 5: 0.3986, Class 6: 0.4059, 
Validation F-beta Score
Class 0: 0.9932, Class 1: 0.6846, Class 2: 0.1473, Class 3: 0.4136, 
Class 4: 0.5950, Class 5: 0.4163, Class 6: 0.6202, 
Overall Mean Dice Score: 0.5065
Overall Mean F-beta Score: 0.5459

Training Loss: 1.6611, Validation Loss: 0.4254, Validation F-beta: 0.5459
Epoch 121/4000
Current lambda: 0.3100


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.267]
Validation: 100%|██████████| 4/4 [00:03<00:00,  1.16it/s, loss=0.358]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.5523, Class 2: 0.1934, Class 3: 0.0570, 
Class 4: 0.5903, Class 5: 0.4512, Class 6: 0.4214, 
Validation F-beta Score
Class 0: 0.9921, Class 1: 0.6347, Class 2: 0.2336, Class 3: 0.0437, 
Class 4: 0.5567, Class 5: 0.4474, Class 6: 0.5162, 
Overall Mean Dice Score: 0.4144
Overall Mean F-beta Score: 0.4397

Training Loss: 1.6676, Validation Loss: 0.4457, Validation F-beta: 0.4397
Epoch 122/4000
Current lambda: 0.3100


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.517]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s, loss=0.363]


Validation Dice Score
Class 0: 0.9920, Class 1: 0.3895, Class 2: 0.2098, Class 3: 0.4618, 
Class 4: 0.5976, Class 5: 0.5949, Class 6: 0.4189, 
Validation F-beta Score
Class 0: 0.9934, Class 1: 0.4146, Class 2: 0.2348, Class 3: 0.3681, 
Class 4: 0.6070, Class 5: 0.5831, Class 6: 0.4741, 
Overall Mean Dice Score: 0.4926
Overall Mean F-beta Score: 0.4894

Training Loss: 1.6822, Validation Loss: 0.4114, Validation F-beta: 0.4894
Epoch 123/4000
Current lambda: 0.3100


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.412]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.43] 


Validation Dice Score
Class 0: 0.9911, Class 1: 0.7855, Class 2: 0.1125, Class 3: 0.3205, 
Class 4: 0.6509, Class 5: 0.4295, Class 6: 0.8928, 
Validation F-beta Score
Class 0: 0.9932, Class 1: 0.7701, Class 2: 0.1341, Class 3: 0.2819, 
Class 4: 0.5860, Class 5: 0.4184, Class 6: 0.9200, 
Overall Mean Dice Score: 0.6158
Overall Mean F-beta Score: 0.5953

Training Loss: 1.6196, Validation Loss: 0.3766, Validation F-beta: 0.5953
Epoch 124/4000
Current lambda: 0.3100


Training: 100%|██████████| 96/96 [02:37<00:00,  1.64s/it, loss=0.451]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.484]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.6664, Class 2: 0.2356, Class 3: 0.3197, 
Class 4: 0.3241, Class 5: 0.4313, Class 6: 0.6673, 
Validation F-beta Score
Class 0: 0.9921, Class 1: 0.6918, Class 2: 0.2997, Class 3: 0.3049, 
Class 4: 0.2755, Class 5: 0.5208, Class 6: 0.6869, 
Overall Mean Dice Score: 0.4818
Overall Mean F-beta Score: 0.4960

Training Loss: 1.6913, Validation Loss: 0.4341, Validation F-beta: 0.4960
Validation loss did not improve. Reducing lambda to 0.3000
Epoch 125/4000
Current lambda: 0.3000


Training: 100%|██████████| 96/96 [02:37<00:00,  1.65s/it, loss=0.364]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.36it/s, loss=0.44] 


Validation Dice Score
Class 0: 0.9877, Class 1: 0.6716, Class 2: 0.1319, Class 3: 0.1996, 
Class 4: 0.4670, Class 5: 0.3340, Class 6: 0.6686, 
Validation F-beta Score
Class 0: 0.9923, Class 1: 0.7120, Class 2: 0.1386, Class 3: 0.1536, 
Class 4: 0.3697, Class 5: 0.2752, Class 6: 0.6692, 
Overall Mean Dice Score: 0.4681
Overall Mean F-beta Score: 0.4360

Training Loss: 1.6846, Validation Loss: 0.4299, Validation F-beta: 0.4360
Epoch 126/4000
Current lambda: 0.3000


Training: 100%|██████████| 96/96 [02:38<00:00,  1.65s/it, loss=0.503]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.414]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.7220, Class 2: 0.1869, Class 3: 0.4161, 
Class 4: 0.6210, Class 5: 0.3201, Class 6: 0.6600, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.6761, Class 2: 0.2091, Class 3: 0.3439, 
Class 4: 0.6910, Class 5: 0.3243, Class 6: 0.6457, 
Overall Mean Dice Score: 0.5479
Overall Mean F-beta Score: 0.5362

Training Loss: 1.6560, Validation Loss: 0.3771, Validation F-beta: 0.5362
Epoch 127/4000
Current lambda: 0.3000


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.489]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.406]


Validation Dice Score
Class 0: 0.9883, Class 1: 0.5086, Class 2: 0.2112, Class 3: 0.3411, 
Class 4: 0.5431, Class 5: 0.3578, Class 6: 0.0282, 
Validation F-beta Score
Class 0: 0.9918, Class 1: 0.5402, Class 2: 0.2270, Class 3: 0.2934, 
Class 4: 0.4963, Class 5: 0.3010, Class 6: 0.0819, 
Overall Mean Dice Score: 0.3558
Overall Mean F-beta Score: 0.3426

Training Loss: 1.6684, Validation Loss: 0.4735, Validation F-beta: 0.3426
Epoch 128/4000
Current lambda: 0.3000


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.458]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.402]


Validation Dice Score
Class 0: 0.9917, Class 1: 0.7832, Class 2: 0.0000, Class 3: 0.2258, 
Class 4: 0.3994, Class 5: 0.4622, Class 6: 0.4419, 
Validation F-beta Score
Class 0: 0.9951, Class 1: 0.8729, Class 2: 0.0000, Class 3: 0.2198, 
Class 4: 0.3014, Class 5: 0.4696, Class 6: 0.4486, 
Overall Mean Dice Score: 0.4625
Overall Mean F-beta Score: 0.4625

Training Loss: 1.6693, Validation Loss: 0.4428, Validation F-beta: 0.4625
Epoch 129/4000
Current lambda: 0.3000


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.444]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.37it/s, loss=0.475]


Validation Dice Score
Class 0: 0.9920, Class 1: 0.6271, Class 2: 0.1286, Class 3: 0.1382, 
Class 4: 0.4039, Class 5: 0.4438, Class 6: 0.8740, 
Validation F-beta Score
Class 0: 0.9910, Class 1: 0.6729, Class 2: 0.1412, Class 3: 0.1100, 
Class 4: 0.4231, Class 5: 0.4402, Class 6: 0.8964, 
Overall Mean Dice Score: 0.4974
Overall Mean F-beta Score: 0.5085

Training Loss: 1.6611, Validation Loss: 0.4156, Validation F-beta: 0.5085
Epoch 130/4000
Current lambda: 0.3000


Training: 100%|██████████| 96/96 [02:36<00:00,  1.63s/it, loss=0.455]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s, loss=0.411]


Validation Dice Score
Class 0: 0.9924, Class 1: 0.8063, Class 2: 0.1897, Class 3: 0.4056, 
Class 4: 0.2366, Class 5: 0.4676, Class 6: 0.5942, 
Validation F-beta Score
Class 0: 0.9938, Class 1: 0.8839, Class 2: 0.2380, Class 3: 0.3789, 
Class 4: 0.1707, Class 5: 0.4242, Class 6: 0.6484, 
Overall Mean Dice Score: 0.5021
Overall Mean F-beta Score: 0.5012

Training Loss: 1.6819, Validation Loss: 0.4191, Validation F-beta: 0.5012
Early stopping


0,1
class_0_dice_score,▁▃▂▆▄▅▄▆▇▅▆▇▄▆▇▅▇▄▇▅▇▇█▇▆▅▇█▇▆▆▆▆▆▅▄▇▇▆▇
class_0_f_beta_score,▁▄▇▆▅▄▆▆▇▄▆▄▆▆▇▇▆▇▇▇▇▇█▇▆▇▇█▇▇▇██▅▇▇██▇▇
class_1_dice_score,▁▁▁▂▃▁▄▅▆▆▇█▅▄▅▆▆▆▆▆▅█▄▇█▅▄▇▄▃▅▇▆▆▇▆▄▇██
class_1_f_beta_score,▁▄▆▆█▅▇▄▆▅▄▁▇▅▇▆▇▄▇█▄▆▅▆▅▇▅▇▅▇▆▄▃▄▅▅█▆▆█
class_2_dice_score,▁▁▁▁▁▂▁▂▁▃▃▂▁▁▄▄▄▁▅▁▅▅▁▁█▁▂▅▃▃▃▄▃▁▁▄▄▃▅▃
class_2_f_beta_score,▁▄▁██▁▁▄▁▃▁▂▃▃▄▅▅▃▃▃▆▇▁▂▃▅▄▆▅▅▂▂▂▁▂▂▃▄▃▄
class_3_dice_score,▁▄▂▃▄▄▄▅▅▂▅▃▄▅▄▄▅▅▇▆▃▂▆▄▇█▄▃▅▃▆▅▃▂▄▆▁▆▄▃
class_3_f_beta_score,▁▁▅▅▅▄█▄▅▄▅▅▇▅▆█▄▇▄▃▆▇▂▅▄▄▅█▅█▆▅▄▆▂▄▅▃▆▆
class_4_dice_score,▅▃▂▃▅▁▅▃▄▃▄▅▇▄▅▆▅▅▅▅▆▂▄▄▇▆▅▄▆▅▆▂▆▆█▄▃▆▆▆
class_4_f_beta_score,▂▂▄▄▆▅▃▄██▃▃▂▆▅▂▄▅▅▆▄▅▄▄▅▅▅▃▆▇▆█▂▄▅▅▂▇▅▁

0,1
class_0_dice_score,0.99238
class_0_f_beta_score,0.99385
class_1_dice_score,0.80631
class_1_f_beta_score,0.88387
class_2_dice_score,0.1897
class_2_f_beta_score,0.23797
class_3_dice_score,0.40564
class_3_f_beta_score,0.3789
class_4_dice_score,0.23663
class_4_f_beta_score,0.17075


In [13]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# VAl

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import TverskyLoss
import torch
import numpy as np
from tqdm import tqdm
import wandb
from src.dataset.dataset import make_val_dataloader

val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
lamda = 0.52

wandb.init(
    project='czii_SwinUnetR_val',  # 프로젝트 이름 설정
    name='SwinUNETR96_96_lr0.001_lambda0.52_batch2',         # 실행(run) 이름 설정
    config={
        'learning_rate': 0.001,
        'batch_size': batch_size,
        'lambda': lamda,
        'img_size': img_size,
        'device': 'cuda',
        "checkpoint_dir": "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2",
        
    }
)

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

val_loader = make_val_dataloader(
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0
)
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    softmax=True
)
    
    
from monai.metrics import DiceMetric

img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

val_loss, overall_mean_fbeta_score = validate_one_epoch(
    model=model, 
    val_loader=val_loader, 
    criterion=criterion, 
    device=device, 
    epoch=0, 
    calculate_dice_interval=1
)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
class_0_dice_score,▁
class_0_f_beta_score,▁
class_1_dice_score,▁
class_1_f_beta_score,▁
class_2_dice_score,▁
class_2_f_beta_score,▁
class_3_dice_score,▁
class_3_f_beta_score,▁
class_4_dice_score,▁
class_4_f_beta_score,▁

0,1
class_0_dice_score,0.65703
class_0_f_beta_score,0.50748
class_1_dice_score,0.53332
class_1_f_beta_score,0.64703
class_2_dice_score,0.00286
class_2_f_beta_score,0.02334
class_3_dice_score,0.23703
class_3_f_beta_score,0.23033
class_4_dice_score,0.65487
class_4_f_beta_score,0.62525


Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]
  checkpoint = torch.load(pretrain_path, map_location=device)
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s, loss=0.865]

Validation Dice Score
Class 0: 0.6570, Class 1: 0.5333, Class 2: 0.0029, Class 3: 0.2370, 
Class 4: 0.6549, Class 5: 0.4790, Class 6: 0.4255, 
Validation F-beta Score
Class 0: 0.5075, Class 1: 0.6470, Class 2: 0.0233, Class 3: 0.2303, 
Class 4: 0.6252, Class 5: 0.5145, Class 6: 0.4720, 
Overall Mean Dice Score: 0.4659
Overall Mean F-beta Score: 0.4978






# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

In [None]:
val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=1, 
            calculate_dice_interval=0
        )

Validation:   0%|          | 0/4 [00:03<?, ?it/s, loss=0.764]


ZeroDivisionError: integer modulo by zero

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv
