In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

import random
import numpy as np
import torch


from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

# 랜덤 시드 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


print_config()

MONAI version: 1.4.0
Numpy version: 1.26.3
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: c:\ProgramData\anaconda3\envs\ship\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 5.2.0
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the 

In [2]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 1500}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 1500},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.16393442622950818, 2: 0.01639344262295082, 3: 0.2459016393442623, 4: 0.16393442622950818, 5: 0.2459016393442623, 6: 0.16393442622950818}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.16393442622950818, 0.01639344262295082, 0.2459016393442623, 0.16393442622950818, 0.2459016393442623, 0.16393442622950818]


# 모델 설정

In [3]:
from src.dataset.dataset import create_dataloaders
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.transforms import CastToTyped
import numpy as np

train_img_dir = "./datasets/train/images"
train_label_dir = "./datasets/train/labels"
val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
# DATA CONFIG
img_size =  96 # Match your patch size
img_depth = img_size
n_classes = 7
batch_size = 1 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
num_repeat = 60
accumulation_steps = 16

# MODEL CONFIG
feature_size = 48
use_checkpoint = True
use_v2 = True
drop_rate= 0.25
attn_drop_rate = 0.25

# TRAINING CONFIG
num_epochs = 4000

lr = 0.001

# LOSS
warmup_epochs=5,
schedule_epochs=10,
warmup_ce=1.0,
warmup_tv=0.1,
warmup_hd=0.1,
ce_end=0.3,
tv_end=1.0,
hd_end=1.0,
include_background=False,
reduction="mean",
softmax=True,
tversky_alpha = 0.52, # Tversky loss의 alpha 값 = lamda
tversky_beta = 1-tversky_alpha,
tversky_smooth=1e-5,
tv_boost = 1.2,
hd_boost = 1.2,


class_weights = None
class_weights = torch.tensor([0.0001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치

# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    CastToTyped(keys=["image"], dtype=np.float16),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
])

In [4]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0,train_num_repeat=num_repeat)

Loading dataset: 100%|██████████| 24/24 [00:40<00:00,  1.68s/it]
Loading dataset: 100%|██████████| 4/4 [00:07<00:00,  1.89s/it]


https://monai.io/model-zoo.html

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.losses import TverskyLoss, LogHausdorffDTLoss

class AdaptiveCombinedLoss(nn.Module):
    """
    요구사항:
    1) Warm-up 단계 (0 ~ warmup_epochs):
       - CE 위주로 학습(예: CE=1.0), Tversky/HD는 작게(예: 0.1)
    2) Warm-up 종료 시점:
       - 해당 시점 손실값 참고해 Kendall(\log\sigma) 초기값/가중치 보정
    3) Kendall(2018) 식:
       - (1/(2σ_i^2))*L_i + log(σ_i^2)
       - Tversky/HD는 더 강조(상수 w_i를 크게) + CE 최소 한도 유지
    4) Warm-up 이후 스케줄링:
       - CE 점진 감소(1.0→0.3), Tversky/HD 점진 증가(0.1→1.0)
       - 필요시 Tversky/HD를 추가로 boost
    """

    def __init__(
        self,
        # 1) Warm-up & 스케줄
        warmup_epochs: int = 5,
        schedule_epochs: int = 10,

        # 2) Warm-up 시 고정 가중치
        warmup_ce=1.0,
        warmup_tv=0.1,
        warmup_hd=0.1,

        # 3) 스케줄에서의 최종 가중치
        ce_end=0.3,
        tv_end=1.0,
        hd_end=1.0,

        # 4) Tversky/HD를 추가로 더 키우고 싶다면 boost 사용
        #    (예: 2.0이면 TV/HD가 2배 더 강조)
        tv_boost=1.2,
        hd_boost=1.2,

        # MONAI Loss 설정
        include_background=True,
        reduction="mean",
        softmax=True,

        # Tversky 파라미터
        tversky_alpha=0.52,
        tversky_beta=None,  # None이면 1 - alpha
        tversky_smooth=1e-5,

        # Kendall(2018) 식 사용 여부
        use_uncertainty=True,
    ):
        super().__init__()
        self.warmup_epochs = warmup_epochs
        self.schedule_epochs = schedule_epochs

        # Warm-up 고정 가중치
        self.warmup_ce = warmup_ce
        self.warmup_tv = warmup_tv
        self.warmup_hd = warmup_hd

        # Warm-up 이후 스케줄 start/end
        self.ce_start, self.ce_end = warmup_ce, ce_end
        self.tv_start, self.tv_end = warmup_tv, tv_end
        self.hd_start, self.hd_end = warmup_hd, hd_end

        # Tversky/HD 추가 배율(강조)
        self.tv_boost = tv_boost
        self.hd_boost = hd_boost

        self.use_uncertainty = use_uncertainty

        if tversky_beta is None:
            tversky_beta = 1.0 - tversky_alpha

        # 배경 채널 무시 설정
        if not include_background:
            ignore_index = 0
        else:
            ignore_index = -100

        # (1) 개별 손실 정의
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.tversky_loss = TverskyLoss(
            alpha=tversky_alpha,
            beta=tversky_beta,
            smooth_nr=tversky_smooth,
            smooth_dr=tversky_smooth,
            softmax=softmax,
            reduction=reduction,
            include_background=include_background
        )
        self.haus_loss = LogHausdorffDTLoss(
            softmax=softmax,
            reduction=reduction,
            include_background=include_background
        )

        # (2) Kendall을 위한 log_sigma 파라미터
        if self.use_uncertainty:
            self.log_sigma_ce   = nn.Parameter(torch.zeros(1))
            self.log_sigma_tv   = nn.Parameter(torch.zeros(1))
            self.log_sigma_haus = nn.Parameter(torch.zeros(1))

        # (3) 현재 epoch, warm-up 손실 기록
        self.current_epoch = 0
        self.last_warmup_ce   = 0.0
        self.last_warmup_tv   = 0.0
        self.last_warmup_haus = 0.0

    def set_epoch(self, epoch: int):
        """학습 루프에서 매 epoch마다 호출하여 현재 epoch 갱신."""
        self.current_epoch = epoch

    def record_warmup_losses(self, ce_val, tv_val, hd_val):
        """
        Warm-up 단계 손실값을 기록 -> 이후 log_sigma 초기값 조정 등 활용 가능
        """
        self.last_warmup_ce   = ce_val
        self.last_warmup_tv   = tv_val
        self.last_warmup_haus = hd_val

    def end_of_warmup_init(self):
        """
        Warm-up → Kendall 전환 시점에서,
        warm-up 손실값 등을 참고해 log_sigma 등 초기 설정
        """
        if self.use_uncertainty:
            with torch.no_grad():
                # 예시: warm-up에서 CE가 안정적, TV/HD가 크면
                # TV/HD 강조 위해 log_sigma를 음수로 조정
                self.log_sigma_ce[0]   = 0.0   # CE
                self.log_sigma_tv[0]   = -0.5  # TV
                self.log_sigma_haus[0] = -0.5  # HD

    def forward(self, preds, targets):
        """
        preds: (B, C, D, H, W) - logit
        targets: (B, D, H, W) - 정수 라벨
        """
        # 1) 개별 손실 계산
        loss_ce   = self.ce_loss(preds, targets)
        loss_tv   = self.tversky_loss(preds, targets)
        loss_haus = self.haus_loss(preds, targets)

        # 2) Warm-up 단계
        if self.current_epoch < self.warmup_epochs:
            total_loss = (
                self.warmup_ce * loss_ce
                + self.warmup_tv * loss_tv
                + self.warmup_hd * loss_haus
            )
        else:
            # Warm-up 이후 → Kendall + 스케줄
            # (a) ratio 계산
            progress = self.current_epoch - self.warmup_epochs
            ratio = float(progress) / float(self.schedule_epochs)
            ratio = max(0.0, min(1.0, ratio))  # 0 ~ 1

            # (b) CE, TV, HAUS 선형 보간 가중치
            w_ce   = self.ce_start + (self.ce_end - self.ce_start)*ratio
            w_tv   = self.tv_start + (self.tv_end - self.tv_start)*ratio
            w_haus = self.hd_start + (self.hd_end - self.hd_start)*ratio

            # 추가 배율(boost)로 Tversky/HD 더 강조
            w_tv   *= self.tv_boost
            w_haus *= self.hd_boost

            if self.use_uncertainty:
                # (c) Kendall 식
                sigma_ce   = torch.exp(self.log_sigma_ce)
                sigma_tv   = torch.exp(self.log_sigma_tv)
                sigma_haus = torch.exp(self.log_sigma_haus)

                ce_term   = (1.0/(2.0*sigma_ce**2))   * loss_ce   + torch.log(sigma_ce**2)
                tv_term   = (1.0/(2.0*sigma_tv**2))   * loss_tv   + torch.log(sigma_tv**2)
                haus_term = (1.0/(2.0*sigma_haus**2)) * loss_haus + torch.log(sigma_haus**2)

                # (d) 최종 합산 (TV/HD > CE가 되도록 w_tv, w_haus를 크게)
                total_loss = w_ce*ce_term + w_tv*tv_term + w_haus*haus_term
            else:
                # Kendall 미사용 시 단순 가중합
                total_loss = w_ce*loss_ce + w_tv*loss_tv + w_haus*loss_haus

        return total_loss

