In [None]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

import random
import numpy as np
import torch


from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from src.models import MiTCSPUnet, MiTUnet

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

# 랜덤 시드 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


print_config()

In [2]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 1500}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 1500},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.16393442622950818, 2: 0.01639344262295082, 3: 0.2459016393442623, 4: 0.16393442622950818, 5: 0.2459016393442623, 6: 0.16393442622950818}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.16393442622950818, 0.01639344262295082, 0.2459016393442623, 0.16393442622950818, 0.2459016393442623, 0.16393442622950818]


# 모델 설정

In [3]:
from src.dataset.dataset import create_dataloaders
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.transforms import CastToTyped
import numpy as np

train_img_dir = "./datasets/train/images"
train_label_dir = "./datasets/train/labels"
val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
# DATA CONFIG
img_size =  96 # Match your patch size
img_depth = img_size
n_classes = 7
batch_size = 1 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
num_repeat = 60
# MODEL CONFIG
num_epochs = 4000
lamda = 0.52
ce_weight = 0.4
lr = 0.0001
use_checkpoint = False
feature_size = 24
reduction_ratio = (16, 8, 4, 2, 1)
ff_expansion = (2, 8, 8, 4, 4)
heads = (1, 1, 2, 4, 8)
stage_kernel_stride_pad = ((7,1,3), (3, 2, 1), (3, 2, 1), (3, 2, 1), (3, 2, 1))
num_mit_layers = (1, 2, 2, 2, 2)
num_bottle_layers = 2
# CLASS_WEIGHTS
class_weights = None
class_weights = torch.tensor([0.0001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치
accumulation_steps = 8
# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    CastToTyped(keys=["image"], dtype=np.float16),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
])


In [None]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0,train_num_repeat=num_repeat)

Loading dataset: 100%|██████████| 24/24 [00:38<00:00,  1.58s/it]
Loading dataset: 100%|██████████| 4/4 [00:07<00:00,  1.86s/it]


https://monai.io/model-zoo.html

In [5]:
from monai.losses import TverskyLoss
import torch
import torch.nn as nn

def loss_fn(loss, class_weights, device):
    """
    Tversky 손실에 클래스별 가중치를 적용하여 최종 스칼라 값을 반환합니다.

    Args:
        loss: Tversky 손실 텐서 (B, num_classes, H, W, D).
        class_weights: 클래스별 가중치 텐서 (num_classes,).
        device: 사용할 장치 (예: 'cuda' 또는 'cpu').

    Returns:
        torch.Tensor: 최종 가중 평균 손실 값 (스칼라).
    """
    # 가중치를 device로 이동
    class_weights = class_weights.to(device)

    # 클래스 차원에 가중치 적용 (B, num_classes, ...)
    class_weights = class_weights.view(1, n_classes, 1, 1, 1)  # [1, num_classes, 1, 1, 1]
    weighted_loss = loss * class_weights

    # 모든 차원을 평균 내어 스칼라 손실 반환
    final_loss = torch.mean(weighted_loss)
    return final_loss

class DynamicTverskyLoss(TverskyLoss):
    def __init__(self, lamda=0.5, **kwargs):
        super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
        self.lamda = lamda

    def set_lamda(self, lamda):
        self.lamda = lamda
        self.alpha = 1 - lamda
        self.beta = lamda
        
# criterion = DynamicTverskyLoss(
#     lamda=0.5,
#     include_background=False,
#     reduction="mean",
#     softmax=True
# )

class CombinedCETverskyLoss(nn.Module):
    def __init__(self, lamda=0.5, ce_weight=0.5, **kwargs):
        super().__init__()
        self._lamda = lamda  # lamda 값 저장
        self.tversky = DynamicTverskyLoss(lamda=lamda, **kwargs)
        self.ce = nn.CrossEntropyLoss()
        self.ce_weight = ce_weight
        
    def forward(self, inputs, targets):
        tversky_loss = self.tversky(inputs, targets)
        ce_loss = self.ce(inputs, targets)
        return self.ce_weight * ce_loss + (1 - self.ce_weight) * tversky_loss
    
    def set_lamda(self, lamda):
        self._lamda = lamda
        self.tversky.set_lamda(lamda)
    
    @property
    def lamda(self):
        return self._lamda

In [6]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
from monai.metrics import DiceMetric

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MiTUnet(
    img_size=(img_depth, img_size, img_size),
    
    out_channels=n_classes,
    feature_size=feature_size,
    
    heads = heads,
    ff_expansion=ff_expansion,
    reduction_ratio=reduction_ratio,
    num_layers=num_mit_layers,
    channels=1,
    stage_kernel_stride_pad = stage_kernel_stride_pad,
    
    spatial_dims=3,
    norm_name="instance",
    act_name = ("leakyrelu ", {"inplace": True, "negative_slope": 0.01}),
    n=num_bottle_layers,
).to(device)
# Pretrained weights 불러오기
# if use_checkpoint:
#     pretrain_path = "./swin_unetr_btcv_segmentation/models/model.pt"
#     weight = torch.load(pretrain_path, map_location=device)

#     # 출력 레이어의 키를 제외한 나머지 가중치만 로드
#     filtered_weights = {k: v for k, v in weight.items() if "out.conv.conv" not in k}

#     # strict=False로 로드하여 불일치하는 부분 무시
#     model.load_state_dict(filtered_weights, strict=False)
#     print("Filtered weights loaded successfully. Output layer will be trained from scratch.")

# Load pretrained weights
# model.load_from(weights=np.load(config_vit.real_pretrained_path, allow_pickle=True))
# TverskyLoss 설정
# 사용 예시
criterion = CombinedCETverskyLoss(
    lamda=lamda,
    ce_weight=ce_weight,  # CE Loss와 Tversky Loss의 비중을 0.5:0.5로 설정
    include_background=False,
    reduction="mean",
    softmax=True
)

pretrain_str = "yes" if use_checkpoint else "no"
weight_str = "weighted" if class_weights is not None else ""

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
folder_name = f"MitUnet_noencodLK_p{pretrain_str}_{weight_str}_f{feature_size}_d{img_depth}s{img_size}_lr{lr:.0e}_a{lamda:.2f}_b{1-lamda:.2f}_b{batch_size}_r{num_repeat}_ce{ce_weight}"
checkpoint_dir = checkpoint_base_dir / folder_name
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")

In [None]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([1, 1, 96, 96, 96]) torch.Size([1, 1, 96, 96, 96])


In [8]:
torch.backends.cudnn.benchmark = True

In [None]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_SwinUnetR',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        "use_checkpoint": use_checkpoint,
        "kernel_strade_pad": stage_kernel_stride_pad,   
        "heads": heads,
        "ff_expansion": ff_expansion,
        "reduction_ratio": reduction_ratio,
        "num_mit_layers": num_mit_layers,
        "num_bottle_layers": num_bottle_layers,
        
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mwoow070840[0m ([33mwaooang[0m). Use [1m`wandb login --relogin`[0m to force relogin


[]

# 학습

In [10]:
from monai.metrics import DiceMetric
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    # loss = criterion(outputs, labels_onehot)
    loss = loss_fn(criterion(outputs, labels_onehot),class_weights=class_weights, device=device)
    return loss, outputs, labels, outputs.argmax(dim=1)

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    with tqdm(train_loader, desc='Training') as pbar:
        for i, batch_data in enumerate(pbar):
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
    avg_loss = epoch_loss / len(train_loader)
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss


def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval):
    model.eval()
    val_loss = 0
    
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
        print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1})
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\n")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    return val_loss / len(val_loader), overall_mean_fbeta

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps
        )
        
        scheduler.step(train_loss)
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F-beta: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        # if epochs_no_improve%6 == 0:
        #     # 손실이 개선되지 않았으므로 lambda 감소
        #     new_lamda = max(criterion.lamda - 0.01, 0.1)  # 최소값은 0.1로 설정
        #     criterion.set_lamda(new_lamda)
        #     print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")

    wandb.finish()


In [None]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps
    ) 


Epoch 1/4000


Training: 100%|██████████| 1440/1440 [13:12<00:00,  1.82it/s, loss=0.465]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.67it/s, loss=0.515]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.0000, Class 5: 0.0000, Class 6: 0.0000, 
Validation F-beta Score
Class 0: 0.9982, Class 1: 0.5000, Class 2: 0.6667, Class 3: 0.3333, Class 4: 0.3333, Class 5: 0.3333, Class 6: 0.5000, 

Overall Mean Dice Score: 0.0000
Overall Mean F-beta Score: 0.4000

Training Loss: 0.4908, Validation Loss: 0.4923, Validation F-beta: 0.4000
SUPER Best model saved. Loss:0.4923, Score:0.4000
Epoch 2/4000


Training: 100%|██████████| 1440/1440 [12:33<00:00,  1.91it/s, loss=0.47] 
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.75it/s, loss=0.481]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.0660, Class 5: 0.0000, Class 6: 0.0001, 
Validation F-beta Score
Class 0: 0.9967, Class 1: 0.5833, Class 2: 0.4167, Class 3: 0.2500, Class 4: 0.1527, Class 5: 0.0833, Class 6: 0.2501, 

Overall Mean Dice Score: 0.0132
Overall Mean F-beta Score: 0.2639

Training Loss: 0.4732, Validation Loss: 0.4728, Validation F-beta: 0.2639
Epoch 3/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.477]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.82it/s, loss=0.504]


Validation Dice Score
Class 0: 0.9795, Class 1: 0.0063, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.2496, Class 5: 0.0269, Class 6: 0.0510, 
Validation F-beta Score
Class 0: 0.9786, Class 1: 0.0035, Class 2: 0.4167, Class 3: 0.5833, Class 4: 0.2357, Class 5: 0.1034, Class 6: 0.0859, 

Overall Mean Dice Score: 0.0668
Overall Mean F-beta Score: 0.2024

Training Loss: 0.4639, Validation Loss: 0.4722, Validation F-beta: 0.2024
Epoch 4/4000


Training: 100%|██████████| 1440/1440 [12:27<00:00,  1.93it/s, loss=0.439]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.95it/s, loss=0.404]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.0529, Class 2: 0.0000, Class 3: 0.0634, Class 4: 0.2984, Class 5: 0.0470, Class 6: 0.2439, 
Validation F-beta Score
Class 0: 0.9899, Class 1: 0.1088, Class 2: 0.5000, Class 3: 0.0517, Class 4: 0.2704, Class 5: 0.0318, Class 6: 0.2432, 

Overall Mean Dice Score: 0.1411
Overall Mean F-beta Score: 0.1412

Training Loss: 0.4459, Validation Loss: 0.4340, Validation F-beta: 0.1412
Epoch 5/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.403]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.88it/s, loss=0.431]


Validation Dice Score
Class 0: 0.9845, Class 1: 0.0444, Class 2: 0.0016, Class 3: 0.0814, Class 4: 0.3147, Class 5: 0.0388, Class 6: 0.1283, 
Validation F-beta Score
Class 0: 0.9820, Class 1: 0.1537, Class 2: 0.2509, Class 3: 0.1692, Class 4: 0.3351, Class 5: 0.0267, Class 6: 0.1087, 

Overall Mean Dice Score: 0.1215
Overall Mean F-beta Score: 0.1587

Training Loss: 0.4366, Validation Loss: 0.4339, Validation F-beta: 0.1587
Epoch 6/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.454]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.82it/s, loss=0.442]


Validation Dice Score
Class 0: 0.9885, Class 1: 0.0890, Class 2: 0.0004, Class 3: 0.1146, Class 4: 0.3098, Class 5: 0.0462, Class 6: 0.1760, 
Validation F-beta Score
Class 0: 0.9930, Class 1: 0.0598, Class 2: 0.0002, Class 3: 0.2935, Class 4: 0.2784, Class 5: 0.0371, Class 6: 0.1549, 

Overall Mean Dice Score: 0.1471
Overall Mean F-beta Score: 0.1647

Training Loss: 0.4298, Validation Loss: 0.4217, Validation F-beta: 0.1647
Epoch 7/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.401]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.75it/s, loss=0.4]  


Validation Dice Score
Class 0: 0.9891, Class 1: 0.0719, Class 2: 0.0000, Class 3: 0.1354, Class 4: 0.1529, Class 5: 0.1509, Class 6: 0.1420, 
Validation F-beta Score
Class 0: 0.9894, Class 1: 0.0613, Class 2: 0.3333, Class 3: 0.1037, Class 4: 0.1647, Class 5: 0.1243, Class 6: 0.1125, 

Overall Mean Dice Score: 0.1306
Overall Mean F-beta Score: 0.1133

Training Loss: 0.4219, Validation Loss: 0.4291, Validation F-beta: 0.1133
Epoch 8/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.447]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.82it/s, loss=0.353]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.1537, Class 2: 0.0203, Class 3: 0.1125, Class 4: 0.1225, Class 5: 0.0796, Class 6: 0.2549, 
Validation F-beta Score
Class 0: 0.9830, Class 1: 0.2303, Class 2: 0.0192, Class 3: 0.1312, Class 4: 0.1129, Class 5: 0.0576, Class 6: 0.2452, 

Overall Mean Dice Score: 0.1446
Overall Mean F-beta Score: 0.1554

Training Loss: 0.4258, Validation Loss: 0.4227, Validation F-beta: 0.1554
Epoch 9/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.425]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.93it/s, loss=0.397]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.0931, Class 2: 0.0000, Class 3: 0.1586, Class 4: 0.3055, Class 5: 0.2799, Class 6: 0.4251, 
Validation F-beta Score
Class 0: 0.9842, Class 1: 0.0675, Class 2: 0.0000, Class 3: 0.1764, Class 4: 0.3482, Class 5: 0.3418, Class 6: 0.3883, 

Overall Mean Dice Score: 0.2525
Overall Mean F-beta Score: 0.2644

Training Loss: 0.4157, Validation Loss: 0.3888, Validation F-beta: 0.2644
Epoch 10/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.456]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.80it/s, loss=0.389]


Validation Dice Score
Class 0: 0.9876, Class 1: 0.1180, Class 2: 0.0000, Class 3: 0.1036, Class 4: 0.4181, Class 5: 0.2959, Class 6: 0.3031, 
Validation F-beta Score
Class 0: 0.9902, Class 1: 0.1141, Class 2: 0.0000, Class 3: 0.0785, Class 4: 0.3340, Class 5: 0.3412, Class 6: 0.2826, 

Overall Mean Dice Score: 0.2477
Overall Mean F-beta Score: 0.2301

Training Loss: 0.4094, Validation Loss: 0.3903, Validation F-beta: 0.2301
Epoch 11/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.416]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.87it/s, loss=0.433]


Validation Dice Score
Class 0: 0.9907, Class 1: 0.1772, Class 2: 0.0166, Class 3: 0.2457, Class 4: 0.2822, Class 5: 0.1517, Class 6: 0.2398, 
Validation F-beta Score
Class 0: 0.9916, Class 1: 0.2342, Class 2: 0.0943, Class 3: 0.2169, Class 4: 0.4004, Class 5: 0.1382, Class 6: 0.3041, 

Overall Mean Dice Score: 0.2193
Overall Mean F-beta Score: 0.2588

Training Loss: 0.4052, Validation Loss: 0.3928, Validation F-beta: 0.2588
Epoch 12/4000


Training: 100%|██████████| 1440/1440 [12:29<00:00,  1.92it/s, loss=0.459]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.78it/s, loss=0.374]


Validation Dice Score
Class 0: 0.9909, Class 1: 0.1655, Class 2: 0.0402, Class 3: 0.2013, Class 4: 0.2779, Class 5: 0.1902, Class 6: 0.2484, 
Validation F-beta Score
Class 0: 0.9930, Class 1: 0.1729, Class 2: 0.0259, Class 3: 0.3084, Class 4: 0.3871, Class 5: 0.1908, Class 6: 0.2200, 

Overall Mean Dice Score: 0.2167
Overall Mean F-beta Score: 0.2558

Training Loss: 0.4038, Validation Loss: 0.3943, Validation F-beta: 0.2558
Epoch 13/4000


Training: 100%|██████████| 1440/1440 [12:27<00:00,  1.93it/s, loss=0.396]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.94it/s, loss=0.341]


Validation Dice Score
Class 0: 0.9883, Class 1: 0.1582, Class 2: 0.0173, Class 3: 0.1368, Class 4: 0.3749, Class 5: 0.1901, Class 6: 0.2187, 
Validation F-beta Score
Class 0: 0.9881, Class 1: 0.1997, Class 2: 0.0111, Class 3: 0.1460, Class 4: 0.3693, Class 5: 0.1881, Class 6: 0.1959, 

Overall Mean Dice Score: 0.2157
Overall Mean F-beta Score: 0.2198

Training Loss: 0.3989, Validation Loss: 0.3951, Validation F-beta: 0.2198
Epoch 14/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.395]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.85it/s, loss=0.465]


Validation Dice Score
Class 0: 0.9885, Class 1: 0.0562, Class 2: 0.0146, Class 3: 0.0394, Class 4: 0.3048, Class 5: 0.1537, Class 6: 0.0718, 
Validation F-beta Score
Class 0: 0.9905, Class 1: 0.0747, Class 2: 0.0930, Class 3: 0.2775, Class 4: 0.4279, Class 5: 0.1247, Class 6: 0.0693, 

Overall Mean Dice Score: 0.1252
Overall Mean F-beta Score: 0.1948

Training Loss: 0.3964, Validation Loss: 0.4249, Validation F-beta: 0.1948
Epoch 15/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.333]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.79it/s, loss=0.399]


Validation Dice Score
Class 0: 0.9889, Class 1: 0.1627, Class 2: 0.0418, Class 3: 0.1523, Class 4: 0.2418, Class 5: 0.1459, Class 6: 0.3683, 
Validation F-beta Score
Class 0: 0.9877, Class 1: 0.1826, Class 2: 0.0288, Class 3: 0.1914, Class 4: 0.3234, Class 5: 0.1119, Class 6: 0.5148, 

Overall Mean Dice Score: 0.2142
Overall Mean F-beta Score: 0.2648

Training Loss: 0.3953, Validation Loss: 0.3921, Validation F-beta: 0.2648
Epoch 16/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.378]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.87it/s, loss=0.419]


Validation Dice Score
Class 0: 0.9907, Class 1: 0.1479, Class 2: 0.0222, Class 3: 0.1890, Class 4: 0.3332, Class 5: 0.1538, Class 6: 0.1811, 
Validation F-beta Score
Class 0: 0.9905, Class 1: 0.1510, Class 2: 0.0196, Class 3: 0.2410, Class 4: 0.4127, Class 5: 0.1514, Class 6: 0.3303, 

Overall Mean Dice Score: 0.2010
Overall Mean F-beta Score: 0.2573

Training Loss: 0.3912, Validation Loss: 0.3942, Validation F-beta: 0.2573
Epoch 17/4000


Training: 100%|██████████| 1440/1440 [12:27<00:00,  1.93it/s, loss=0.362]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.73it/s, loss=0.326]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.0441, Class 2: 0.0000, Class 3: 0.1325, Class 4: 0.2248, Class 5: 0.2197, Class 6: 0.2319, 
Validation F-beta Score
Class 0: 0.9926, Class 1: 0.0341, Class 2: 0.0000, Class 3: 0.0888, Class 4: 0.2461, Class 5: 0.2120, Class 6: 0.4157, 

Overall Mean Dice Score: 0.1706
Overall Mean F-beta Score: 0.1993

Training Loss: 0.3885, Validation Loss: 0.4070, Validation F-beta: 0.1993
Epoch 18/4000


Training: 100%|██████████| 1440/1440 [12:27<00:00,  1.93it/s, loss=0.4]  
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.85it/s, loss=0.422]


Validation Dice Score
Class 0: 0.9884, Class 1: 0.2184, Class 2: 0.0126, Class 3: 0.2031, Class 4: 0.1588, Class 5: 0.2089, Class 6: 0.3122, 
Validation F-beta Score
Class 0: 0.9903, Class 1: 0.2093, Class 2: 0.0080, Class 3: 0.2141, Class 4: 0.1329, Class 5: 0.2148, Class 6: 0.3346, 

Overall Mean Dice Score: 0.2203
Overall Mean F-beta Score: 0.2212

Training Loss: 0.3883, Validation Loss: 0.3922, Validation F-beta: 0.2212
Epoch 19/4000


Training: 100%|██████████| 1440/1440 [12:30<00:00,  1.92it/s, loss=0.398]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.81it/s, loss=0.437]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.2194, Class 2: 0.0258, Class 3: 0.2152, Class 4: 0.2703, Class 5: 0.1922, Class 6: 0.2494, 
Validation F-beta Score
Class 0: 0.9884, Class 1: 0.2595, Class 2: 0.0186, Class 3: 0.2502, Class 4: 0.4022, Class 5: 0.1972, Class 6: 0.2332, 

Overall Mean Dice Score: 0.2293
Overall Mean F-beta Score: 0.2684

Training Loss: 0.3844, Validation Loss: 0.3934, Validation F-beta: 0.2684
Epoch 20/4000


Training: 100%|██████████| 1440/1440 [12:29<00:00,  1.92it/s, loss=0.346]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.92it/s, loss=0.421]


Validation Dice Score
Class 0: 0.9885, Class 1: 0.1945, Class 2: 0.0568, Class 3: 0.2194, Class 4: 0.3251, Class 5: 0.2614, Class 6: 0.2700, 
Validation F-beta Score
Class 0: 0.9851, Class 1: 0.2179, Class 2: 0.0447, Class 3: 0.2674, Class 4: 0.4012, Class 5: 0.2300, Class 6: 0.3695, 

Overall Mean Dice Score: 0.2541
Overall Mean F-beta Score: 0.2972

Training Loss: 0.3830, Validation Loss: 0.3790, Validation F-beta: 0.2972
Epoch 21/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.415]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.71it/s, loss=0.357]


Validation Dice Score
Class 0: 0.9847, Class 1: 0.2004, Class 2: 0.0000, Class 3: 0.2277, Class 4: 0.2870, Class 5: 0.1447, Class 6: 0.4369, 
Validation F-beta Score
Class 0: 0.9855, Class 1: 0.2074, Class 2: 0.1667, Class 3: 0.3219, Class 4: 0.3486, Class 5: 0.1896, Class 6: 0.4982, 

Overall Mean Dice Score: 0.2593
Overall Mean F-beta Score: 0.3131

Training Loss: 0.3832, Validation Loss: 0.3862, Validation F-beta: 0.3131
Epoch 22/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.437]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.83it/s, loss=0.413]


Validation Dice Score
Class 0: 0.9892, Class 1: 0.2711, Class 2: 0.0378, Class 3: 0.2381, Class 4: 0.2811, Class 5: 0.2230, Class 6: 0.4845, 
Validation F-beta Score
Class 0: 0.9877, Class 1: 0.2703, Class 2: 0.0272, Class 3: 0.2725, Class 4: 0.2981, Class 5: 0.2163, Class 6: 0.6021, 

Overall Mean Dice Score: 0.2996
Overall Mean F-beta Score: 0.3319

Training Loss: 0.3786, Validation Loss: 0.3584, Validation F-beta: 0.3319
Epoch 23/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.388]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.94it/s, loss=0.351]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.2830, Class 2: 0.0209, Class 3: 0.2293, Class 4: 0.1448, Class 5: 0.2504, Class 6: 0.5980, 
Validation F-beta Score
Class 0: 0.9847, Class 1: 0.2947, Class 2: 0.0126, Class 3: 0.4682, Class 4: 0.3151, Class 5: 0.2096, Class 6: 0.6169, 

Overall Mean Dice Score: 0.3011
Overall Mean F-beta Score: 0.3809

Training Loss: 0.3767, Validation Loss: 0.3654, Validation F-beta: 0.3809
Epoch 24/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.344]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.86it/s, loss=0.391]


Validation Dice Score
Class 0: 0.9875, Class 1: 0.2055, Class 2: 0.0647, Class 3: 0.2052, Class 4: 0.2988, Class 5: 0.3179, Class 6: 0.4087, 
Validation F-beta Score
Class 0: 0.9828, Class 1: 0.2489, Class 2: 0.1068, Class 3: 0.3023, Class 4: 0.4673, Class 5: 0.3132, Class 6: 0.4944, 

Overall Mean Dice Score: 0.2872
Overall Mean F-beta Score: 0.3652

Training Loss: 0.3767, Validation Loss: 0.3728, Validation F-beta: 0.3652
Epoch 25/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.427]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.91it/s, loss=0.402]


Validation Dice Score
Class 0: 0.9909, Class 1: 0.0768, Class 2: 0.0156, Class 3: 0.2954, Class 4: 0.5677, Class 5: 0.3298, Class 6: 0.2097, 
Validation F-beta Score
Class 0: 0.9924, Class 1: 0.0950, Class 2: 0.0097, Class 3: 0.3194, Class 4: 0.5848, Class 5: 0.2840, Class 6: 0.3746, 

Overall Mean Dice Score: 0.2959
Overall Mean F-beta Score: 0.3315

Training Loss: 0.3765, Validation Loss: 0.3658, Validation F-beta: 0.3315
Epoch 26/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.379]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.85it/s, loss=0.341]


Validation Dice Score
Class 0: 0.9894, Class 1: 0.0150, Class 2: 0.0001, Class 3: 0.2211, Class 4: 0.5052, Class 5: 0.1994, Class 6: 0.2505, 
Validation F-beta Score
Class 0: 0.9909, Class 1: 0.0126, Class 2: 0.0001, Class 3: 0.2168, Class 4: 0.4675, Class 5: 0.1867, Class 6: 0.2357, 

Overall Mean Dice Score: 0.2383
Overall Mean F-beta Score: 0.2239

Training Loss: 0.3738, Validation Loss: 0.3856, Validation F-beta: 0.2239
Epoch 27/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.326]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.94it/s, loss=0.464]


Validation Dice Score
Class 0: 0.9853, Class 1: 0.2811, Class 2: 0.0551, Class 3: 0.0999, Class 4: 0.2137, Class 5: 0.2070, Class 6: 0.2672, 
Validation F-beta Score
Class 0: 0.9800, Class 1: 0.2683, Class 2: 0.0600, Class 3: 0.0754, Class 4: 0.2362, Class 5: 0.2114, Class 6: 0.2838, 

Overall Mean Dice Score: 0.2138
Overall Mean F-beta Score: 0.2150

Training Loss: 0.3733, Validation Loss: 0.3950, Validation F-beta: 0.2150
Epoch 28/4000


Training: 100%|██████████| 1440/1440 [12:27<00:00,  1.93it/s, loss=0.398]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.90it/s, loss=0.318]


Validation Dice Score
Class 0: 0.9891, Class 1: 0.1110, Class 2: 0.0000, Class 3: 0.1461, Class 4: 0.4296, Class 5: 0.1444, Class 6: 0.4969, 
Validation F-beta Score
Class 0: 0.9891, Class 1: 0.1734, Class 2: 0.2500, Class 3: 0.2400, Class 4: 0.4785, Class 5: 0.1610, Class 6: 0.5673, 

Overall Mean Dice Score: 0.2656
Overall Mean F-beta Score: 0.3240

Training Loss: 0.3718, Validation Loss: 0.3745, Validation F-beta: 0.3240
Epoch 29/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.317]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.79it/s, loss=0.401]


Validation Dice Score
Class 0: 0.9912, Class 1: 0.2304, Class 2: 0.0703, Class 3: 0.1976, Class 4: 0.2217, Class 5: 0.2877, Class 6: 0.2182, 
Validation F-beta Score
Class 0: 0.9905, Class 1: 0.3674, Class 2: 0.1375, Class 3: 0.1967, Class 4: 0.2902, Class 5: 0.3308, Class 6: 0.4595, 

Overall Mean Dice Score: 0.2311
Overall Mean F-beta Score: 0.3289

Training Loss: 0.3722, Validation Loss: 0.3797, Validation F-beta: 0.3289
Epoch 30/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.393]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.92it/s, loss=0.469]


Validation Dice Score
Class 0: 0.9891, Class 1: 0.2348, Class 2: 0.0118, Class 3: 0.2944, Class 4: 0.1779, Class 5: 0.3845, Class 6: 0.4728, 
Validation F-beta Score
Class 0: 0.9932, Class 1: 0.2088, Class 2: 0.1737, Class 3: 0.2900, Class 4: 0.3984, Class 5: 0.3468, Class 6: 0.6446, 

Overall Mean Dice Score: 0.3129
Overall Mean F-beta Score: 0.3777

Training Loss: 0.3687, Validation Loss: 0.3581, Validation F-beta: 0.3777
Epoch 31/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.341]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.87it/s, loss=0.299]


Validation Dice Score
Class 0: 0.9902, Class 1: 0.3289, Class 2: 0.1010, Class 3: 0.1845, Class 4: 0.4183, Class 5: 0.3106, Class 6: 0.4365, 
Validation F-beta Score
Class 0: 0.9874, Class 1: 0.3845, Class 2: 0.0824, Class 3: 0.3391, Class 4: 0.6179, Class 5: 0.3793, Class 6: 0.5703, 

Overall Mean Dice Score: 0.3358
Overall Mean F-beta Score: 0.4582

Training Loss: 0.3672, Validation Loss: 0.3422, Validation F-beta: 0.4582
SUPER Best model saved. Loss:0.3422, Score:0.4582
Epoch 32/4000


Training: 100%|██████████| 1440/1440 [12:32<00:00,  1.91it/s, loss=0.351]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.82it/s, loss=0.363]


Validation Dice Score
Class 0: 0.9844, Class 1: 0.2106, Class 2: 0.0881, Class 3: 0.1952, Class 4: 0.2994, Class 5: 0.1575, Class 6: 0.1195, 
Validation F-beta Score
Class 0: 0.9797, Class 1: 0.2706, Class 2: 0.1186, Class 3: 0.2955, Class 4: 0.3937, Class 5: 0.1509, Class 6: 0.4425, 

Overall Mean Dice Score: 0.1964
Overall Mean F-beta Score: 0.3106

Training Loss: 0.3662, Validation Loss: 0.4012, Validation F-beta: 0.3106
Epoch 33/4000


Training: 100%|██████████| 1440/1440 [12:27<00:00,  1.93it/s, loss=0.357]
Validation: 100%|██████████| 12/12 [00:03<00:00,  3.07it/s, loss=0.395]


Validation Dice Score
Class 0: 0.9882, Class 1: 0.2728, Class 2: 0.0907, Class 3: 0.1037, Class 4: 0.4647, Class 5: 0.1122, Class 6: 0.3813, 
Validation F-beta Score
Class 0: 0.9889, Class 1: 0.3858, Class 2: 0.0685, Class 3: 0.1773, Class 4: 0.5784, Class 5: 0.0807, Class 6: 0.4267, 

Overall Mean Dice Score: 0.2670
Overall Mean F-beta Score: 0.3298

Training Loss: 0.3635, Validation Loss: 0.3711, Validation F-beta: 0.3298
Epoch 34/4000


Training: 100%|██████████| 1440/1440 [12:29<00:00,  1.92it/s, loss=0.386]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.91it/s, loss=0.393]


Validation Dice Score
Class 0: 0.9907, Class 1: 0.1610, Class 2: 0.0558, Class 3: 0.3062, Class 4: 0.4181, Class 5: 0.4077, Class 6: 0.5029, 
Validation F-beta Score
Class 0: 0.9882, Class 1: 0.1588, Class 2: 0.0656, Class 3: 0.3610, Class 4: 0.6166, Class 5: 0.4601, Class 6: 0.6912, 

Overall Mean Dice Score: 0.3592
Overall Mean F-beta Score: 0.4575

Training Loss: 0.3634, Validation Loss: 0.3365, Validation F-beta: 0.4575
Epoch 35/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.392]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.96it/s, loss=0.418]


Validation Dice Score
Class 0: 0.9896, Class 1: 0.2877, Class 2: 0.0844, Class 3: 0.1865, Class 4: 0.3890, Class 5: 0.2107, Class 6: 0.3675, 
Validation F-beta Score
Class 0: 0.9902, Class 1: 0.3732, Class 2: 0.2261, Class 3: 0.4187, Class 4: 0.5058, Class 5: 0.3516, Class 6: 0.5863, 

Overall Mean Dice Score: 0.2883
Overall Mean F-beta Score: 0.4471

Training Loss: 0.3612, Validation Loss: 0.3582, Validation F-beta: 0.4471
Epoch 36/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.299]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.82it/s, loss=0.317]


Validation Dice Score
Class 0: 0.9907, Class 1: 0.3189, Class 2: 0.1240, Class 3: 0.1209, Class 4: 0.4107, Class 5: 0.3644, Class 6: 0.2979, 
Validation F-beta Score
Class 0: 0.9914, Class 1: 0.3649, Class 2: 0.0895, Class 3: 0.1372, Class 4: 0.7012, Class 5: 0.3749, Class 6: 0.5491, 

Overall Mean Dice Score: 0.3025
Overall Mean F-beta Score: 0.4254

Training Loss: 0.3586, Validation Loss: 0.3521, Validation F-beta: 0.4254
Epoch 37/4000


Training: 100%|██████████| 1440/1440 [12:27<00:00,  1.93it/s, loss=0.304]
Validation: 100%|██████████| 12/12 [00:03<00:00,  3.05it/s, loss=0.377]


Validation Dice Score
Class 0: 0.9840, Class 1: 0.3309, Class 2: 0.0176, Class 3: 0.1765, Class 4: 0.3303, Class 5: 0.2531, Class 6: 0.2967, 
Validation F-beta Score
Class 0: 0.9844, Class 1: 0.3299, Class 2: 0.0109, Class 3: 0.1618, Class 4: 0.5029, Class 5: 0.2245, Class 6: 0.3761, 

Overall Mean Dice Score: 0.2775
Overall Mean F-beta Score: 0.3190

Training Loss: 0.3562, Validation Loss: 0.3812, Validation F-beta: 0.3190
Epoch 38/4000


Training: 100%|██████████| 1440/1440 [12:29<00:00,  1.92it/s, loss=0.395]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.95it/s, loss=0.443]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.2264, Class 2: 0.0760, Class 3: 0.1393, Class 4: 0.2841, Class 5: 0.3721, Class 6: 0.2982, 
Validation F-beta Score
Class 0: 0.9918, Class 1: 0.2659, Class 2: 0.2286, Class 3: 0.2927, Class 4: 0.3776, Class 5: 0.3299, Class 6: 0.6494, 

Overall Mean Dice Score: 0.2640
Overall Mean F-beta Score: 0.3831

Training Loss: 0.3587, Validation Loss: 0.3657, Validation F-beta: 0.3831
Epoch 39/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.371]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.91it/s, loss=0.392]


Validation Dice Score
Class 0: 0.9898, Class 1: 0.3196, Class 2: 0.0396, Class 3: 0.2818, Class 4: 0.2973, Class 5: 0.2847, Class 6: 0.5482, 
Validation F-beta Score
Class 0: 0.9896, Class 1: 0.3447, Class 2: 0.0508, Class 3: 0.2776, Class 4: 0.3870, Class 5: 0.2831, Class 6: 0.6318, 

Overall Mean Dice Score: 0.3463
Overall Mean F-beta Score: 0.3848

Training Loss: 0.3555, Validation Loss: 0.3414, Validation F-beta: 0.3848
Epoch 40/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.288]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.85it/s, loss=0.335]


Validation Dice Score
Class 0: 0.9918, Class 1: 0.4139, Class 2: 0.1908, Class 3: 0.1826, Class 4: 0.3758, Class 5: 0.3053, Class 6: 0.2032, 
Validation F-beta Score
Class 0: 0.9922, Class 1: 0.6428, Class 2: 0.1945, Class 3: 0.1818, Class 4: 0.3419, Class 5: 0.2930, Class 6: 0.3637, 

Overall Mean Dice Score: 0.2962
Overall Mean F-beta Score: 0.3646

Training Loss: 0.3561, Validation Loss: 0.3487, Validation F-beta: 0.3646
Epoch 41/4000


Training: 100%|██████████| 1440/1440 [12:27<00:00,  1.93it/s, loss=0.386]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.88it/s, loss=0.393]


Validation Dice Score
Class 0: 0.9906, Class 1: 0.3473, Class 2: 0.0390, Class 3: 0.1207, Class 4: 0.3252, Class 5: 0.3728, Class 6: 0.4830, 
Validation F-beta Score
Class 0: 0.9908, Class 1: 0.4723, Class 2: 0.1968, Class 3: 0.3754, Class 4: 0.4183, Class 5: 0.4705, Class 6: 0.5795, 

Overall Mean Dice Score: 0.3298
Overall Mean F-beta Score: 0.4632

Training Loss: 0.3554, Validation Loss: 0.3468, Validation F-beta: 0.4632
Epoch 42/4000


Training: 100%|██████████| 1440/1440 [12:29<00:00,  1.92it/s, loss=0.311]
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.91it/s, loss=0.302]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.4417, Class 2: 0.1118, Class 3: 0.2150, Class 4: 0.2871, Class 5: 0.3772, Class 6: 0.5919, 
Validation F-beta Score
Class 0: 0.9933, Class 1: 0.5666, Class 2: 0.0863, Class 3: 0.2966, Class 4: 0.3283, Class 5: 0.3252, Class 6: 0.6749, 

Overall Mean Dice Score: 0.3826
Overall Mean F-beta Score: 0.4383

Training Loss: 0.3517, Validation Loss: 0.3216, Validation F-beta: 0.4383
Epoch 43/4000


Training: 100%|██████████| 1440/1440 [12:28<00:00,  1.92it/s, loss=0.35] 
Validation: 100%|██████████| 12/12 [00:04<00:00,  2.80it/s, loss=0.388]


Validation Dice Score
Class 0: 0.9892, Class 1: 0.2885, Class 2: 0.0712, Class 3: 0.1156, Class 4: 0.4924, Class 5: 0.4387, Class 6: 0.2926, 
Validation F-beta Score
Class 0: 0.9905, Class 1: 0.5175, Class 2: 0.0491, Class 3: 0.1214, Class 4: 0.4543, Class 5: 0.4981, Class 6: 0.4761, 

Overall Mean Dice Score: 0.3256
Overall Mean F-beta Score: 0.4135

Training Loss: 0.3532, Validation Loss: 0.3478, Validation F-beta: 0.4135
Epoch 44/4000


Training: 100%|██████████| 1440/1440 [14:41<00:00,  1.63it/s, loss=0.342]
Validation: 100%|██████████| 12/12 [00:07<00:00,  1.59it/s, loss=0.337]


Validation Dice Score
Class 0: 0.9900, Class 1: 0.2571, Class 2: 0.0756, Class 3: 0.1155, Class 4: 0.2198, Class 5: 0.2823, Class 6: 0.2766, 
Validation F-beta Score
Class 0: 0.9925, Class 1: 0.4095, Class 2: 0.1530, Class 3: 0.1095, Class 4: 0.1840, Class 5: 0.2624, Class 6: 0.3729, 

Overall Mean Dice Score: 0.2303
Overall Mean F-beta Score: 0.2676

Training Loss: 0.3528, Validation Loss: 0.3818, Validation F-beta: 0.2676
Epoch 45/4000


Training: 100%|██████████| 1440/1440 [20:14<00:00,  1.19it/s, loss=0.338]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.03it/s, loss=0.337]


Validation Dice Score
Class 0: 0.9896, Class 1: 0.3061, Class 2: 0.0411, Class 3: 0.3660, Class 4: 0.5135, Class 5: 0.2867, Class 6: 0.3630, 
Validation F-beta Score
Class 0: 0.9892, Class 1: 0.3337, Class 2: 0.2000, Class 3: 0.4203, Class 4: 0.4838, Class 5: 0.2964, Class 6: 0.3927, 

Overall Mean Dice Score: 0.3671
Overall Mean F-beta Score: 0.3854

Training Loss: 0.3498, Validation Loss: 0.3364, Validation F-beta: 0.3854
Epoch 46/4000


Training: 100%|██████████| 1440/1440 [25:03<00:00,  1.04s/it, loss=0.406]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.22it/s, loss=0.338]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.2268, Class 2: 0.1343, Class 3: 0.0830, Class 4: 0.5457, Class 5: 0.2993, Class 6: 0.1444, 
Validation F-beta Score
Class 0: 0.9858, Class 1: 0.2368, Class 2: 0.2054, Class 3: 0.1128, Class 4: 0.6743, Class 5: 0.3106, Class 6: 0.4815, 

Overall Mean Dice Score: 0.2598
Overall Mean F-beta Score: 0.3632

Training Loss: 0.3481, Validation Loss: 0.3691, Validation F-beta: 0.3632
Epoch 47/4000


Training: 100%|██████████| 1440/1440 [22:28<00:00,  1.07it/s, loss=0.286]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.32it/s, loss=0.472]


Validation Dice Score
Class 0: 0.9907, Class 1: 0.5077, Class 2: 0.1697, Class 3: 0.2059, Class 4: 0.1517, Class 5: 0.2331, Class 6: 0.4575, 
Validation F-beta Score
Class 0: 0.9903, Class 1: 0.5169, Class 2: 0.1476, Class 3: 0.2008, Class 4: 0.1503, Class 5: 0.2301, Class 6: 0.4541, 

Overall Mean Dice Score: 0.3112
Overall Mean F-beta Score: 0.3104

Training Loss: 0.3487, Validation Loss: 0.3475, Validation F-beta: 0.3104
Epoch 48/4000


Training: 100%|██████████| 1440/1440 [22:29<00:00,  1.07it/s, loss=0.371]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.34it/s, loss=0.355]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.4021, Class 2: 0.0525, Class 3: 0.2485, Class 4: 0.3788, Class 5: 0.3281, Class 6: 0.2859, 
Validation F-beta Score
Class 0: 0.9932, Class 1: 0.5754, Class 2: 0.3833, Class 3: 0.3240, Class 4: 0.4985, Class 5: 0.2689, Class 6: 0.5535, 

Overall Mean Dice Score: 0.3287
Overall Mean F-beta Score: 0.4441

Training Loss: 0.3490, Validation Loss: 0.3524, Validation F-beta: 0.4441
Epoch 49/4000


Training: 100%|██████████| 1440/1440 [22:28<00:00,  1.07it/s, loss=0.361]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.32it/s, loss=0.345]


Validation Dice Score
Class 0: 0.9906, Class 1: 0.3165, Class 2: 0.0656, Class 3: 0.1660, Class 4: 0.4336, Class 5: 0.3620, Class 6: 0.3665, 
Validation F-beta Score
Class 0: 0.9910, Class 1: 0.4941, Class 2: 0.0480, Class 3: 0.3048, Class 4: 0.5631, Class 5: 0.3770, Class 6: 0.3942, 

Overall Mean Dice Score: 0.3289
Overall Mean F-beta Score: 0.4266

Training Loss: 0.3456, Validation Loss: 0.3459, Validation F-beta: 0.4266
Epoch 50/4000


Training: 100%|██████████| 1440/1440 [22:31<00:00,  1.07it/s, loss=0.395]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.33it/s, loss=0.329]


Validation Dice Score
Class 0: 0.9903, Class 1: 0.2710, Class 2: 0.1499, Class 3: 0.1633, Class 4: 0.5166, Class 5: 0.2061, Class 6: 0.3613, 
Validation F-beta Score
Class 0: 0.9914, Class 1: 0.3767, Class 2: 0.2023, Class 3: 0.1309, Class 4: 0.5818, Class 5: 0.2367, Class 6: 0.4387, 

Overall Mean Dice Score: 0.3037
Overall Mean F-beta Score: 0.3529

Training Loss: 0.3454, Validation Loss: 0.3480, Validation F-beta: 0.3529
Epoch 51/4000


Training: 100%|██████████| 1440/1440 [22:29<00:00,  1.07it/s, loss=0.349]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.29it/s, loss=0.418]


Validation Dice Score
Class 0: 0.9884, Class 1: 0.1656, Class 2: 0.0953, Class 3: 0.2044, Class 4: 0.4476, Class 5: 0.3766, Class 6: 0.4425, 
Validation F-beta Score
Class 0: 0.9890, Class 1: 0.5655, Class 2: 0.1276, Class 3: 0.3253, Class 4: 0.4207, Class 5: 0.4853, Class 6: 0.7656, 

Overall Mean Dice Score: 0.3274
Overall Mean F-beta Score: 0.5125

Training Loss: 0.3438, Validation Loss: 0.3443, Validation F-beta: 0.5125
Epoch 52/4000


Training: 100%|██████████| 1440/1440 [22:31<00:00,  1.07it/s, loss=0.345]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.28it/s, loss=0.271]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.2156, Class 2: 0.1527, Class 3: 0.2260, Class 4: 0.2937, Class 5: 0.2148, Class 6: 0.3614, 
Validation F-beta Score
Class 0: 0.9912, Class 1: 0.3799, Class 2: 0.1694, Class 3: 0.1981, Class 4: 0.3388, Class 5: 0.1762, Class 6: 0.5988, 

Overall Mean Dice Score: 0.2623
Overall Mean F-beta Score: 0.3384

Training Loss: 0.3436, Validation Loss: 0.3740, Validation F-beta: 0.3384
Epoch 53/4000


Training: 100%|██████████| 1440/1440 [22:35<00:00,  1.06it/s, loss=0.352]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.23it/s, loss=0.271]


Validation Dice Score
Class 0: 0.9911, Class 1: 0.3703, Class 2: 0.0211, Class 3: 0.0826, Class 4: 0.4294, Class 5: 0.2811, Class 6: 0.2929, 
Validation F-beta Score
Class 0: 0.9916, Class 1: 0.3709, Class 2: 0.4295, Class 3: 0.3157, Class 4: 0.4262, Class 5: 0.2916, Class 6: 0.3874, 

Overall Mean Dice Score: 0.2912
Overall Mean F-beta Score: 0.3583

Training Loss: 0.3432, Validation Loss: 0.3578, Validation F-beta: 0.3583
Epoch 54/4000


Training: 100%|██████████| 1440/1440 [24:17<00:00,  1.01s/it, loss=0.34] 
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.35it/s, loss=0.386]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.1698, Class 2: 0.0162, Class 3: 0.0944, Class 4: 0.4930, Class 5: 0.3552, Class 6: 0.3821, 
Validation F-beta Score
Class 0: 0.9842, Class 1: 0.4946, Class 2: 0.0311, Class 3: 0.3206, Class 4: 0.5073, Class 5: 0.5502, Class 6: 0.5620, 

Overall Mean Dice Score: 0.2989
Overall Mean F-beta Score: 0.4869

Training Loss: 0.3421, Validation Loss: 0.3644, Validation F-beta: 0.4869
Epoch 55/4000


Training: 100%|██████████| 1440/1440 [29:48<00:00,  1.24s/it, loss=0.337]
Validation: 100%|██████████| 12/12 [00:05<00:00,  2.39it/s, loss=0.373]


Validation Dice Score
Class 0: 0.9918, Class 1: 0.3569, Class 2: 0.0444, Class 3: 0.1805, Class 4: 0.2563, Class 5: 0.4468, Class 6: 0.2945, 
Validation F-beta Score
Class 0: 0.9936, Class 1: 0.5916, Class 2: 0.0294, Class 3: 0.2074, Class 4: 0.6248, Class 5: 0.4496, Class 6: 0.5474, 

Overall Mean Dice Score: 0.3070
Overall Mean F-beta Score: 0.4842

Training Loss: 0.3411, Validation Loss: 0.3505, Validation F-beta: 0.4842
Epoch 56/4000


Training:  66%|██████▋   | 956/1440 [19:36<10:28,  1.30s/it, loss=0.271]

In [12]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# VAl

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import TverskyLoss
import torch
import numpy as np
from tqdm import tqdm
import wandb
from src.dataset.dataset import make_val_dataloader

val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
lamda = 0.52

wandb.init(
    project='czii_SwinUnetR_val',  # 프로젝트 이름 설정
    name='SwinUNETR96_96_lr0.001_lambda0.52_batch2',         # 실행(run) 이름 설정
    config={
        'learning_rate': 0.001,
        'batch_size': batch_size,
        'lambda': lamda,
        'img_size': img_size,
        'device': 'cuda',
        "checkpoint_dir": "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2",
        
    }
)

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

val_loader = make_val_dataloader(
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0
)
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    softmax=True
)
    
    
from monai.metrics import DiceMetric

img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

val_loss, overall_mean_fbeta_score = validate_one_epoch(
    model=model, 
    val_loader=val_loader, 
    criterion=criterion, 
    device=device, 
    epoch=0, 
    calculate_dice_interval=1
)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
class_0_dice_score,▁
class_0_f_beta_score,▁
class_1_dice_score,▁
class_1_f_beta_score,▁
class_2_dice_score,▁
class_2_f_beta_score,▁
class_3_dice_score,▁
class_3_f_beta_score,▁
class_4_dice_score,▁
class_4_f_beta_score,▁

0,1
class_0_dice_score,0.65703
class_0_f_beta_score,0.50748
class_1_dice_score,0.53332
class_1_f_beta_score,0.64703
class_2_dice_score,0.00286
class_2_f_beta_score,0.02334
class_3_dice_score,0.23703
class_3_f_beta_score,0.23033
class_4_dice_score,0.65487
class_4_f_beta_score,0.62525


Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]
  checkpoint = torch.load(pretrain_path, map_location=device)
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s, loss=0.865]

Validation Dice Score
Class 0: 0.6570, Class 1: 0.5333, Class 2: 0.0029, Class 3: 0.2370, 
Class 4: 0.6549, Class 5: 0.4790, Class 6: 0.4255, 
Validation F-beta Score
Class 0: 0.5075, Class 1: 0.6470, Class 2: 0.0233, Class 3: 0.2303, 
Class 4: 0.6252, Class 5: 0.5145, Class 6: 0.4720, 
Overall Mean Dice Score: 0.4659
Overall Mean F-beta Score: 0.4978






# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

In [None]:
val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=1, 
            calculate_dice_interval=0
        )

Validation:   0%|          | 0/4 [00:03<?, ?it/s, loss=0.764]


ZeroDivisionError: integer modulo by zero

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv
