In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

import random
import numpy as np
import torch
torch.backends.cudnn.benchmark = True

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)
from monai.networks.layers.factories import Act, Norm
from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNet

from src.models import UNet_CBAM

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

# 랜덤 시드 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

from src.models import UNet_CBAM
print_config()

import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.losses import TverskyLoss

# DynamicTverskyLoss 클래스 정의
class DynamicTverskyLoss(TverskyLoss):
    def __init__(self, lamda=0.5, **kwargs):
        super().__init__(alpha=1 - lamda, beta=lamda, **kwargs)
        self.lamda = lamda

    def set_lamda(self, lamda):
        self.lamda = lamda
        self.alpha = 1 - lamda
        self.beta = lamda


import torch
import torch.nn as nn

class CombinedCETverskyLoss(nn.Module):
    def __init__(self, lamda=0.5, ce_weight=0.5, n_classes=7, class_weights=None, ignore_index=-1, **kwargs):
        super().__init__()
        self.n_classes = n_classes
        self.ce_weight = ce_weight
        self.ignore_index = ignore_index
        
        # CrossEntropyLoss에서 클래스별 가중치를 적용
        self.ce = nn.CrossEntropyLoss(weight=class_weights, ignore_index=self.ignore_index, reduction='mean', **kwargs)
        
        # TverskyLoss
        self.tversky = DynamicTverskyLoss(lamda=lamda, reduction="mean",softmax=True, **kwargs)

    def forward(self, inputs, targets):
        
        # CrossEntropyLoss는 정수형 클래스 인덱스를 사용
        ce_loss = self.ce(inputs, targets)

        # TverskyLoss 계산 (원핫 인코딩된 라벨을 사용)
        
        tversky_loss = self.tversky(inputs, targets)

        # 최종 손실 계산
        final_loss = self.ce_weight * ce_loss + (1 - self.ce_weight) * tversky_loss
        return final_loss

    def set_lamda(self, lamda):
        self.tversky.set_lamda(lamda)

    @property
    def lamda(self):
        return self.tversky.lamda


  from .autonotebook import tqdm as notebook_tqdm


MONAI version: 1.4.0
Numpy version: 1.26.3
Pytorch version: 2.5.1+cu124
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: c:\ProgramData\anaconda3\envs\czii\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.25.0
scipy version: 1.15.1
Pillow version: 10.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.20.1+cu124
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.1.1
pandas version: 2.2.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For d

In [2]:
class_info = {
    0: {"name": "background", "weight": 0},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 1000},
    2: {"name": "beta-amylase", "weight": 100}, # 4130
    3: {"name": "beta-galactosidase", "weight": 1500}, #3080
    4: {"name": "ribosome", "weight": 1000},
    5: {"name": "thyroglobulin", "weight": 1500},
    6: {"name": "virus-like-particle", "weight": 1000},
}

# 가중치에 비례한 비율 계산
raw_ratios = {
    k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값a
    for k, v in class_info.items()
}
total = sum(raw_ratios.values())
ratios = {k: v / total for k, v in raw_ratios.items()}

# 최종 합계가 1인지 확인
final_total = sum(ratios.values())
print("클래스 비율:", ratios)
print("최종 합계:", final_total)

# 비율을 리스트로 변환
ratios_list = [ratios[k] for k in sorted(ratios.keys())]
print("클래스 비율 리스트:", ratios_list)

클래스 비율: {0: 0.0, 1: 0.16393442622950818, 2: 0.01639344262295082, 3: 0.2459016393442623, 4: 0.16393442622950818, 5: 0.2459016393442623, 6: 0.16393442622950818}
최종 합계: 1.0
클래스 비율 리스트: [0.0, 0.16393442622950818, 0.01639344262295082, 0.2459016393442623, 0.16393442622950818, 0.2459016393442623, 0.16393442622950818]


# 모델 선언

In [3]:
from __future__ import annotations

import warnings
from collections.abc import Sequence

# 불필요한 import 제거
# from monai.networks.blocks.convolutions import Convolution, ResidualUnit
# from monai.networks.layers.simplelayers import SkipConnection
from monai.networks.layers.factories import Act, Norm

from src.models.unet_block import Encoder, Decoder, get_conv_layer
from src.models.cbam import CBAM3D

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn

class LayerNorm3d(nn.Module):
    """
    3D 입력 (N, C, D, H, W)에 대해 layer normalization을 수행하는 모듈입니다.
    
    정규화는 각 배치 샘플의 모든 채널 및 공간 차원에 대해 진행되며,
    공식은 다음과 같습니다.
    
        y = (x - μ) / sqrt(σ^2 + ε) * γ + β
        
    여기서,
        - μ: 입력 x의 (C, D, H, W) 차원에 대한 평균
        - σ^2: 입력 x의 (C, D, H, W) 차원에 대한 분산 (비편향 추정)
        - ε: 수치 안정성을 위한 작은 상수 (default: 1e-5)
        - γ, β: 학습 가능한 scale과 shift 파라미터 (elementwise_affine=True인 경우)
        
    Args:
        num_channels (int): 입력 텐서의 채널 수. γ와 β의 크기를 결정합니다.
        eps (float): 분산 계산 시 분모 안정화를 위한 상수. (default: 1e-5)
        elementwise_affine (bool): True이면 γ와 β를 학습 가능한 파라미터로 사용합니다. (default: True)
    """
    def __init__(self, num_channels, eps=1e-5, elementwise_affine=True):
        super(LayerNorm3d, self).__init__()
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        
        if self.elementwise_affine:
            # γ와 β는 채널마다 다르게 적용되도록 (1, C, 1, 1, 1) 크기로 초기화합니다.
            self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1, 1))
            self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
    
    def forward(self, x):
        """
        x: 입력 텐서, shape는 (N, C, D, H, W)로 가정합니다.
        """
        # 배치 차원을 제외한 모든 차원(C, D, H, W)에 대해 평균과 분산 계산
        mean = x.mean(dim=[1, 2, 3, 4], keepdim=True)
        var = x.var(dim=[1, 2, 3, 4], keepdim=True, unbiased=False)
        
        # 정규화 수행
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        
        # elementwise_affine가 True이면 γ와 β를 곱해줍니다.
        if self.elementwise_affine:
            x_norm = x_norm * self.weight + self.bias
        
        return x_norm


    
from collections.abc import Sequence

def autopad(kernel_size: int | Sequence[int], padding: int | Sequence[int] | None = None) -> int | list[int]:
    """
    padding이 None일 경우, kernel_size에 기반해 동일한 출력 shape을 만들기 위한 패딩 크기를 반환합니다.
    보통 커널 사이즈가 홀수일 때 kernel_size//2로 설정합니다.
    
    Args:
        kernel_size (int 또는 Sequence[int]): 커널 사이즈.
        padding (int 또는 Sequence[int] 또는 None): 이미 지정된 패딩 값. None이면 자동으로 계산합니다.
        
    Returns:
        계산된 padding 값.
    """
    if padding is None:
        if isinstance(kernel_size, int):
            return kernel_size // 2
        elif isinstance(kernel_size, Sequence):
            return [k // 2 for k in kernel_size]
        else:
            raise TypeError("kernel_size must be int or sequence of ints")
    else:
        return padding

class Encoder(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Sequence[int] | int,
        stride: int,
        dropout: tuple | str | float | None = None,
    ):
        super().__init__()
        self.conv = nn.Conv3d(
            in_channels=in_channels,
            out_channels = out_channels,
            kernel_size= kernel_size,
            stride = stride,
            padding = autopad(kernel_size)
        )
        self.norm1 = nn.InstanceNorm3d(out_channels)
        self.act1 = nn.PReLU(out_channels)
        if dropout is not None:
            self.dropout1 = nn.Dropout3d(dropout)
    def forward(self, x):
        x = self.conv(x)
        x = self.norm1(x)
        if hasattr(self, "dropout1"):
            x = self.dropout1(x)
        x = self.act1(x)
        return x
        

class c_Decoder(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Sequence[int] | int,
        stride: int,
        dropout: tuple | str | float | None = None,
        conv_only = False
        
    ):
        super().__init__()
        self.conv1 = nn.ConvTranspose3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=autopad(kernel_size),
            output_padding=1 if kernel_size // 2 == 1 else 0,
            bias=False,
        )
        if not conv_only:
            self.norm1 = nn.InstanceNorm3d(out_channels)
            self.act1 = nn.PReLU(out_channels)
            if dropout is not None:
                self.dropout1 = nn.Dropout3d(dropout)
        self.cbam = CBAM3D(channels=out_channels, reduction=8, spatial_kernel_size=3)

    def forward(self, x, skip):
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        
        if not hasattr(self, "norm1"):  # conv_only인 경우 norm1, act1, dropout1이 없을 것임
            pass
        else:
            x = self.norm1(x)
            x = self.act1(x)
            if hasattr(self, "dropout1"):
                x = self.dropout1(x)

        return self.cbam(x)  # norm, act 후 CBAM 적용



In [4]:
# import warnings
# from collections.abc import Sequence
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np

# from monai.networks.blocks.convolutions import Convolution
# from monai.networks.layers.factories import Act, Norm

# # ---------------------------
# # 1) LayerNorm3D (옵션)
# # ---------------------------
# class LayerNorm3D(nn.Module):
#     def __init__(self, num_channels: int):
#         super().__init__()
#         self.layer_norm = nn.LayerNorm(num_channels)

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         N, C, *spatial_dims = x.shape
#         x = x.permute(0, 2, 3, 4, 1).contiguous()
#         x = self.layer_norm(x)
#         x = x.permute(0, 4, 1, 2, 3).contiguous()
#         return x


# # ---------------------------
# # 2) 헬퍼 함수: padding 계산
# # ---------------------------
# def get_padding(kernel_size: Sequence[int] | int, stride: Sequence[int] | int):
#     kernel_size_np = np.atleast_1d(kernel_size)
#     stride_np = np.atleast_1d(stride)
#     padding_np = (kernel_size_np - stride_np + 1) / 2
#     if np.min(padding_np) < 0:
#         raise AssertionError("padding must not be negative.")
#     padding = tuple(int(p) for p in padding_np)
#     return padding if len(padding) > 1 else padding[0]

# def get_output_padding(kernel_size, stride, padding):
#     kernel_size_np = np.atleast_1d(kernel_size)
#     stride_np = np.atleast_1d(stride)
#     padding_np = np.atleast_1d(padding)
#     out_padding_np = 2 * padding_np + stride_np - kernel_size_np
#     if np.min(out_padding_np) < 0:
#         raise AssertionError("out_padding must not be negative.")
#     out_padding = tuple(int(p) for p in out_padding_np)
#     return out_padding if len(out_padding) > 1 else out_padding[0]


# # ---------------------------
# # 3) get_conv_layer
# # ---------------------------
# def get_conv_layer(
#     spatial_dims: int,
#     in_channels: int,
#     out_channels: int,
#     kernel_size: int | Sequence[int] = 3,
#     stride: int | Sequence[int] = 1,
#     act: tuple | str | None = Act.PRELU,
#     norm: tuple | str | None = Norm.INSTANCE,
#     dropout: float | None = 0.0,
#     bias: bool = True,
#     conv_only: bool = False,
#     is_transposed: bool = False,
# ):
#     padding = get_padding(kernel_size, stride)
#     output_padding = None
#     if is_transposed:
#         output_padding = get_output_padding(kernel_size, stride, padding)

#     return Convolution(
#         spatial_dims=spatial_dims,
#         in_channels=in_channels,
#         out_channels=out_channels,
#         strides=stride,
#         kernel_size=kernel_size,
#         act=act,
#         norm=norm,
#         dropout=dropout,
#         bias=bias,
#         conv_only=conv_only,
#         is_transposed=is_transposed,
#         padding=padding,
#         output_padding=output_padding,
#     )


# # ---------------------------
# # 4) ResizeLayer(nn.Upsample)
# # ---------------------------
# class ResizeLayer(nn.Module):
#     def __init__(self, mode="trilinear", align_corners=False):
#         super().__init__()
#         self.mode = mode
#         self.align_corners = align_corners

#     def forward(self, x: torch.Tensor, size: tuple[int, ...]) -> torch.Tensor:
#         up = nn.Upsample(size=size, mode=self.mode, align_corners=self.align_corners)
#         return up(x)


# # ---------------------------
# # 5) SkipAlign
# # ---------------------------
# class SkipAlign(nn.Module):
#     def __init__(
#         self,
#         skip_in_channels: int,
#         out_channels: int,
#         spatial_dims: int = 3,
#         channel_match: bool = True,
#         mode: str = "trilinear",
#         align_corners: bool = False,
#     ):
#         super().__init__()
#         self.spatial_dims = spatial_dims
#         self.channel_match = channel_match

#         self.upsample = ResizeLayer(mode=mode, align_corners=align_corners)

#         if channel_match and (skip_in_channels != out_channels):
#             if spatial_dims == 3:
#                 self.conv1x1 = nn.Conv3d(skip_in_channels, out_channels, kernel_size=1, bias=True)
#             else:
#                 self.conv1x1 = nn.Conv2d(skip_in_channels, out_channels, kernel_size=1, bias=True)
#         else:
#             self.conv1x1 = None

#     def forward(self, skip_tensor: torch.Tensor, target_tensor: torch.Tensor) -> torch.Tensor:
#         device = target_tensor.device
#         dtype = target_tensor.dtype
#         skip_tensor = skip_tensor.to(device=device, dtype=dtype)

#         # 업샘플
#         D_out, H_out, W_out = target_tensor.shape[2:]
#         if skip_tensor.shape[2:] != (D_out, H_out, W_out):
#             skip_tensor = self.upsample(skip_tensor, (D_out, H_out, W_out))

#         # 1×1 Conv
#         if self.conv1x1 is not None:
#             skip_tensor = self.conv1x1(skip_tensor)

#         return skip_tensor


# # ---------------------------
# # 6) build_conv_stack
# # ---------------------------
# def build_conv_stack(
#     spatial_dims: int,
#     in_channels: int,
#     out_channels: int,
#     num_layers: int,
#     kernel_size: int | Sequence[int],
#     stride: int | Sequence[int],
#     act: tuple | str | None,
#     norm: tuple | str | None,
#     dropout: float,
#     bias: bool,
#     is_transposed: bool = False,
#     use_cbam: bool = False,
# ):
#     layers = []
#     for i in range(num_layers):
#         if i == 0:
#             layers.append(
#                 get_conv_layer(
#                     spatial_dims=spatial_dims,
#                     in_channels=in_channels,
#                     out_channels=out_channels,
#                     kernel_size=kernel_size,
#                     stride=stride,
#                     act=act,
#                     norm=norm,
#                     dropout=dropout,
#                     bias=bias,
#                     conv_only=False,
#                     is_transposed=is_transposed,
#                 )
#             )
#         else:
#             layers.append(
#                 get_conv_layer(
#                     spatial_dims=spatial_dims,
#                     in_channels=out_channels,
#                     out_channels=out_channels,
#                     kernel_size=kernel_size,
#                     stride=1,
#                     act=act,
#                     norm=norm,
#                     dropout=dropout,
#                     bias=bias,
#                     conv_only=False,
#                     is_transposed=False,
#                 )
#             )
#         if use_cbam:
#             layers.append(CBAM3D(channels=out_channels, reduction=8, spatial_kernel_size=3))
#     return nn.Sequential(*layers)


# # ---------------------------
# # 7) SingleEncoderBlock
# # ---------------------------
# class SingleEncoderBlock(nn.Module):
#     def __init__(
#         self,
#         spatial_dims: int,
#         in_channels: int,
#         out_channels: int,
#         num_layers: int,
#         kernel_size: int | Sequence[int],
#         stride: int | Sequence[int],
#         act: tuple | str | None,
#         norm: tuple | str | None,
#         dropout: float,
#         bias: bool = True,
#         use_cbam: bool = False,
#     ):
#         super().__init__()
#         self.stack = build_conv_stack(
#             spatial_dims=spatial_dims,
#             in_channels=in_channels,
#             out_channels=out_channels,
#             num_layers=num_layers,
#             kernel_size=kernel_size,
#             stride=stride,
#             act=act,
#             norm=norm,
#             dropout=dropout,
#             bias=bias,
#             is_transposed=False,
#             use_cbam=use_cbam,
#         )

#     def forward(self, x):
#         return self.stack(x)


# # ---------------------------
# # 8) SingleDecoderBlock
# # ---------------------------
# class SingleDecoderBlock(nn.Module):
#     def __init__(
#         self,
#         spatial_dims: int,
#         main_in_channels: int,
#         core_channels: int,
#         out_channels: int,
#         skip_in_channels_list: list[int],
#         num_layers: int,
#         kernel_size: int | Sequence[int],
#         stride: int | Sequence[int],
#         act: tuple | str | None,
#         norm: tuple | str | None,
#         dropout: float,
#         bias: bool,
#         mode: str = "trilinear",
#         align_corners: bool = False,
#         use_cbam: bool = True,
#     ):
#         super().__init__()
#         self.spatial_dims = spatial_dims
#         self.out_channels = out_channels
#         self.skip_count = len(skip_in_channels_list)
        
#         # (1) Main input channel aligner (1x1 conv로 core_channels로 맞춤)
#         if spatial_dims == 3:
#             self.main_aligner = nn.Conv3d(main_in_channels, core_channels, kernel_size=1, bias=True)
#         else:
#             self.main_aligner = nn.Conv2d(main_in_channels, core_channels, kernel_size=1, bias=True)
        
#         # (2) Skip aligners (채널 매칭을 위한)
#         self.skip_aligners = nn.ModuleList()
#         for s_in_ch in skip_in_channels_list:
#             aligner = SkipAlign(
#                 skip_in_channels=s_in_ch,
#                 out_channels=core_channels,
#                 spatial_dims=spatial_dims,
#                 channel_match=True,
#                 mode=mode,
#                 align_corners=align_corners,
#             )
#             self.skip_aligners.append(aligner)

#         # (3) Main conv stack with transposed conv
#         total_in_channels = core_channels * (1 + self.skip_count)  # main(core_ch) + skips(core_ch each)
        
        
#         self.conv_stack = build_conv_stack(
#             spatial_dims=spatial_dims,
#             in_channels=total_in_channels,
#             out_channels=out_channels,
#             num_layers=num_layers,  # Already used one layer for upsampling
#             kernel_size=kernel_size,
#             stride=stride,
#             act=act,
#             norm=norm,
#             dropout=dropout,
#             bias=bias,
#             is_transposed=True,
#             use_cbam=use_cbam,
#         )

#     def forward(self, x_main: torch.Tensor, skip_tensors: list[torch.Tensor]) -> torch.Tensor:
#         # (1) Main input을 core_channels로 변환
#         x_main = self.main_aligner(x_main)  # (N, core_channels, ...)
        
#         # (2) 모든 skip connection을 현재 입력 크기로 맞춤
#         aligned_skips = []
        
#         for i, s in enumerate(skip_tensors):
#             aligned_s = self.skip_aligners[i](s, x_main)  # (N, core_channels, ...)
#             aligned_skips.append(aligned_s)
            
        
#         # (3) Concatenate main input with aligned skips
#         cat_list = [x_main] + aligned_skips
#         cat_input = torch.cat(cat_list, dim=1)  # (N, core_channels * (1 + skip_count), ...)
        
#         # (4) Apply transposed conv for upsampling
#         out = self.conv_stack(cat_input)
    
#         return out


# # ---------------------------
# # 9) FlexibleUNet
# # ---------------------------
# class FlexibleUNet(nn.Module):
#     """
#     디코더 간 스킵 연결:
#       skip_connections = {
#          dec_idx: [
#            ("enc", enc_i),  # 인코더 레벨 enc_i
#            ("dec", dec_j),  # 디코더 레벨 dec_j
#          ],
#          ...
#       }
#     """
#     def __init__(
#         self,
#         spatial_dims: int = 3,
#         in_channels: int = 1,
#         out_channels: int = 2,
#         encoder_channels: Sequence[int] = (32, 64, 128, 256),
#         encoder_strides: Sequence[int] = (2, 2, 2),
#         core_channels: int = 64,
#         decoder_channels: Sequence[int] = (128, 64, 32),
#         decoder_strides: Sequence[int] = (2, 2, 2),
#         num_layers_encoder: Sequence[int] = (1, 1, 1, 1),
#         num_layers_decoder: Sequence[int] = (1, 1, 1),
#         skip_connections: dict[int, list[tuple[str, int]]] | None = None,
#         kernel_size: int | Sequence[int] = 3,
#         up_kernel_size: int | Sequence[int] = 3,
#         act: tuple | str = Act.PRELU,
#         norm: tuple | str = Norm.INSTANCE,
#         dropout: float = 0.0,
#         bias: bool = True,
#         mode: str = "trilinear",
#         align_corners: bool = False,
#         encoder_use_cbam: bool = True,
#         decoder_use_cbam: bool = True
#     ):
#         """
#         skip_connections: {
#           decoder_index: [("enc", i), ("dec", j), ...],
#           ...
#         }
#         """
#         super().__init__()
#         if len(encoder_channels) != len(num_layers_encoder):
#             raise ValueError("encoder_channels와 num_layers_encoder 길이가 맞지 않습니다.")
#         if len(encoder_strides) != len(encoder_channels) - 1:
#             raise ValueError("encoder_strides 길이는 (len(encoder_channels) - 1)이어야 합니다.")
#         if len(decoder_channels) != len(num_layers_decoder):
#             raise ValueError("decoder_channels와 num_layers_decoder 길이가 맞지 않습니다.")
#         if len(decoder_strides) != len(decoder_channels):
#             raise ValueError("decoder_strides 길이는 len(decoder_channels)와 같아야 합니다.")

#         self.spatial_dims = spatial_dims
#         self.skip_connections = skip_connections if skip_connections else {}

#         # ---------------------- 인코더 구성 ----------------------
#         self.encoder_blocks = nn.ModuleList()
#         prev_ch = in_channels
#         for i, out_ch in enumerate(encoder_channels):
#             stride = encoder_strides[i] if i < len(encoder_strides) else 1
#             block = SingleEncoderBlock(
#                 spatial_dims=spatial_dims,
#                 in_channels=prev_ch,
#                 out_channels=out_ch,
#                 num_layers=num_layers_encoder[i],
#                 kernel_size=kernel_size,
#                 stride=stride,
#                 act=act,
#                 norm=norm,
#                 dropout=dropout,
#                 bias=bias,
#                 use_cbam=encoder_use_cbam,
#             )
#             self.encoder_blocks.append(block)
#             prev_ch = out_ch

#         # ---------------------- 디코더 구성 ----------------------
#         self.decoder_blocks = nn.ModuleList()
#         main_in_ch = encoder_channels[-1]  # bottleneck
#         for dec_i in range(len(decoder_channels)):
#             out_ch = decoder_channels[dec_i]
#             stride = decoder_strides[dec_i]
#             nlayer = num_layers_decoder[dec_i]

#             # skip 인덱스 -> skip_in_channels
#             # "enc" -> encoder_channels[idx], "dec" -> decoder_channels[idx]
#             skip_info_list = skip_connections.get(dec_i, [])
#             skip_in_channels_list = []
#             for (typ, idx) in skip_info_list:
#                 if typ == "enc":
#                     skip_in_channels_list.append(encoder_channels[idx])
#                 elif typ == "dec":
#                     # 디코더 레벨 idx의 출력 채널
#                     skip_in_channels_list.append(decoder_channels[idx])
#                 else:
#                     raise ValueError(f"Invalid skip type: {typ}, must be 'enc' or 'dec'.")
            
#             block = SingleDecoderBlock(
#                 spatial_dims=spatial_dims,
#                 main_in_channels=main_in_ch,
#                 out_channels=out_ch,
#                 core_channels=core_channels,
#                 skip_in_channels_list=skip_in_channels_list,
#                 num_layers=nlayer,
#                 kernel_size=up_kernel_size,
#                 stride=stride,
#                 act=act,
#                 norm=norm,
#                 dropout=dropout,
#                 bias=bias,
#                 mode=mode,
#                 align_corners=align_corners,
#                 use_cbam=decoder_use_cbam,
#             )
#             self.decoder_blocks.append(block)
#             main_in_ch = out_ch

#         # 최종 Conv
#         self.final_conv = get_conv_layer(
#             spatial_dims=spatial_dims,
#             in_channels=decoder_channels[-1],
#             out_channels=out_channels,
#             kernel_size=1,
#             stride=1,
#             norm=None,
#             act=None,
#             dropout=0.0,
#             bias=True,
#             conv_only=True,
#             is_transposed=False,
#         )

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         # device/dtype
#         device = next(self.parameters()).device
#         dtype = next(self.parameters()).dtype
#         x = x.to(device=device, dtype=dtype)

#         # 1) 인코더
#         encoder_outputs = []
#         out = x
#         for enc in self.encoder_blocks:
#             out = enc(out)
#             encoder_outputs.append(out)

#         # 2) 디코더
#         decoder_outputs = []
#         # out = encoder_outputs[-1]  # bottleneck
#         # decoder_outputs.append(out)  # 0번 디코더가 시작하기 전(bottleneck)에 쓸 수도 있지만, 여기선 i=0부터 맞춰줄 수도 있음.

#         for dec_i, dec_block in enumerate(self.decoder_blocks):
#             # skip에 "enc" => encoder_outputs, "dec" => decoder_outputs
#             skip_info_list = self.skip_connections.get(dec_i, [])
#             skip_list = []
#             for (typ, idx) in skip_info_list:
#                 if typ == "enc":
#                     skip_list.append(encoder_outputs[idx])
#                 elif typ == "dec":
#                     # 디코더 idx는 0.. dec_i-1 범위여야함
#                     skip_list.append(decoder_outputs[idx])
#                 else:
#                     raise ValueError(f"Invalid skip type: {typ}.")

#             out = dec_block(out, skip_list)
#             # 디코더 i번 블록 결과를 decoder_outputs에 저장
#             decoder_outputs.append(out)

#         # 3) 최종 Conv
#         out = self.final_conv(out)
#         return out


# # ---------------------------
# # 10) 테스트
# # ---------------------------
# # if __name__ == "__main__":
# #     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# #     enc_channels = (32, 64, 128, 256)
# #     enc_strides = (2, 2, 2)
# #     num_layers_enc = (1, 1, 1, 1)

# #     core_channels = 64
# #     dec_channels = (256, 256, 256)
# #     dec_strides = (2, 2, 2)
# #     num_layers_dec = (1, 1, 1)

# #     # 디코더간 스킵 예시:
# #     #   디코더0: ("enc",2), ("dec", ?) -- 가능하지만 보통 dec_? < dec_0
# #     #   디코더1: ("enc",1), ("dec",0)
# #     #   디코더2: ("enc",0), ("dec",1)
# #     # 반드시 "dec", j => j < i 여야함
# #     skip_map = {
# #         0: [("enc", 2), ("enc", 0), ("enc", 1)],       # 디코더0 => 인코더2
# #         1: [("enc", 3), ("enc", 0), ("enc", 1)],  # 디코더1 => 인코더1 + 디코더0
# #         2: [("enc", 3), ("dec", 0), ("enc", 0)]   # 디코더2 => 인코더0 + 디코더1
# #     }

# #     # net = FlexibleUNet(
# #     #     spatial_dims=3,
# #     #     in_channels=1,
# #     #     out_channels=2,
# #     #     encoder_channels=enc_channels,
# #     #     encoder_strides=enc_strides,
# #     #     core_channels=core_channels,
# #     #     decoder_channels=dec_channels,
# #     #     decoder_strides=dec_strides,
# #     #     num_layers_encoder=num_layers_enc,
# #     #     num_layers_decoder=num_layers_dec,
# #     #     skip_connections=skip_map,
# #     #     kernel_size=3,
# #     #     up_kernel_size=3,
# #     #     dropout=0.0,
# #     #     bias=True,
# #     #     mode="trilinear",
# #     #     align_corners=False,
# #     # ).to(device)

# #     # x = torch.randn((1, 1, 64, 64, 32), device=device)
# #     # with torch.no_grad():
# #     #     out = net(x)
# #     #     print(out.shape)

In [5]:
class Unet_CBAM_ds(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Sequence[int] | int = 3,
        up_kernel_size: Sequence[int] | int = 3,

        dropout: float = 0.0,
        bias: bool = True,
        use_supervision: bool = False,

    ) -> None:
        super().__init__()

        if len(channels) < 2:
            raise ValueError("the length of `channels` should be no less than 2.")
        
        # 기존 코드와 동일한 검사
        delta = len(strides) - (len(channels) - 1)
        if delta < 0:
            raise ValueError("the length of `strides` should equal `len(channels) - 1`.")
        if delta > 0:
            warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.")
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.channels = channels
        self.strides = strides
        self.kernel_size = kernel_size
        self.up_kernel_size = up_kernel_size

        self.dropout = dropout
        self.bias = bias


        # ---------------------
        # Encoder
        # ---------------------
        self.encoder1 = Encoder(
            in_channels,
            channels[0],
            kernel_size,
            strides[0],
            dropout,
        )
        self.encoder2 = Encoder(
            channels[0],
            channels[1],
            kernel_size,
            strides[1],
            dropout,
        )
        self.encoder3 = Encoder(
            channels[1],
            channels[2],
            kernel_size,
            strides[2],
            dropout,
        )
        # encoder4는 strides[3]를 사용할 수 있도록 strides에 4개 값을 넣어주거나, 아래처럼 stride=1로 따로 설정 가능
        # 여기서는 strides에 4개 값을 넣어준다고 가정함
        self.bottleneck = Encoder(
            channels[2],
            channels[3],
            kernel_size,
            1, 
            dropout,
        )

        # self.cbam = CBAM3D(channels=channels[3], reduction=8, spatial_kernel_size=3)
        # ---------------------
        # Decoder
        # ---------------------
        self.decoder3 = c_Decoder(
            channels[3] + channels[2],
            channels[1],
            up_kernel_size,
            strides[2],
            dropout,
        )
        self.decoder2 = c_Decoder(
            channels[1] + channels[1],
            channels[0],
            up_kernel_size,
            strides[1],
            dropout,
        )
        self.decoder1 = c_Decoder(
            channels[0] + channels[0],
            out_channels,
            up_kernel_size,
            strides[0],
            conv_only=True,
        )
        if use_supervision:
            self.supervision2 = nn.Conv3d(channels[0], out_channels, kernel_size=1, stride=1, padding=0)
            self.supervision3 = nn.Conv3d(channels[1], out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        x1 = self.encoder1(x)
        x2 = self.encoder2(x1)
        x3 = self.encoder3(x2)
        
        x = self.bottleneck(x3)

        # Decoder: 이전 decoder 출력과 skip connection을 제대로 연결
        x = self.decoder3(x, x3)   # bottleneck + encoder3
        x = self.decoder2(x, x2)    # decoder3 출력 + encoder2
        x = self.decoder1(x, x1)    # decoder2 출력 + encoder1
        if hasattr(self, "supervision2"):
            out2 = self.supervision2(x1)
            out2 = F.interpolate(out2, x.shape[2:], mode="trilinear", align_corners=False)
            out3 = self.supervision3(x2)
            out3 = F.interpolate(out3, x.shape[2:], mode="trilinear", align_corners=False)
            return [x, out2, out3]
        return x

    
    
# 예제 사용법:
# if __name__ == "__main__":
#     # 임의의 3D 데이터: 배치 크기 2, 채널 4, 깊이 8, 높이 16, 너비 16
#     print("LayerNorm3d 사용 예제")
#     input_tensor = torch.randn(2, 1, 32, 96, 96)
#     layer_norm3d = LayerNorm3d(num_channels=1)
#     output = layer_norm3d(input_tensor)
#     print("입력 텐서 shape:", input_tensor.shape)
#     print("출력 텐서 shape:", output.shape)
    
#     # Encoder, Decoder, Unet_CBAM_LAYERNORM 사용 예제
#     print("\nEncoder, Decoder, Unet_CBAM_LAYERNORM 사용 예제")
#     model = Unet_CBAM_ds(
#         in_channels=1,
#         out_channels=4,
#         channels=[16, 32, 64, 128],
#         strides=[2, 2, 2, 2],
#         kernel_size=3,
#         up_kernel_size=3,
#         dropout=0.1,
#         use_supervision=True,
#     )
    
#     output,ds2,ds3 = model(input_tensor)
#     print("입력 텐서 shape:", input_tensor.shape)
#     print("출력 텐서 shape:", output.shape)
#     print("출력 텐서 shape:", ds2.shape)
#     print("출력 텐서 shape:", ds3.shape)
    

In [6]:
from thop import profile
import torch
from torch.profiler import ProfilerActivity
from torch.profiler import profile as profilee
    

def print_model_summary(model, input_size):
    input_tensor = torch.randn(input_size)
    device = next(model.parameters()).device
    input_tensor = input_tensor.to(device)
    flops, params = profile(model, inputs=(input_tensor,))

    print(f"Model: {model.__class__.__name__}")
    print(f"FLOPs: {flops:,}, GFLOPs: {flops / 1e9:.2f}")
    print(f"Parameters: {params:,}")
    print("-" * 50)

def profile_model(model, input_size, log_dir='./log'):

    input_tensor = torch.randn(input_size)
    device = next(model.parameters()).device
    input_tensor = input_tensor.to(device)
    # 프로파일링
    with profilee(
        activities=[
            ProfilerActivity.CPU, 
            ProfilerActivity.CUDA,
        ],
        on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir),  # TensorBoard 연동
        record_shapes=True,
        with_stack=True
    ) as prof:
        model(input_tensor)

    # 프로파일링 결과 출력
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# 모델 설정

In [7]:
from src.dataset.dataset import create_dataloaders
from src.dataset.dataset_csv import create_dataloaders_from_csv
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd, RandCropd,RandCropByPosNegLabeld, RandGaussianSmoothd, RandCoarseDropoutd
)
from monai.transforms import CastToTyped
import numpy as np

train_csv = "./datasets/aug_train.csv"
val_csv = "./datasets/val_647.csv"
# DATA CONFIG
img_size =  96 # Match your patch size
img_depth = 32
n_classes = 7
batch_size = 32 # 13.8GB GPU memory required for 128x128 img size
loader_batch = 1
num_samples = batch_size // loader_batch # 한 이미지에서 뽑을 샘플 수
num_repeat = 3
val_num_repeat = 20
# MODEL CONFIG
num_epochs = 4000
lamda = 0.5
ce_weight = 0.4
lr = 0.001
feature_size = [32, 64, 128, 256]
dropout= 0.25
use_ds = True
# CLASS_WEIGHTS
class_weights = None
# class_weights = torch.tensor([0.0001, 1, 0.001, 1.1, 1, 1.1, 1], dtype=torch.float32)  # 클래스별 가중치
sigma = 1.5


accumulation_steps = 1
# INIT
start_epoch = 0
best_val_loss = float('inf')
best_val_fbeta_score = 0

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    # GaussianSmoothd(
    #     keys=["image"],      # 변환을 적용할 키
    #     sigma=[sigma, sigma, sigma]  # 각 축(x, y, z)의 시그마 값
    #     ),
])
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandCoarseDropoutd(
        keys=["image"],
        holes=3,
        spatial_size=(4,4,4),
        prob=0.5,
        fill_value=255,
    ),
    
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    # RandGaussianSmoothd(
    # keys=["image"],      # 변환을 적용할 키
    # sigma_x = (0.5, sigma), # 각 축(x, y, z)의 시그마 값
    # sigma_y = (0.5, sigma),
    # sigma_z = (0.5, sigma),
    # prob=0.5,
    # ),
])
val_random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    # RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    # RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    # RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    # RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    # RandGaussianSmoothd(
    # keys=["image"],      # 변환을 적용할 키
    # sigma_x = (0.0, sigma), # 각 축(x, y, z)의 시그마 값
    # sigma_y = (0.0, sigma),
    # sigma_z = (0.0, sigma),
    # prob=1.0,
    # ),
])


In [8]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders_from_csv(
    train_csv,
    val_csv, 
    train_non_random_transforms = non_random_transforms, 
    val_non_random_transforms=non_random_transforms,
    train_random_transforms=random_transforms,
    val_random_transforms=val_random_transforms,
    batch_size = loader_batch,
    num_workers=0,train_num_repeat=num_repeat, val_num_repeat=val_num_repeat
    )

(184, 630, 630)
(184, 630, 630)


https://monai.io/model-zoo.html

# 모델 선언

In [9]:
import torch
import torch.optim as optim
from pathlib import Path


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet_CBAM(
    spatial_dims=3,
    in_channels=1,
    out_channels=n_classes,
    channels=feature_size,
    strides=[2, 2, 2, 2],
    kernel_size=3,
    up_kernel_size=3,
    dropout=dropout,
    # use_supervision=use_ds,
    ).to(device)
# x = (batch_size, 1, img_depth, img_size, img_size)
# print_model_summary(model, x)
# profile_model(model, x, './log')

criterion = CombinedCETverskyLoss(
    lamda=lamda,
    ce_weight=ce_weight,
    n_classes=n_classes,
    
).to(device)


weight_str = "weighted" if class_weights is not None else ""

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
folder_name = f"UNetCBAM_aug_hole_maxf{feature_size[-1]}_{img_depth}x{img_size}x{img_size}_e{num_epochs}_lr{lr}_lamda{lamda}_ce{ce_weight}_{weight_str}"
checkpoint_dir = checkpoint_base_dir / folder_name
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
                checkpoint= None
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")

기존 best model 발견: model_checkpoints\UNetCBAM_aug_hole_maxf256_32x96x96_e4000_lr0.001_lamda0.5_ce0.4_\best_model.pt

  checkpoint = torch.load(best_model_path, map_location=device)



기존 학습된 가중치를 성공적으로 로드했습니다.


In [10]:
# batch = next(iter(val_loader))
# images, labels = batch["image"], batch["label"]
# print(images.shape, labels.shape)

In [12]:
import wandb
from datetime import datetime

current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = folder_name

# wandb 초기화
wandb.init(
    project='czii_SwinUnetR',  # 프로젝트 이름 설정
    name=run_name,         # 실행(run) 이름 설정
    config={
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'batch_size': batch_size,
        'lambda': lamda,
        "cross_entropy_weight": ce_weight,
        'img_depth': img_depth,
        'img_size': img_size,
        'sampling_ratio': ratios_list,
        'device': device.type,
        "checkpoint_dir": str(checkpoint_dir),
        "class_weights": class_weights.tolist() if class_weights is not None else None,
        "feature_size": feature_size,
        "deep_supervision": use_ds,
        "dropout": dropout,        
        "accumulation_steps": accumulation_steps,
        "num_repeat": num_repeat,
        
        # 필요한 하이퍼파라미터 추가
    }
)
# 모델을 wandb에 연결
wandb.watch(model, log='all')

[34m[1mwandb[0m: Currently logged in as: [33mwoow070840[0m ([33mwaooang[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


# 학습

In [13]:
from monai.metrics import DiceMetric

def processing(batch_data, model, criterion,device):
    
        
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    loss = criterion(outputs, labels_onehot)
        
    
    # loss = loss_fn(criterion(outputs, labels_onehot),class_weights=class_weights, device=device)
    return loss, outputs, labels, outputs[0].argmax(dim=1) if type(outputs) == list else outputs.argmax(dim=1)

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, accumulation_steps=4):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()  # 그래디언트 초기화
    with tqdm(train_loader, desc='Training') as pbar:
        for i, batch_data in enumerate(pbar):
            # 손실 계산
            loss, _, _, _ = processing(batch_data, model, criterion, device)

            # 그래디언트를 계산하고 누적
            loss = loss / accumulation_steps  # 그래디언트 누적을 위한 스케일링
            loss.backward()  # 그래디언트 계산 및 누적
            
            # 그래디언트 업데이트 (accumulation_steps마다 한 번)
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                optimizer.step()  # 파라미터 업데이트
                optimizer.zero_grad()  # 누적된 그래디언트 초기화
            
            # 손실값 누적 (스케일링 복구)
            epoch_loss += loss.item() * accumulation_steps  # 실제 손실값 반영
            pbar.set_postfix(loss=loss.item() * accumulation_steps)  # 실제 손실값 출력
    avg_loss = epoch_loss / len(train_loader)
    wandb.log({'train_epoch_loss': avg_loss, 'epoch': epoch + 1})
    return avg_loss



def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval, ce_weight):
    model.eval()
    val_loss = 0
    
    class_dice_scores = {i: [] for i in range(n_classes)}
    class_f_beta_scores = {i: [] for i in range(n_classes)}
    class_mIoU_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())
                        precision = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(pred_i) + 1e-8)
                        recall = (torch.sum(pred_i & label_i) + 1e-8) / (torch.sum(label_i) + 1e-8)
                        f_beta_score = (1 + 4**2) * (precision * recall) / (4**2 * precision + recall + 1e-8)
                        class_f_beta_scores[i].append(f_beta_score.item())
                        intersection = torch.sum(pred_i & label_i).float()
                        union = torch.sum(pred_i | label_i).float()
                        iou = (intersection + 1e-8) / (union + 1e-8)
                        class_mIoU_scores[i].append(iou.item())

    avg_loss = val_loss / len(val_loader)
    # 에포크별 평균 손실 로깅
    wandb.log({'val_epoch_loss': avg_loss, 'epoch': epoch + 1})
    
    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        all_classes_dice_scores = []
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            wandb.log({f'class_{i}_dice_score': mean_dice, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_dice_scores.append(mean_dice)
            
        print()
    if epoch % calculate_dice_interval == 0:
        print("Validation F-beta Score")
        all_classes_fbeta_scores = []
        for i in range(n_classes):
            mean_fbeta = np.mean(class_f_beta_scores[i])
            wandb.log({f'class_{i}_f_beta_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_fbeta:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_fbeta_scores.append(mean_fbeta)
               
        print() 
    if epoch % calculate_dice_interval == 0:
        print("Validation mIoU Score")
        all_classes_mIoU_scores = []
        for i in range(n_classes):
            mean_IoU = np.mean(class_mIoU_scores[i])
            wandb.log({f'class_{i}_IoU_score': mean_fbeta, 'epoch': epoch + 1})
            print(f"Class {i}: {mean_IoU:.4f}", end=", ")
            if i not in [0, 2]:  # 평균에 포함할 클래스만 추가
                all_classes_mIoU_scores.append(mean_IoU)
                
        print()
        overall_mean_dice = np.mean(all_classes_dice_scores)
        overall_mean_fbeta = np.mean(all_classes_fbeta_scores)
        overall_mean_IoU = np.mean(all_classes_mIoU_scores)
        wandb.log({'overall_mean_f_beta_score': overall_mean_fbeta, 'overall_mean_dice_score': overall_mean_dice, 'epoch': epoch + 1, 'overall_mean_IoU_score': overall_mean_IoU})
        print(f"\nOverall Mean Dice Score: {overall_mean_dice:.4f}\nOverall Mean F-beta Score: {overall_mean_fbeta:.4f}\nOverall Mean IoU Score: {overall_mean_IoU:.4f}")

    if overall_mean_fbeta is None:
        overall_mean_fbeta = 0

    final_score = overall_mean_fbeta * (1 - ce_weight) + overall_mean_IoU * ce_weight
    return val_loss / len(val_loader), final_score 

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, best_val_fbeta_score, calculate_dice_interval=1,
    accumulation_steps=4, pretrained=False
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        best_val_fbeta_score: 이전 최적 validation f-beta score
        calculate_dice_interval: Dice 점수 계산 주기
    """
    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device,
            epoch=epoch,
            accumulation_steps= accumulation_steps
        )
        
        scheduler.step(train_loss)
        # Validate One Epoch
        val_loss, overall_mean_fbeta_score = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval,
            ce_weight=ce_weight
        )

        
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F-beta: {overall_mean_fbeta_score:.4f}")

        if val_loss < best_val_loss and overall_mean_fbeta_score > best_val_fbeta_score:
            best_val_loss = val_loss
            best_val_fbeta_score = overall_mean_fbeta_score
            epochs_no_improve = 0
            if pretrained:
                checkpoint_path = os.path.join(checkpoint_dir, 'best_model_pretrained.pt')
            else:
                checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            print(f"========================================================")
            print(f"SUPER Best model saved. Loss:{best_val_loss:.4f}, Score:{best_val_fbeta_score:.4f}")
            print(f"========================================================")

        # Early stopping 조건 체크
        if val_loss >= best_val_loss and overall_mean_fbeta_score <= best_val_fbeta_score:
            epochs_no_improve += 1
        else:
            epochs_no_improve = 0

        if epochs_no_improve >= patience:
            print("Early stopping")
            checkpoint_path = os.path.join(checkpoint_dir, 'last.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'best_val_fbeta_score': best_val_fbeta_score
            }, checkpoint_path)
            break
        # if epochs_no_improve % 6 == 0 & epochs_no_improve != 0:
        #     # 손실이 개선되지 않았으므로 lambda 감소
        #     new_lamda = max(criterion.lamda - 0.01, 0.35)  # 최소값은 0.1로 설정
        #     criterion.set_lamda(new_lamda)
        #     print(f"Validation loss did not improve. Reducing lambda to {new_lamda:.4f}")

    wandb.finish()


In [14]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=10,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    best_val_fbeta_score=best_val_fbeta_score,
    calculate_dice_interval=1,
    accumulation_steps = accumulation_steps
    )

Epoch 29/4000


Training:   0%|          | 0/432 [00:10<?, ?it/s]


TypeError: CombinedCETverskyLoss.forward() got an unexpected keyword argument 'validation'