In [6]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
from monai.metrics import DiceMetric

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=feature_size,
    use_checkpoint=True,
    drop_rate = drop_rate,
    attn_drop_rate = attn_drop_rate,
    use_v2 = use_v2,
).to(device)
# Pretrained weights 불러오기
# if use_checkpoint:
#     pretrain_path = "./swin_unetr_btcv_segmentation/models/model.pt"
#     weight = torch.load(pretrain_path, map_location=device)

#     # 출력 레이어의 키를 제외한 나머지 가중치만 로드
#     filtered_weights = {k: v for k, v in weight.items() if "out.conv.conv" not in k}

#     # strict=False로 로드하여 불일치하는 부분 무시
#     model.load_state_dict(filtered_weights, strict=False)
#     print("Filtered weights loaded successfully. Output layer will be trained from scratch.")


# Loss function
criterion = AdaptiveCombinedLoss(
    warmup_epochs=warmup_epochs,
    schedule_epochs=schedule_epochs,
    warmup_ce=warmup_ce,
    warmup_tv=warmup_tv,
    warmup_hd=warmup_hd,
    ce_end=ce_end,
    tv_end=tv_end,
    hd_end=hd_end,
    include_background=include_background,
    reduction=reduction,
    tversky_alpha=tversky_alpha,
    tversky_beta=tversky_beta,
    tversky_smooth=tversky_smooth,
    tv_boost=tv_boost,
    hd_boost=hd_boost,
)

pretrain_str = "yes" if use_checkpoint else "no"
weight_str = "weighted" if class_weights is not None else ""
if tv_boost == hd_boost == 1.0:
    boost_str = "b{tv_boost:.2f}"
else:
    boost_str = "tvb{tv_boost:.2f}_hb{hd_boost:.2f}"
# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
folder_name = f"SwinUNETRv2_CETVHF_{weight_str}_f{feature_size}s{img_size}lr{lr:.0e}_T-a{tversky_alpha:.2f}b{tversky_beta:.2f}Wc{warmup_ce}_Wt{warmup_tv}Wh{warmup_hd}_We{warmup_epochs}_Se{schedule_epochs}_{boost_str}_b{batch_size}_r{num_repeat}"
checkpoint_dir = checkpoint_base_dir / folder_name
optimizer = optim.AdamW(list(model.parameters()) + list(criterion.parameters()), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")



In [7]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([1, 1, 96, 96, 96]) torch.Size([1, 1, 96, 96, 96])


In [8]:
torch.backends.cudnn.benchmark = True

In [9]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_SwinUnetR',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': tversky_alpha,
        "cross_entropy_weight": warmup_ce,
        "tversky_weight": warmup_tv,
        "hausdorff_weight": warmup_hd,
        "cross_entropy_weight_end": ce_end,
        "tversky_weight_end": tv_end,
        "hausdorff_weight_end": hd_end,
        "include_background": include_background,
        "wramup_epochs": warmup_epochs,
        "schedule_epochs": schedule_epochs,
        "include_background": include_background,
        "reduction": reduction,
        'feature_size': feature_size,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(folder_name),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        "use_checkpoint": use_checkpoint,
        "drop_rate": drop_rate,
        "attn_drop_rate": attn_drop_rate,
        "use_v2": use_v2,
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mwoow070840[0m ([33mwaooang[0m). Use [1m`wandb login --relogin`[0m to force relogin


[]

# 학습

In [10]:
from monai.metrics import DiceMetric
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    # loss = criterion(outputs, labels_onehot)
    loss = criterion(outputs, labels_onehot)
    return loss, outputs, labels, outputs.argmax(dim=1)

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    with tqdm(train_loader, desc='Training') as pbar:
        for i, batch_data in enumerate(pbar):
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
    avg_loss = epoch_loss / len(train_loader)
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss


def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval):
    model.eval()
    val_loss = 0
    
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
        print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1})
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\n")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    return val_loss / len(val_loader), overall_mean_fbeta

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps
        )
        
        scheduler.step(train_loss)
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F-beta: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        # if epochs_no_improve%6 == 0:
        #     # 손실이 개선되지 않았으므로 lambda 감소
        #     new_lamda = max(criterion.lamda - 0.01, 0.1)  # 최소값은 0.1로 설정
        #     criterion.set_lamda(new_lamda)
        #     print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")

    wandb.finish()


In [None]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps
     ) 

Epoch 1/4000


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Training: 100%|██████████| 1440/1440 [24:08<00:00,  1.01s/it, loss=0.114]
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.31it/s, loss=0.179]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.0000, Class 5: 0.0000, Class 6: 0.0000, 
Validation F-beta Score
Class 0: 0.9982, Class 1: 0.5000, Class 2: 0.6667, Class 3: 0.3333, Class 4: 0.3333, Class 5: 0.4167, Class 6: 0.5000, 

Overall Mean Dice Score: 0.0000
Overall Mean F-beta Score: 0.4167

Training Loss: 0.2826, Validation Loss: 0.1562, Validation F-beta: 0.4167
SUPER Best model saved. Loss:0.1562, Score:0.4167
Epoch 2/4000


Training: 100%|██████████| 1440/1440 [23:30<00:00,  1.02it/s, loss=0.112] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.23it/s, loss=0.126]


Validation Dice Score
Class 0: 0.9905, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.0000, Class 5: 0.0000, Class 6: 0.0000, 
Validation F-beta Score
Class 0: 0.9989, Class 1: 0.5833, Class 2: 0.4167, Class 3: 0.2500, Class 4: 0.5000, Class 5: 0.0833, Class 6: 0.4167, 

Overall Mean Dice Score: 0.0000
Overall Mean F-beta Score: 0.3667

Training Loss: 0.1336, Validation Loss: 0.1301, Validation F-beta: 0.3667
Epoch 3/4000


Training: 100%|██████████| 1440/1440 [24:31<00:00,  1.02s/it, loss=0.143] 
Validation: 100%|██████████| 12/12 [00:09<00:00,  1.28it/s, loss=0.195]


Validation Dice Score
Class 0: 0.9867, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.0000, Class 5: 0.0000, Class 6: 0.0406, 
Validation F-beta Score
Class 0: 0.9984, Class 1: 0.3333, Class 2: 0.4167, Class 3: 0.7500, Class 4: 0.3333, Class 5: 0.2500, Class 6: 0.6074, 

Overall Mean Dice Score: 0.0081
Overall Mean F-beta Score: 0.4548

Training Loss: 0.1334, Validation Loss: 0.1434, Validation F-beta: 0.4548
SUPER Best model saved. Loss:0.1434, Score:0.4548
Epoch 4/4000


Training: 100%|██████████| 1440/1440 [25:02<00:00,  1.04s/it, loss=0.199] 
Validation: 100%|██████████| 12/12 [00:11<00:00,  1.05it/s, loss=0.0933]


Validation Dice Score
Class 0: 0.9858, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.0000, Class 5: 0.0000, Class 6: 0.4129, 
Validation F-beta Score
Class 0: 0.9977, Class 1: 0.3333, Class 2: 0.5833, Class 3: 0.4167, Class 4: 0.2500, Class 5: 0.0833, Class 6: 0.3826, 

Overall Mean Dice Score: 0.0826
Overall Mean F-beta Score: 0.2932

Training Loss: 0.1294, Validation Loss: 0.1378, Validation F-beta: 0.2932
Epoch 5/4000


Training: 100%|██████████| 1440/1440 [24:48<00:00,  1.03s/it, loss=0.114] 
Validation: 100%|██████████| 12/12 [00:10<00:00,  1.14it/s, loss=0.094]


Validation Dice Score
Class 0: 0.9884, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.0597, Class 5: 0.0000, Class 6: 0.1931, 
Validation F-beta Score
Class 0: 0.9982, Class 1: 0.5833, Class 2: 0.3333, Class 3: 0.2500, Class 4: 0.3752, Class 5: 0.1667, Class 6: 0.6532, 

Overall Mean Dice Score: 0.0506
Overall Mean F-beta Score: 0.4057

Training Loss: 0.1255, Validation Loss: 0.1525, Validation F-beta: 0.4057
Epoch 6/4000


Training: 100%|██████████| 1440/1440 [26:14<00:00,  1.09s/it, loss=0.188] 
Validation: 100%|██████████| 12/12 [00:10<00:00,  1.18it/s, loss=0.133]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0000, Class 4: 0.0007, Class 5: 0.0000, Class 6: 0.1762, 
Validation F-beta Score
Class 0: 0.9983, Class 1: 0.5833, Class 2: 0.2500, Class 3: 0.6667, Class 4: 0.2504, Class 5: 0.1667, Class 6: 0.4036, 

Overall Mean Dice Score: 0.0354
Overall Mean F-beta Score: 0.4141

Training Loss: 0.1338, Validation Loss: 0.1502, Validation F-beta: 0.4141
Epoch 7/4000


Training: 100%|██████████| 1440/1440 [26:34<00:00,  1.11s/it, loss=0.0918]
Validation: 100%|██████████| 12/12 [00:10<00:00,  1.11it/s, loss=0.115]


Validation Dice Score
Class 0: 0.9926, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0967, Class 4: 0.0841, Class 5: 0.0000, Class 6: 0.2846, 
Validation F-beta Score
Class 0: 0.9984, Class 1: 0.3333, Class 2: 0.7500, Class 3: 0.3144, Class 4: 0.4768, Class 5: 0.2500, Class 6: 0.7785, 

Overall Mean Dice Score: 0.0931
Overall Mean F-beta Score: 0.4306

Training Loss: 0.1283, Validation Loss: 0.1411, Validation F-beta: 0.4306
Epoch 8/4000


Training: 100%|██████████| 1440/1440 [27:09<00:00,  1.13s/it, loss=0.0997]
Validation: 100%|██████████| 12/12 [00:11<00:00,  1.01it/s, loss=0.224]


Validation Dice Score
Class 0: 0.9933, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0668, Class 4: 0.0177, Class 5: 0.0000, Class 6: 0.3096, 
Validation F-beta Score
Class 0: 0.9988, Class 1: 0.1667, Class 2: 0.5833, Class 3: 0.2266, Class 4: 0.4271, Class 5: 0.1667, Class 6: 0.5934, 

Overall Mean Dice Score: 0.0788
Overall Mean F-beta Score: 0.3161

Training Loss: 0.1257, Validation Loss: 0.1649, Validation F-beta: 0.3161
Epoch 9/4000


Training: 100%|██████████| 1440/1440 [27:22<00:00,  1.14s/it, loss=0.103] 
Validation: 100%|██████████| 12/12 [00:11<00:00,  1.02it/s, loss=0.0915]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.1031, Class 4: 0.1559, Class 5: 0.0002, Class 6: 0.5472, 
Validation F-beta Score
Class 0: 0.9980, Class 1: 0.4167, Class 2: 0.7500, Class 3: 0.3512, Class 4: 0.5235, Class 5: 0.0834, Class 6: 0.8057, 

Overall Mean Dice Score: 0.1613
Overall Mean F-beta Score: 0.4361

Training Loss: 0.1208, Validation Loss: 0.1570, Validation F-beta: 0.4361
Epoch 10/4000


Training: 100%|██████████| 1440/1440 [29:21<00:00,  1.22s/it, loss=0.115] 
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.16s/it, loss=0.231]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.0982, Class 2: 0.0000, Class 3: 0.1079, Class 4: 0.1560, Class 5: 0.0355, Class 6: 0.4137, 
Validation F-beta Score
Class 0: 0.9980, Class 1: 0.1478, Class 2: 0.5000, Class 3: 0.1782, Class 4: 0.1933, Class 5: 0.1902, Class 6: 0.6787, 

Overall Mean Dice Score: 0.1623
Overall Mean F-beta Score: 0.2776

Training Loss: 0.1360, Validation Loss: 0.1851, Validation F-beta: 0.2776
Epoch 11/4000


Training: 100%|██████████| 1440/1440 [30:36<00:00,  1.28s/it, loss=0.0791]
Validation: 100%|██████████| 12/12 [00:14<00:00,  1.18s/it, loss=0.0977]


Validation Dice Score
Class 0: 0.9936, Class 1: 0.4140, Class 2: 0.0000, Class 3: 0.2046, Class 4: 0.2154, Class 5: 0.0918, Class 6: 0.4177, 
Validation F-beta Score
Class 0: 0.9983, Class 1: 0.5400, Class 2: 0.5000, Class 3: 0.3613, Class 4: 0.4907, Class 5: 0.3222, Class 6: 0.6875, 

Overall Mean Dice Score: 0.2687
Overall Mean F-beta Score: 0.4804

Training Loss: 0.1334, Validation Loss: 0.1401, Validation F-beta: 0.4804
SUPER Best model saved. Loss:0.1401, Score:0.4804
Epoch 12/4000


Training: 100%|██████████| 1440/1440 [30:43<00:00,  1.28s/it, loss=0.241] 
Validation: 100%|██████████| 12/12 [00:12<00:00,  1.06s/it, loss=0.112] 


Validation Dice Score
Class 0: 0.9908, Class 1: 0.2643, Class 2: 0.0000, Class 3: 0.1897, Class 4: 0.0075, Class 5: 0.0108, Class 6: 0.3138, 
Validation F-beta Score
Class 0: 0.9984, Class 1: 0.5529, Class 2: 0.3333, Class 3: 0.3297, Class 4: 0.3375, Class 5: 0.0061, Class 6: 0.7995, 

Overall Mean Dice Score: 0.1572
Overall Mean F-beta Score: 0.4051

Training Loss: 0.1362, Validation Loss: 0.1491, Validation F-beta: 0.4051
Epoch 13/4000


Training: 100%|██████████| 1440/1440 [30:49<00:00,  1.28s/it, loss=0.233] 
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.11s/it, loss=0.0993]


Validation Dice Score
Class 0: 0.9916, Class 1: 0.2986, Class 2: 0.0000, Class 3: 0.0565, Class 4: 0.3102, Class 5: 0.0066, Class 6: 0.3226, 
Validation F-beta Score
Class 0: 0.9986, Class 1: 0.3671, Class 2: 0.3333, Class 3: 0.6247, Class 4: 0.4943, Class 5: 0.0870, Class 6: 0.6173, 

Overall Mean Dice Score: 0.1989
Overall Mean F-beta Score: 0.4381

Training Loss: 0.1355, Validation Loss: 0.1237, Validation F-beta: 0.4381
Epoch 14/4000


Training: 100%|██████████| 1440/1440 [30:58<00:00,  1.29s/it, loss=0.268] 
Validation: 100%|██████████| 12/12 [00:10<00:00,  1.15it/s, loss=0.39]  


Validation Dice Score
Class 0: 0.9902, Class 1: 0.1801, Class 2: 0.0000, Class 3: 0.0006, Class 4: 0.2352, Class 5: 0.0000, Class 6: 0.0743, 
Validation F-beta Score
Class 0: 0.9981, Class 1: 0.8351, Class 2: 0.5833, Class 3: 0.4170, Class 4: 0.4371, Class 5: 0.2500, Class 6: 0.5796, 

Overall Mean Dice Score: 0.0981
Overall Mean F-beta Score: 0.5038

Training Loss: 0.1350, Validation Loss: 0.1578, Validation F-beta: 0.5038
Epoch 15/4000


Training: 100%|██████████| 1440/1440 [30:57<00:00,  1.29s/it, loss=0.198] 
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.16s/it, loss=0.158]


Validation Dice Score
Class 0: 0.9927, Class 1: 0.3655, Class 2: 0.0000, Class 3: 0.1364, Class 4: 0.2834, Class 5: 0.1039, Class 6: 0.4567, 
Validation F-beta Score
Class 0: 0.9969, Class 1: 0.5147, Class 2: 0.3333, Class 3: 0.2132, Class 4: 0.5965, Class 5: 0.4019, Class 6: 0.8741, 

Overall Mean Dice Score: 0.2692
Overall Mean F-beta Score: 0.5201

Training Loss: 0.1325, Validation Loss: 0.1526, Validation F-beta: 0.5201
Epoch 16/4000


Training: 100%|██████████| 1440/1440 [31:34<00:00,  1.32s/it, loss=0.0937]
Validation: 100%|██████████| 12/12 [00:12<00:00,  1.05s/it, loss=0.178]


Validation Dice Score
Class 0: 0.9933, Class 1: 0.3326, Class 2: 0.0000, Class 3: 0.2142, Class 4: 0.2999, Class 5: 0.0472, Class 6: 0.2089, 
Validation F-beta Score
Class 0: 0.9983, Class 1: 0.5009, Class 2: 0.5000, Class 3: 0.4123, Class 4: 0.6597, Class 5: 0.2011, Class 6: 0.8192, 

Overall Mean Dice Score: 0.2206
Overall Mean F-beta Score: 0.5186

Training Loss: 0.1276, Validation Loss: 0.1233, Validation F-beta: 0.5186
SUPER Best model saved. Loss:0.1233, Score:0.5186
Epoch 17/4000


Training: 100%|██████████| 1440/1440 [31:39<00:00,  1.32s/it, loss=0.082] 
Validation: 100%|██████████| 12/12 [00:12<00:00,  1.01s/it, loss=0.0887]


Validation Dice Score
Class 0: 0.9923, Class 1: 0.2509, Class 2: 0.0000, Class 3: 0.0686, Class 4: 0.3343, Class 5: 0.0972, Class 6: 0.3614, 
Validation F-beta Score
Class 0: 0.9979, Class 1: 0.7260, Class 2: 0.5833, Class 3: 0.3094, Class 4: 0.5304, Class 5: 0.3177, Class 6: 0.7878, 

Overall Mean Dice Score: 0.2225
Overall Mean F-beta Score: 0.5343

Training Loss: 0.1252, Validation Loss: 0.1313, Validation F-beta: 0.5343
Epoch 18/4000


Training: 100%|██████████| 1440/1440 [31:33<00:00,  1.31s/it, loss=0.0963]
Validation: 100%|██████████| 12/12 [00:15<00:00,  1.29s/it, loss=0.205]


Validation Dice Score
Class 0: 0.9914, Class 1: 0.6388, Class 2: 0.0000, Class 3: 0.1754, Class 4: 0.2481, Class 5: 0.0255, Class 6: 0.4892, 
Validation F-beta Score
Class 0: 0.9971, Class 1: 0.8094, Class 2: 0.5000, Class 3: 0.2431, Class 4: 0.4395, Class 5: 0.0987, Class 6: 0.6804, 

Overall Mean Dice Score: 0.3154
Overall Mean F-beta Score: 0.4542

Training Loss: 0.1263, Validation Loss: 0.1657, Validation F-beta: 0.4542
Epoch 19/4000


Training: 100%|██████████| 1440/1440 [31:54<00:00,  1.33s/it, loss=0.187] 
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.13s/it, loss=0.332]


Validation Dice Score
Class 0: 0.9904, Class 1: 0.3877, Class 2: 0.0000, Class 3: 0.3134, Class 4: 0.1891, Class 5: 0.0414, Class 6: 0.2822, 
Validation F-beta Score
Class 0: 0.9974, Class 1: 0.6392, Class 2: 0.6667, Class 3: 0.5590, Class 4: 0.4022, Class 5: 0.1925, Class 6: 0.7248, 

Overall Mean Dice Score: 0.2428
Overall Mean F-beta Score: 0.5036

Training Loss: 0.1237, Validation Loss: 0.1808, Validation F-beta: 0.5036
Epoch 20/4000


Training: 100%|██████████| 1440/1440 [31:45<00:00,  1.32s/it, loss=0.287] 
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.13s/it, loss=0.133] 


Validation Dice Score
Class 0: 0.9932, Class 1: 0.3253, Class 2: 0.0000, Class 3: 0.2075, Class 4: 0.3085, Class 5: 0.2538, Class 6: 0.2852, 
Validation F-beta Score
Class 0: 0.9963, Class 1: 0.6634, Class 2: 0.5000, Class 3: 0.3501, Class 4: 0.6064, Class 5: 0.2349, Class 6: 0.9498, 

Overall Mean Dice Score: 0.2760
Overall Mean F-beta Score: 0.5609

Training Loss: 0.1266, Validation Loss: 0.1409, Validation F-beta: 0.5609
Epoch 21/4000


Training: 100%|██████████| 1440/1440 [31:30<00:00,  1.31s/it, loss=0.0755]
Validation: 100%|██████████| 12/12 [00:14<00:00,  1.25s/it, loss=0.0634]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.5223, Class 2: 0.0000, Class 3: 0.2408, Class 4: 0.3272, Class 5: 0.0000, Class 6: 0.5877, 
Validation F-beta Score
Class 0: 0.9971, Class 1: 0.5375, Class 2: 0.4167, Class 3: 0.4114, Class 4: 0.4444, Class 5: 0.3333, Class 6: 0.9368, 

Overall Mean Dice Score: 0.3356
Overall Mean F-beta Score: 0.5327

Training Loss: 0.1227, Validation Loss: 0.1675, Validation F-beta: 0.5327
Epoch 22/4000


Training: 100%|██████████| 1440/1440 [31:59<00:00,  1.33s/it, loss=0.0847]
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.10s/it, loss=0.162] 


Validation Dice Score
Class 0: 0.9930, Class 1: 0.4018, Class 2: 0.0000, Class 3: 0.1830, Class 4: 0.3057, Class 5: 0.0203, Class 6: 0.5365, 
Validation F-beta Score
Class 0: 0.9982, Class 1: 0.8054, Class 2: 0.5000, Class 3: 0.4107, Class 4: 0.6701, Class 5: 0.0121, Class 6: 0.8762, 

Overall Mean Dice Score: 0.2895
Overall Mean F-beta Score: 0.5549

Training Loss: 0.1223, Validation Loss: 0.1104, Validation F-beta: 0.5549
SUPER Best model saved. Loss:0.1104, Score:0.5549
Epoch 23/4000


Training: 100%|██████████| 1440/1440 [32:35<00:00,  1.36s/it, loss=0.134] 
Validation: 100%|██████████| 12/12 [00:15<00:00,  1.29s/it, loss=0.0746]


Validation Dice Score
Class 0: 0.9942, Class 1: 0.5664, Class 2: 0.0000, Class 3: 0.2033, Class 4: 0.4085, Class 5: 0.2879, Class 6: 0.6645, 
Validation F-beta Score
Class 0: 0.9967, Class 1: 0.6654, Class 2: 0.8333, Class 3: 0.4350, Class 4: 0.6390, Class 5: 0.4948, Class 6: 0.8641, 

Overall Mean Dice Score: 0.4261
Overall Mean F-beta Score: 0.6197

Training Loss: 0.1194, Validation Loss: 0.1220, Validation F-beta: 0.6197
Epoch 24/4000


Training: 100%|██████████| 1440/1440 [32:13<00:00,  1.34s/it, loss=0.0817]
Validation: 100%|██████████| 12/12 [00:14<00:00,  1.21s/it, loss=0.108]


Validation Dice Score
Class 0: 0.9931, Class 1: 0.5609, Class 2: 0.0000, Class 3: 0.1210, Class 4: 0.2289, Class 5: 0.1139, Class 6: 0.4572, 
Validation F-beta Score
Class 0: 0.9983, Class 1: 0.7637, Class 2: 0.5833, Class 3: 0.0847, Class 4: 0.5328, Class 5: 0.1635, Class 6: 0.9700, 

Overall Mean Dice Score: 0.2964
Overall Mean F-beta Score: 0.5029

Training Loss: 0.1215, Validation Loss: 0.1404, Validation F-beta: 0.5029
Epoch 25/4000


Training: 100%|██████████| 1440/1440 [31:56<00:00,  1.33s/it, loss=0.0925]
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.13s/it, loss=0.0955]


Validation Dice Score
Class 0: 0.9925, Class 1: 0.2715, Class 2: 0.0000, Class 3: 0.3368, Class 4: 0.4487, Class 5: 0.1409, Class 6: 0.2564, 
Validation F-beta Score
Class 0: 0.9979, Class 1: 0.7038, Class 2: 0.5000, Class 3: 0.5655, Class 4: 0.4999, Class 5: 0.1979, Class 6: 0.9468, 

Overall Mean Dice Score: 0.2909
Overall Mean F-beta Score: 0.5828

Training Loss: 0.1218, Validation Loss: 0.1674, Validation F-beta: 0.5828
Epoch 26/4000


Training: 100%|██████████| 1440/1440 [32:22<00:00,  1.35s/it, loss=0.0841]
Validation: 100%|██████████| 12/12 [00:12<00:00,  1.05s/it, loss=0.079]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.0598, Class 2: 0.0000, Class 3: 0.2948, Class 4: 0.3641, Class 5: 0.1425, Class 6: 0.2646, 
Validation F-beta Score
Class 0: 0.9978, Class 1: 0.5499, Class 2: 0.8333, Class 3: 0.5894, Class 4: 0.3860, Class 5: 0.2717, Class 6: 0.8849, 

Overall Mean Dice Score: 0.2252
Overall Mean F-beta Score: 0.5364

Training Loss: 0.1231, Validation Loss: 0.1736, Validation F-beta: 0.5364
Epoch 27/4000


Training: 100%|██████████| 1440/1440 [31:38<00:00,  1.32s/it, loss=0.153] 
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.15s/it, loss=0.225] 


Validation Dice Score
Class 0: 0.9940, Class 1: 0.5394, Class 2: 0.0000, Class 3: 0.1583, Class 4: 0.2000, Class 5: 0.2611, Class 6: 0.3389, 
Validation F-beta Score
Class 0: 0.9981, Class 1: 0.7837, Class 2: 0.6667, Class 3: 0.3106, Class 4: 0.6596, Class 5: 0.2166, Class 6: 0.7240, 

Overall Mean Dice Score: 0.2995
Overall Mean F-beta Score: 0.5389

Training Loss: 0.1159, Validation Loss: 0.1449, Validation F-beta: 0.5389
Epoch 28/4000


Training: 100%|██████████| 1440/1440 [31:49<00:00,  1.33s/it, loss=0.0887]
Validation: 100%|██████████| 12/12 [00:12<00:00,  1.06s/it, loss=0.203] 


Validation Dice Score
Class 0: 0.9943, Class 1: 0.3202, Class 2: 0.0000, Class 3: 0.1966, Class 4: 0.4908, Class 5: 0.0693, Class 6: 0.6110, 
Validation F-beta Score
Class 0: 0.9978, Class 1: 0.8204, Class 2: 0.6667, Class 3: 0.4945, Class 4: 0.4906, Class 5: 0.6577, Class 6: 0.9519, 

Overall Mean Dice Score: 0.3376
Overall Mean F-beta Score: 0.6830

Training Loss: 0.1204, Validation Loss: 0.1145, Validation F-beta: 0.6830
Epoch 29/4000


Training: 100%|██████████| 1440/1440 [31:39<00:00,  1.32s/it, loss=0.0679]
Validation: 100%|██████████| 12/12 [00:12<00:00,  1.08s/it, loss=0.112] 


Validation Dice Score
Class 0: 0.9937, Class 1: 0.3076, Class 2: 0.0000, Class 3: 0.2573, Class 4: 0.2010, Class 5: 0.1170, Class 6: 0.2928, 
Validation F-beta Score
Class 0: 0.9977, Class 1: 0.6964, Class 2: 0.5000, Class 3: 0.3205, Class 4: 0.7577, Class 5: 0.1704, Class 6: 0.8601, 

Overall Mean Dice Score: 0.2351
Overall Mean F-beta Score: 0.5610

Training Loss: 0.1203, Validation Loss: 0.1275, Validation F-beta: 0.5610
Epoch 30/4000


Training: 100%|██████████| 1440/1440 [32:11<00:00,  1.34s/it, loss=0.123] 
Validation: 100%|██████████| 12/12 [00:14<00:00,  1.17s/it, loss=0.42]  


Validation Dice Score
Class 0: 0.9900, Class 1: 0.4714, Class 2: 0.0000, Class 3: 0.2283, Class 4: 0.1347, Class 5: 0.1561, Class 6: 0.5110, 
Validation F-beta Score
Class 0: 0.9977, Class 1: 0.7511, Class 2: 0.5833, Class 3: 0.4682, Class 4: 0.6054, Class 5: 0.1178, Class 6: 0.8743, 

Overall Mean Dice Score: 0.3003
Overall Mean F-beta Score: 0.5634

Training Loss: 0.1200, Validation Loss: 0.1746, Validation F-beta: 0.5634
Epoch 31/4000


Training: 100%|██████████| 1440/1440 [32:21<00:00,  1.35s/it, loss=0.0918]
Validation: 100%|██████████| 12/12 [00:14<00:00,  1.24s/it, loss=0.0693]


Validation Dice Score
Class 0: 0.9949, Class 1: 0.5097, Class 2: 0.0000, Class 3: 0.2342, Class 4: 0.4580, Class 5: 0.3715, Class 6: 0.5203, 
Validation F-beta Score
Class 0: 0.9963, Class 1: 0.6741, Class 2: 0.4167, Class 3: 0.5487, Class 4: 0.6934, Class 5: 0.5015, Class 6: 0.9691, 

Overall Mean Dice Score: 0.4187
Overall Mean F-beta Score: 0.6773

Training Loss: 0.1206, Validation Loss: 0.1280, Validation F-beta: 0.6773
Epoch 32/4000


Training: 100%|██████████| 1440/1440 [32:21<00:00,  1.35s/it, loss=0.0978]
Validation: 100%|██████████| 12/12 [00:12<00:00,  1.06s/it, loss=0.189]


Validation Dice Score
Class 0: 0.9918, Class 1: 0.3789, Class 2: 0.0000, Class 3: 0.2184, Class 4: 0.2969, Class 5: 0.0786, Class 6: 0.2119, 
Validation F-beta Score
Class 0: 0.9969, Class 1: 0.7672, Class 2: 0.5000, Class 3: 0.5475, Class 4: 0.6634, Class 5: 0.2267, Class 6: 0.7759, 

Overall Mean Dice Score: 0.2369
Overall Mean F-beta Score: 0.5962

Training Loss: 0.1184, Validation Loss: 0.1525, Validation F-beta: 0.5962
Epoch 33/4000


Training: 100%|██████████| 1440/1440 [32:33<00:00,  1.36s/it, loss=0.0804]
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.16s/it, loss=0.221] 


Validation Dice Score
Class 0: 0.9913, Class 1: 0.3472, Class 2: 0.0000, Class 3: 0.0453, Class 4: 0.4811, Class 5: 0.0439, Class 6: 0.4665, 
Validation F-beta Score
Class 0: 0.9981, Class 1: 0.7830, Class 2: 0.4167, Class 3: 0.2834, Class 4: 0.4626, Class 5: 0.3631, Class 6: 0.9693, 

Overall Mean Dice Score: 0.2768
Overall Mean F-beta Score: 0.5723

Training Loss: 0.1204, Validation Loss: 0.1564, Validation F-beta: 0.5723
Epoch 34/4000


Training: 100%|██████████| 1440/1440 [32:09<00:00,  1.34s/it, loss=0.119] 
Validation: 100%|██████████| 12/12 [00:15<00:00,  1.28s/it, loss=0.105] 


Validation Dice Score
Class 0: 0.9930, Class 1: 0.2696, Class 2: 0.0000, Class 3: 0.3199, Class 4: 0.3864, Class 5: 0.3691, Class 6: 0.5322, 
Validation F-beta Score
Class 0: 0.9968, Class 1: 0.6693, Class 2: 0.5000, Class 3: 0.3340, Class 4: 0.7084, Class 5: 0.3161, Class 6: 0.8720, 

Overall Mean Dice Score: 0.3754
Overall Mean F-beta Score: 0.5800

Training Loss: 0.1154, Validation Loss: 0.1564, Validation F-beta: 0.5800
Epoch 35/4000


Training: 100%|██████████| 1440/1440 [33:05<00:00,  1.38s/it, loss=0.114] 
Validation: 100%|██████████| 12/12 [00:14<00:00,  1.24s/it, loss=0.13]  


Validation Dice Score
Class 0: 0.9929, Class 1: 0.3679, Class 2: 0.0127, Class 3: 0.1736, Class 4: 0.4131, Class 5: 0.2145, Class 6: 0.4583, 
Validation F-beta Score
Class 0: 0.9968, Class 1: 0.7323, Class 2: 0.3403, Class 3: 0.4864, Class 4: 0.7038, Class 5: 0.5432, Class 6: 0.8792, 

Overall Mean Dice Score: 0.3255
Overall Mean F-beta Score: 0.6690

Training Loss: 0.1156, Validation Loss: 0.1212, Validation F-beta: 0.6690
Epoch 36/4000


Training: 100%|██████████| 1440/1440 [33:54<00:00,  1.41s/it, loss=0.0637]
Validation: 100%|██████████| 12/12 [00:15<00:00,  1.30s/it, loss=0.207]


Validation Dice Score
Class 0: 0.9932, Class 1: 0.4398, Class 2: 0.0256, Class 3: 0.0964, Class 4: 0.3591, Class 5: 0.3243, Class 6: 0.3467, 
Validation F-beta Score
Class 0: 0.9967, Class 1: 0.6951, Class 2: 0.1826, Class 3: 0.2482, Class 4: 0.6613, Class 5: 0.2866, Class 6: 0.7740, 

Overall Mean Dice Score: 0.3132
Overall Mean F-beta Score: 0.5330

Training Loss: 0.1172, Validation Loss: 0.1468, Validation F-beta: 0.5330
Epoch 37/4000


Training: 100%|██████████| 1440/1440 [33:15<00:00,  1.39s/it, loss=0.202] 
Validation: 100%|██████████| 12/12 [00:16<00:00,  1.37s/it, loss=0.233] 


Validation Dice Score
Class 0: 0.9898, Class 1: 0.5188, Class 2: 0.0098, Class 3: 0.2487, Class 4: 0.4467, Class 5: 0.1645, Class 6: 0.3709, 
Validation F-beta Score
Class 0: 0.9965, Class 1: 0.6374, Class 2: 0.4222, Class 3: 0.3099, Class 4: 0.5313, Class 5: 0.1434, Class 6: 0.9732, 

Overall Mean Dice Score: 0.3499
Overall Mean F-beta Score: 0.5190

Training Loss: 0.1178, Validation Loss: 0.1600, Validation F-beta: 0.5190
Epoch 38/4000


Training: 100%|██████████| 1440/1440 [32:35<00:00,  1.36s/it, loss=0.103] 
Validation: 100%|██████████| 12/12 [00:12<00:00,  1.07s/it, loss=0.363]


Validation Dice Score
Class 0: 0.9916, Class 1: 0.3901, Class 2: 0.0000, Class 3: 0.1084, Class 4: 0.1529, Class 5: 0.2446, Class 6: 0.3084, 
Validation F-beta Score
Class 0: 0.9977, Class 1: 0.8549, Class 2: 0.5000, Class 3: 0.5197, Class 4: 0.5391, Class 5: 0.4359, Class 6: 0.7387, 

Overall Mean Dice Score: 0.2409
Overall Mean F-beta Score: 0.6177

Training Loss: 0.1181, Validation Loss: 0.1627, Validation F-beta: 0.6177
Epoch 39/4000


Training: 100%|██████████| 1440/1440 [32:37<00:00,  1.36s/it, loss=0.169] 
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.15s/it, loss=0.107] 


Validation Dice Score
Class 0: 0.9928, Class 1: 0.4148, Class 2: 0.0012, Class 3: 0.2803, Class 4: 0.2622, Class 5: 0.1372, Class 6: 0.6135, 
Validation F-beta Score
Class 0: 0.9978, Class 1: 0.6724, Class 2: 0.7507, Class 3: 0.4585, Class 4: 0.6447, Class 5: 0.3733, Class 6: 0.8891, 

Overall Mean Dice Score: 0.3416
Overall Mean F-beta Score: 0.6076

Training Loss: 0.1154, Validation Loss: 0.1463, Validation F-beta: 0.6076
Epoch 40/4000


Training: 100%|██████████| 1440/1440 [32:28<00:00,  1.35s/it, loss=0.065] 
Validation: 100%|██████████| 12/12 [00:13<00:00,  1.14s/it, loss=0.09] 


Validation Dice Score
Class 0: 0.9934, Class 1: 0.6116, Class 2: 0.0000, Class 3: 0.0911, Class 4: 0.2380, Class 5: 0.2159, Class 6: 0.2349, 
Validation F-beta Score
Class 0: 0.9981, Class 1: 0.8212, Class 2: 0.3333, Class 3: 0.4032, Class 4: 0.5343, Class 5: 0.2052, Class 6: 0.9867, 

Overall Mean Dice Score: 0.2783
Overall Mean F-beta Score: 0.5901

Training Loss: 0.1165, Validation Loss: 0.1074, Validation F-beta: 0.5901
SUPER Best model saved. Loss:0.1074, Score:0.5901
Epoch 41/4000


Training:  18%|█▊        | 261/1440 [05:51<26:44,  1.36s/it, loss=0.18]  

In [12]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# VAl

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import TverskyLoss
import torch
import numpy as np
from tqdm import tqdm
import wandb
from src.dataset.dataset import make_val_dataloader

val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1
lamda = 0.52

wandb.init(
    project='czii_SwinUnetR_val',  # 프로젝트 이름 설정
    name='SwinUNETR96_96_lr0.001_lambda0.52_batch2',         # 실행(run) 이름 설정
    config={
        'learning_rate': 0.001,
        'batch_size': batch_size,
        'lambda': lamda,
        'img_size': img_size,
        'device': 'cuda',
        "checkpoint_dir": "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2",
        
    }
)

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

val_loader = make_val_dataloader(
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0
)
criterion = TverskyLoss(
    alpha= 1 - lamda,  # FP에 대한 가중치
    beta=lamda,       # FN에 대한 가중치
    include_background=False,  # 배경 클래스 제외
    softmax=True
)
    
    
from monai.metrics import DiceMetric

img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

val_loss, overall_mean_fbeta_score = validate_one_epoch(
    model=model, 
    val_loader=val_loader, 
    criterion=criterion, 
    device=device, 
    epoch=0, 
    calculate_dice_interval=1
)

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
class_0_dice_score,▁
class_0_f_beta_score,▁
class_1_dice_score,▁
class_1_f_beta_score,▁
class_2_dice_score,▁
class_2_f_beta_score,▁
class_3_dice_score,▁
class_3_f_beta_score,▁
class_4_dice_score,▁
class_4_f_beta_score,▁

0,1
class_0_dice_score,0.65703
class_0_f_beta_score,0.50748
class_1_dice_score,0.53332
class_1_f_beta_score,0.64703
class_2_dice_score,0.00286
class_2_f_beta_score,0.02334
class_3_dice_score,0.23703
class_3_f_beta_score,0.23033
class_4_dice_score,0.65487
class_4_f_beta_score,0.62525


Loading dataset: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]
  checkpoint = torch.load(pretrain_path, map_location=device)
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s, loss=0.865]

Validation Dice Score
Class 0: 0.6570, Class 1: 0.5333, Class 2: 0.0029, Class 3: 0.2370, 
Class 4: 0.6549, Class 5: 0.4790, Class 6: 0.4255, 
Validation F-beta Score
Class 0: 0.5075, Class 1: 0.6470, Class 2: 0.0233, Class 3: 0.2303, 
Class 4: 0.6252, Class 5: 0.5145, Class 6: 0.4720, 
Overall Mean Dice Score: 0.4659
Overall Mean F-beta Score: 0.4978






# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

In [None]:
val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=1, 
            calculate_dice_interval=0
        )

Validation:   0%|          | 0/4 [00:03<?, ?it/s, loss=0.764]


ZeroDivisionError: integer modulo by zero

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv
