In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR, SwinUNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)


import torch

print_config()

MONAI version: 1.4.0
Numpy version: 1.26.3
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: c:\ProgramData\anaconda3\envs\ship\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 5.2.0
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the 

In [2]:
# class_info = {
#     0: {"name": "background", "weight": 10000},  # weight 없음
#     1: {"name": "apo-ferritin", "weight": 300},
#     2: {"name": "beta-amylase", "weight": 100}, # 4130
#     3: {"name": "beta-galactosidase", "weight": 150}, #3080
#     4: {"name": "ribosome", "weight": 6000},
#     5: {"name": "thyroglobulin", "weight": 4000},
#     6: {"name": "virus-like-particle", "weight": 2000},
# }

# # 가중치에 비례한 비율 계산
# raw_ratios = {
#     k: (v["weight"] if v["weight"] is not None else 0.01)  # 가중치 비례, None일 경우 기본값
#     for k, v in class_info.items()
# }
# total = sum(raw_ratios.values())
# ratios = {k: v / total for k, v in raw_ratios.items()}

# # 최종 합계가 1인지 확인
# final_total = sum(ratios.values())
# print("클래스 비율:", ratios)
# print("최종 합계:", final_total)

# # 비율을 리스트로 변환
# ratios_list = [ratios[k] for k in sorted(ratios.keys())]
# print("클래스 비율 리스트:", ratios_list)

In [3]:
class_info = {
    0: {"name": "background", "weight": 10000},  # weight 없음
    1: {"name": "apo-ferritin", "weight": 300},
    2: {"name": "beta-amylase", "weight": 10000}, # 100
    3: {"name": "beta-galactosidase", "weight": 150}, #3080
    4: {"name": "ribosome", "weight": 6000},
    5: {"name": "thyroglobulin", "weight": 4000},
    6: {"name": "virus-like-particle", "weight": 2000},
}

# 가중치의 역수 계산 (0인 경우 작은 값을 대체)
inverse_ratios = {
    k: (1 / v["weight"] if v["weight"] > 0 else 1e-6)
    for k, v in class_info.items()
}

# 정규화하여 비율 계산
total_inverse = sum(inverse_ratios.values())
ratios_inverse = {k: v / total_inverse for k, v in inverse_ratios.items()}

# 최종 합계가 1인지 확인
final_total_inverse = sum(ratios_inverse.values())
print("반비례 클래스 비율:", ratios_inverse)
print("최종 합계:", final_total_inverse)

# 비율을 리스트로 변환
ratios_list = [ratios_inverse[k] for k in sorted(ratios_inverse.keys())]
print("반비례 클래스 비율 리스트:", ratios_list)


반비례 클래스 비율: {0: 0.008995502248875561, 1: 0.29985007496251875, 2: 0.008995502248875561, 3: 0.5997001499250375, 4: 0.014992503748125935, 5: 0.022488755622188904, 6: 0.04497751124437781}
최종 합계: 1.0
반비례 클래스 비율 리스트: [0.008995502248875561, 0.29985007496251875, 0.008995502248875561, 0.5997001499250375, 0.014992503748125935, 0.022488755622188904, 0.04497751124437781]


In [4]:
from src.dataset.dataset import create_dataloaders
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)

train_img_dir = "./datasets/train/images"
train_label_dir = "./datasets/train/labels"
val_img_dir = "./datasets/val/images"
val_label_dir = "./datasets/val/labels"
img_depth = 96
img_size = 96  # Match your patch size
n_classes = 7
batch_size = 2 # 13.8GB GPU memory required for 128x128 img size
num_samples = batch_size # 한 이미지에서 뽑을 샘플 수
loader_batch = 1

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
])
random_transforms = Compose([
    
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[img_depth, img_size, img_size],
        num_classes=n_classes,
        num_samples=num_samples, 
        ratios=ratios_list,
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])


In [5]:
train_loader, val_loader = None, None
train_loader, val_loader = create_dataloaders(
    train_img_dir, 
    train_label_dir, 
    val_img_dir, 
    val_label_dir, 
    non_random_transforms = non_random_transforms, 
    random_transforms = random_transforms, 
    batch_size = loader_batch,
    num_workers=0)

Loading dataset: 100%|██████████| 24/24 [00:37<00:00,  1.56s/it]
Loading dataset: 100%|██████████| 4/4 [00:05<00:00,  1.49s/it]


https://monai.io/model-zoo.html

In [6]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from monai.losses import TverskyLoss
from pathlib import Path
from monai.metrics import DiceMetric
# Model Configuration
start_epoch = 0
best_val_loss = float('inf')

# Training setup
num_epochs = 4000
lamda = 0.52
lr = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=36,
    use_checkpoint=False,
).to(device)
# Pretrained weights 불러오기
# pretrain_path = "./swin_unetr_btcv_segmentation/models/model.pt"
# weight = torch.load(pretrain_path, map_location=device)

# # 출력 레이어의 키를 제외한 나머지 가중치만 로드
# filtered_weights = {k: v for k, v in weight.items() if "out.conv.conv" not in k}

# # strict=False로 로드하여 불일치하는 부분 무시
# model.load_state_dict(filtered_weights, strict=False)
# print("Filtered weights loaded successfully. Output layer will be trained from scratch.")

# Load pretrained weights
# model.load_from(weights=np.load(config_vit.real_pretrained_path, allow_pickle=True))

# TverskyLoss 설정
criterion = TverskyLoss(
    alpha=1- lamda,  # FP에 대한 가중치
    beta= lamda,  # FN에 대한 가중치
    include_background=True,
    softmax=True
)

# 체크포인트 디렉토리 및 파일 설정
checkpoint_base_dir = Path("./model_checkpoints")
checkpoint_dir = checkpoint_base_dir / f"SwinUNETR_no_pretrain_f36_newratio_bTrue_{img_depth}_{img_size}_lr{lr}_lambda{lamda}_batch{batch_size}"
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

# 체크포인트 디렉토리 생성
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if checkpoint_dir.exists():
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"기존 best model 발견: {best_model_path}")
        try:
            checkpoint = torch.load(best_model_path, map_location=device)
            # 체크포인트 내부 키 검증
            required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'best_val_loss']
            if all(k in checkpoint for k in required_keys):
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']
                best_val_loss = checkpoint['best_val_loss']
                print("기존 학습된 가중치를 성공적으로 로드했습니다.")
            else:
                raise ValueError("체크포인트 파일에 필요한 key가 없습니다.")
        except Exception as e:
            print(f"체크포인트 파일을 로드하는 중 오류 발생: {e}")




In [7]:
batch = next(iter(val_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([2, 1, 96, 96, 96]) torch.Size([2, 1, 96, 96, 96])


In [8]:
torch.backends.cudnn.benchmark = True

In [None]:
from monai.metrics import DiceMetric

def create_metric_dict(num_classes):
    """각 클래스별 DiceMetric 생성"""
    metrics = {}
    for i in range(num_classes):
        metrics[f'dice_class_{i}'] = DiceMetric(
            include_background=False if i == 0 else False,
            reduction="mean",
            get_not_nans=False
        )
    return metrics
    
def processing(batch_data, model, criterion, device):
    images = batch_data['image'].to(device)  # Input 이미지 (B, 1, 96, 96, 96)
    labels = batch_data['label'].to(device)  # 라벨 (B, 96, 96, 96)

    labels = labels.squeeze(1)  # (B, 1, 96, 96, 96) → (B, 96, 96, 96)
    labels = labels.long()  # 라벨을 정수형으로 변환

    # 원핫 인코딩 (B, H, W, D) → (B, num_classes, H, W, D)
    
    labels_onehot = torch.nn.functional.one_hot(labels, num_classes=n_classes)
    labels_onehot = labels_onehot.permute(0, 4, 1, 2, 3).float()  # (B, num_classes, H, W, D)

    # 모델 예측
    outputs = model(images)  # outputs: (B, num_classes, H, W, D)

    # Loss 계산
    loss = criterion(outputs, labels_onehot)

    return loss, outputs, labels, outputs.argmax(dim=1)

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    with tqdm(train_loader, desc='Training') as pbar:
        for batch_data in pbar:
            optimizer.zero_grad()
            loss, _, _, _ = processing(batch_data, model, criterion, device)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

    return epoch_loss / len(train_loader)

def validate_one_epoch(model, val_loader, criterion, device, epoch, calculate_dice_interval):
    model.eval()
    val_loss = 0
    metrics = create_metric_dict(n_classes)
    class_dice_scores = {i: [] for i in range(n_classes)}
    with torch.no_grad():
        with tqdm(val_loader, desc='Validation') as pbar:
            for batch_data in pbar:
                loss, _, labels, preds = processing(batch_data, model, criterion, device)
                val_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

                # 각 클래스별 Dice 점수 계산
                if epoch % calculate_dice_interval == 0:
                    for i in range(n_classes):
                        pred_i = (preds == i)
                        label_i = (labels == i)
                        dice_score = (2.0 * torch.sum(pred_i & label_i)) / (torch.sum(pred_i) + torch.sum(label_i) + 1e-8)
                        class_dice_scores[i].append(dice_score.item())

    # 각 클래스별 평균 Dice 점수 출력
    if epoch % calculate_dice_interval == 0:
        print("Validation Dice Score")
        for i in range(n_classes):
            mean_dice = np.mean(class_dice_scores[i])
            print(f"Class {i}: {mean_dice:.4f}", end=", ")
            if i == 3:
                print()
        print()

    return val_loss / len(val_loader)

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, 
    device, start_epoch, best_val_loss, calculate_dice_interval=1
):
    """
    모델을 학습하고 검증하는 함수
    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        criterion: 손실 함수
        optimizer: 최적화 알고리즘
        num_epochs: 총 학습 epoch 수
        patience: early stopping 기준
        device: GPU/CPU 장치
        start_epoch: 시작 epoch
        best_val_loss: 이전 최적 validation loss
        calculate_dice_interval: Dice 점수 계산 주기
    """

    epochs_no_improve = 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train One Epoch
        train_loss = train_one_epoch(
            model=model, 
            train_loader=train_loader, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device
        )

        # Validate One Epoch
        val_loss = validate_one_epoch(
            model=model, 
            val_loader=val_loader, 
            criterion=criterion, 
            device=device, 
            epoch=epoch, 
            calculate_dice_interval=calculate_dice_interval
        )

        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Early stopping 및 모델 저장 로직
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
            }, checkpoint_path)
            print(f"========================================================")
            print(f"Best model saved with validation loss: {best_val_loss:.4f}")
            print(f"========================================================")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print("Early stopping")
            break

train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    patience=50,
    device=device,
    start_epoch=start_epoch,
    best_val_loss=best_val_loss,
    calculate_dice_interval=1
)

Epoch 1/4000


Training: 100%|██████████| 24/24 [00:47<00:00,  1.97s/it, loss=0.87] 
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s, loss=0.889]


Validation Dice Score
Class 0: 0.9291, Class 1: 0.0069, Class 2: 0.0021, Class 3: 0.0002, 
Class 4: 0.1660, Class 5: 0.0393, Class 6: 0.0143, 
Training Loss: 0.9101, Validation Loss: 0.8837
Best model saved with validation loss: 0.8837
Epoch 2/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.62s/it, loss=0.838]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.08it/s, loss=0.851]


Validation Dice Score
Class 0: 0.9686, Class 1: 0.0047, Class 2: 0.0000, Class 3: 0.0115, 
Class 4: 0.2959, Class 5: 0.0724, Class 6: 0.0071, 
Training Loss: 0.8705, Validation Loss: 0.8415
Best model saved with validation loss: 0.8415
Epoch 3/4000


Training: 100%|██████████| 24/24 [00:52<00:00,  2.17s/it, loss=0.859]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.75it/s, loss=0.844]


Validation Dice Score
Class 0: 0.9856, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.0051, 
Class 4: 0.3635, Class 5: 0.0012, Class 6: 0.0000, 
Training Loss: 0.8451, Validation Loss: 0.8278
Best model saved with validation loss: 0.8278
Epoch 4/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.61s/it, loss=0.855]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.17it/s, loss=0.828]


Validation Dice Score
Class 0: 0.9800, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.2080, 
Class 4: 0.1997, Class 5: 0.0098, Class 6: 0.0000, 
Training Loss: 0.8244, Validation Loss: 0.8154
Best model saved with validation loss: 0.8154
Epoch 5/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.852]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.16it/s, loss=0.835]


Validation Dice Score
Class 0: 0.9857, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.2124, 
Class 4: 0.0965, Class 5: 0.0621, Class 6: 0.0226, 
Training Loss: 0.8067, Validation Loss: 0.8209
Epoch 6/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.799]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.82] 


Validation Dice Score
Class 0: 0.9886, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.2033, 
Class 4: 0.1665, Class 5: 0.0015, Class 6: 0.0262, 
Training Loss: 0.8013, Validation Loss: 0.8097
Best model saved with validation loss: 0.8097
Epoch 7/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.784]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.20it/s, loss=0.769]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.0000, Class 2: 0.0000, Class 3: 0.1291, 
Class 4: 0.1005, Class 5: 0.0385, Class 6: 0.0453, 
Training Loss: 0.7790, Validation Loss: 0.8196
Epoch 8/4000


Training: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it, loss=0.705]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.788]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.0106, Class 2: 0.0000, Class 3: 0.1714, 
Class 4: 0.2930, Class 5: 0.0638, Class 6: 0.0156, 
Training Loss: 0.7886, Validation Loss: 0.7903
Best model saved with validation loss: 0.7903
Epoch 9/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.783]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.729]


Validation Dice Score
Class 0: 0.9874, Class 1: 0.0897, Class 2: 0.0000, Class 3: 0.4465, 
Class 4: 0.1452, Class 5: 0.1080, Class 6: 0.2983, 
Training Loss: 0.7717, Validation Loss: 0.7409
Best model saved with validation loss: 0.7409
Epoch 10/4000


Training: 100%|██████████| 24/24 [00:35<00:00,  1.46s/it, loss=0.719]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.02it/s, loss=0.708]


Validation Dice Score
Class 0: 0.9879, Class 1: 0.0303, Class 2: 0.0000, Class 3: 0.4255, 
Class 4: 0.3207, Class 5: 0.0426, Class 6: 0.2264, 
Training Loss: 0.7705, Validation Loss: 0.7301
Best model saved with validation loss: 0.7301
Epoch 11/4000


Training: 100%|██████████| 24/24 [00:35<00:00,  1.47s/it, loss=0.749]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, loss=0.759]


Validation Dice Score
Class 0: 0.9891, Class 1: 0.1243, Class 2: 0.0000, Class 3: 0.2086, 
Class 4: 0.3657, Class 5: 0.1234, Class 6: 0.1255, 
Training Loss: 0.7673, Validation Loss: 0.7763
Epoch 12/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.46s/it, loss=0.678]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.777]


Validation Dice Score
Class 0: 0.9856, Class 1: 0.0067, Class 2: 0.0000, Class 3: 0.2542, 
Class 4: 0.3451, Class 5: 0.0944, Class 6: 0.1806, 
Training Loss: 0.7581, Validation Loss: 0.7577
Epoch 13/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.46s/it, loss=0.715]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.696]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.1596, Class 2: 0.0000, Class 3: 0.2377, 
Class 4: 0.4693, Class 5: 0.0994, Class 6: 0.0041, 
Training Loss: 0.7516, Validation Loss: 0.7522
Epoch 14/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.767]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.621]


Validation Dice Score
Class 0: 0.9851, Class 1: 0.0285, Class 2: 0.0000, Class 3: 0.4823, 
Class 4: 0.3231, Class 5: 0.1785, Class 6: 0.3325, 
Training Loss: 0.7341, Validation Loss: 0.7074
Best model saved with validation loss: 0.7074
Epoch 15/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.757]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.752]


Validation Dice Score
Class 0: 0.9862, Class 1: 0.0494, Class 2: 0.0003, Class 3: 0.2981, 
Class 4: 0.3011, Class 5: 0.1189, Class 6: 0.0860, 
Training Loss: 0.7493, Validation Loss: 0.7663
Epoch 16/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.46s/it, loss=0.76] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.719]


Validation Dice Score
Class 0: 0.9839, Class 1: 0.1220, Class 2: 0.0010, Class 3: 0.4345, 
Class 4: 0.3158, Class 5: 0.1197, Class 6: 0.1549, 
Training Loss: 0.7340, Validation Loss: 0.7259
Epoch 17/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.767]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.20it/s, loss=0.736]


Validation Dice Score
Class 0: 0.9880, Class 1: 0.1924, Class 2: 0.0108, Class 3: 0.4532, 
Class 4: 0.2890, Class 5: 0.0904, Class 6: 0.1076, 
Training Loss: 0.7503, Validation Loss: 0.7477
Epoch 18/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.46s/it, loss=0.797]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, loss=0.782]


Validation Dice Score
Class 0: 0.9848, Class 1: 0.0796, Class 2: 0.0592, Class 3: 0.2955, 
Class 4: 0.2756, Class 5: 0.1362, Class 6: 0.0131, 
Training Loss: 0.7473, Validation Loss: 0.7660
Epoch 19/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.721]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.697]


Validation Dice Score
Class 0: 0.9862, Class 1: 0.2982, Class 2: 0.0357, Class 3: 0.2275, 
Class 4: 0.0436, Class 5: 0.0818, Class 6: 0.5837, 
Training Loss: 0.7455, Validation Loss: 0.7037
Best model saved with validation loss: 0.7037
Epoch 20/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.749]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.693]


Validation Dice Score
Class 0: 0.9845, Class 1: 0.1149, Class 2: 0.0000, Class 3: 0.2592, 
Class 4: 0.2249, Class 5: 0.0840, Class 6: 0.1462, 
Training Loss: 0.7494, Validation Loss: 0.7603
Epoch 21/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.815]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, loss=0.71] 


Validation Dice Score
Class 0: 0.9870, Class 1: 0.2203, Class 2: 0.0863, Class 3: 0.3532, 
Class 4: 0.2547, Class 5: 0.0958, Class 6: 0.3139, 
Training Loss: 0.7342, Validation Loss: 0.7073
Epoch 22/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.678]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.696]


Validation Dice Score
Class 0: 0.9876, Class 1: 0.4321, Class 2: 0.0239, Class 3: 0.4850, 
Class 4: 0.2782, Class 5: 0.1344, Class 6: 0.2699, 
Training Loss: 0.7079, Validation Loss: 0.6767
Best model saved with validation loss: 0.6767
Epoch 23/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.698]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.701]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.3225, Class 2: 0.0392, Class 3: 0.1916, 
Class 4: 0.2207, Class 5: 0.1310, Class 6: 0.1434, 
Training Loss: 0.7245, Validation Loss: 0.7403
Epoch 24/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.46s/it, loss=0.727]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.67] 


Validation Dice Score
Class 0: 0.9872, Class 1: 0.2059, Class 2: 0.0687, Class 3: 0.5268, 
Class 4: 0.3331, Class 5: 0.1539, Class 6: 0.1647, 
Training Loss: 0.7143, Validation Loss: 0.6715
Best model saved with validation loss: 0.6715
Epoch 25/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.717]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.783]


Validation Dice Score
Class 0: 0.9837, Class 1: 0.3861, Class 2: 0.0361, Class 3: 0.4637, 
Class 4: 0.2256, Class 5: 0.1604, Class 6: 0.2674, 
Training Loss: 0.7106, Validation Loss: 0.7063
Epoch 26/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.712]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.29it/s, loss=0.743]


Validation Dice Score
Class 0: 0.9847, Class 1: 0.3125, Class 2: 0.0518, Class 3: 0.2331, 
Class 4: 0.2707, Class 5: 0.1418, Class 6: 0.1598, 
Training Loss: 0.6977, Validation Loss: 0.7230
Epoch 27/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.7]  
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.771]


Validation Dice Score
Class 0: 0.9864, Class 1: 0.4678, Class 2: 0.0193, Class 3: 0.3772, 
Class 4: 0.3726, Class 5: 0.1203, Class 6: 0.1598, 
Training Loss: 0.7277, Validation Loss: 0.7020
Epoch 28/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.66] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.25it/s, loss=0.622]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.5578, Class 2: 0.0329, Class 3: 0.3830, 
Class 4: 0.2539, Class 5: 0.1269, Class 6: 0.3517, 
Training Loss: 0.6830, Validation Loss: 0.6795
Epoch 29/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.693]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.28it/s, loss=0.746]


Validation Dice Score
Class 0: 0.9885, Class 1: 0.4475, Class 2: 0.1265, Class 3: 0.2862, 
Class 4: 0.1635, Class 5: 0.1690, Class 6: 0.2237, 
Training Loss: 0.7003, Validation Loss: 0.6958
Epoch 30/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.719]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.684]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.4052, Class 2: 0.0725, Class 3: 0.6287, 
Class 4: 0.2541, Class 5: 0.1832, Class 6: 0.2991, 
Training Loss: 0.7017, Validation Loss: 0.6511
Best model saved with validation loss: 0.6511
Epoch 31/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.716]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.688]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.3386, Class 2: 0.0588, Class 3: 0.2147, 
Class 4: 0.1387, Class 5: 0.1929, Class 6: 0.3672, 
Training Loss: 0.6834, Validation Loss: 0.6923
Epoch 32/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.46s/it, loss=0.819]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.655]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.6087, Class 2: 0.0012, Class 3: 0.2262, 
Class 4: 0.2703, Class 5: 0.1936, Class 6: 0.3437, 
Training Loss: 0.6982, Validation Loss: 0.6519
Epoch 33/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.645]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, loss=0.76] 


Validation Dice Score
Class 0: 0.9867, Class 1: 0.4171, Class 2: 0.0702, Class 3: 0.3604, 
Class 4: 0.3615, Class 5: 0.1256, Class 6: 0.1769, 
Training Loss: 0.6729, Validation Loss: 0.6941
Epoch 34/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.785]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.733]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.6066, Class 2: 0.0796, Class 3: 0.4619, 
Class 4: 0.1527, Class 5: 0.2076, Class 6: 0.1336, 
Training Loss: 0.6785, Validation Loss: 0.6918
Epoch 35/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.701]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.711]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.5405, Class 2: 0.0081, Class 3: 0.3590, 
Class 4: 0.4323, Class 5: 0.2264, Class 6: 0.4464, 
Training Loss: 0.6824, Validation Loss: 0.6405
Best model saved with validation loss: 0.6405
Epoch 36/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.66] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, loss=0.728]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.3972, Class 2: 0.0785, Class 3: 0.3047, 
Class 4: 0.3541, Class 5: 0.1215, Class 6: 0.0724, 
Training Loss: 0.6977, Validation Loss: 0.7074
Epoch 37/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.704]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.17it/s, loss=0.584]


Validation Dice Score
Class 0: 0.9865, Class 1: 0.3693, Class 2: 0.0000, Class 3: 0.4248, 
Class 4: 0.4140, Class 5: 0.2173, Class 6: 0.4118, 
Training Loss: 0.6922, Validation Loss: 0.6341
Best model saved with validation loss: 0.6341
Epoch 38/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.731]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.22it/s, loss=0.585]


Validation Dice Score
Class 0: 0.9908, Class 1: 0.5088, Class 2: 0.0342, Class 3: 0.4098, 
Class 4: 0.4071, Class 5: 0.2267, Class 6: 0.3053, 
Training Loss: 0.6774, Validation Loss: 0.6417
Epoch 39/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.692]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.22it/s, loss=0.652]


Validation Dice Score
Class 0: 0.9863, Class 1: 0.2911, Class 2: 0.0340, Class 3: 0.3089, 
Class 4: 0.2280, Class 5: 0.1508, Class 6: 0.1952, 
Training Loss: 0.6841, Validation Loss: 0.7278
Epoch 40/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.719]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.24it/s, loss=0.655]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.5787, Class 2: 0.0242, Class 3: 0.3956, 
Class 4: 0.1231, Class 5: 0.2537, Class 6: 0.2352, 
Training Loss: 0.6879, Validation Loss: 0.6571
Epoch 41/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.629]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, loss=0.628]


Validation Dice Score
Class 0: 0.9893, Class 1: 0.6095, Class 2: 0.0754, Class 3: 0.4171, 
Class 4: 0.2396, Class 5: 0.2329, Class 6: 0.4736, 
Training Loss: 0.6696, Validation Loss: 0.6203
Best model saved with validation loss: 0.6203
Epoch 42/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.705]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.25it/s, loss=0.678]


Validation Dice Score
Class 0: 0.9841, Class 1: 0.6123, Class 2: 0.1043, Class 3: 0.4249, 
Class 4: 0.1910, Class 5: 0.1222, Class 6: 0.2229, 
Training Loss: 0.6700, Validation Loss: 0.6670
Epoch 43/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.716]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.24it/s, loss=0.658]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.6378, Class 2: 0.1185, Class 3: 0.2959, 
Class 4: 0.2272, Class 5: 0.2044, Class 6: 0.1155, 
Training Loss: 0.6690, Validation Loss: 0.6873
Epoch 44/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.673]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.22it/s, loss=0.66] 


Validation Dice Score
Class 0: 0.9839, Class 1: 0.7034, Class 2: 0.0764, Class 3: 0.3368, 
Class 4: 0.3811, Class 5: 0.1618, Class 6: 0.3167, 
Training Loss: 0.6839, Validation Loss: 0.6443
Epoch 45/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.673]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.693]


Validation Dice Score
Class 0: 0.9894, Class 1: 0.4246, Class 2: 0.0333, Class 3: 0.3609, 
Class 4: 0.2584, Class 5: 0.2297, Class 6: 0.1371, 
Training Loss: 0.6683, Validation Loss: 0.7016
Epoch 46/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.572]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.24it/s, loss=0.654]


Validation Dice Score
Class 0: 0.9856, Class 1: 0.4566, Class 2: 0.0452, Class 3: 0.3728, 
Class 4: 0.2989, Class 5: 0.1972, Class 6: 0.1174, 
Training Loss: 0.6517, Validation Loss: 0.6612
Epoch 47/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.679]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.706]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.3226, Class 2: 0.0855, Class 3: 0.2723, 
Class 4: 0.4762, Class 5: 0.1765, Class 6: 0.0104, 
Training Loss: 0.6592, Validation Loss: 0.6985
Epoch 48/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.626]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.22it/s, loss=0.742]


Validation Dice Score
Class 0: 0.9885, Class 1: 0.3823, Class 2: 0.1370, Class 3: 0.6114, 
Class 4: 0.1908, Class 5: 0.2702, Class 6: 0.5858, 
Training Loss: 0.6683, Validation Loss: 0.6287
Epoch 49/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.72] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.666]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.5510, Class 2: 0.1638, Class 3: 0.3203, 
Class 4: 0.2043, Class 5: 0.1004, Class 6: 0.1479, 
Training Loss: 0.6888, Validation Loss: 0.7098
Epoch 50/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.66] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.27it/s, loss=0.768]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.4958, Class 2: 0.1137, Class 3: 0.3997, 
Class 4: 0.0740, Class 5: 0.1414, Class 6: 0.0288, 
Training Loss: 0.6545, Validation Loss: 0.7117
Epoch 51/4000


Training: 100%|██████████| 24/24 [00:35<00:00,  1.46s/it, loss=0.648]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.25it/s, loss=0.705]


Validation Dice Score
Class 0: 0.9842, Class 1: 0.2047, Class 2: 0.1484, Class 3: 0.4310, 
Class 4: 0.4434, Class 5: 0.1924, Class 6: 0.3455, 
Training Loss: 0.6602, Validation Loss: 0.6672
Epoch 52/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.589]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, loss=0.667]


Validation Dice Score
Class 0: 0.9843, Class 1: 0.5040, Class 2: 0.0420, Class 3: 0.5425, 
Class 4: 0.3819, Class 5: 0.1367, Class 6: 0.3316, 
Training Loss: 0.6566, Validation Loss: 0.6532
Epoch 53/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.678]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.16it/s, loss=0.716]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.6845, Class 2: 0.1182, Class 3: 0.4833, 
Class 4: 0.4006, Class 5: 0.1636, Class 6: 0.4367, 
Training Loss: 0.6613, Validation Loss: 0.5850
Best model saved with validation loss: 0.5850
Epoch 54/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.78] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.14it/s, loss=0.668]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.6290, Class 2: 0.0960, Class 3: 0.4394, 
Class 4: 0.1656, Class 5: 0.2002, Class 6: 0.3253, 
Training Loss: 0.6817, Validation Loss: 0.6340
Epoch 55/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.666]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.792]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.5054, Class 2: 0.1241, Class 3: 0.2524, 
Class 4: 0.0960, Class 5: 0.1838, Class 6: 0.0443, 
Training Loss: 0.6624, Validation Loss: 0.7226
Epoch 56/4000


Training: 100%|██████████| 24/24 [00:35<00:00,  1.46s/it, loss=0.544]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.24it/s, loss=0.625]


Validation Dice Score
Class 0: 0.9900, Class 1: 0.5309, Class 2: 0.0515, Class 3: 0.4600, 
Class 4: 0.1827, Class 5: 0.1505, Class 6: 0.4193, 
Training Loss: 0.6485, Validation Loss: 0.6496
Epoch 57/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.694]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.20it/s, loss=0.588]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.6401, Class 2: 0.1046, Class 3: 0.3342, 
Class 4: 0.3618, Class 5: 0.1899, Class 6: 0.4783, 
Training Loss: 0.6654, Validation Loss: 0.6025
Epoch 58/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.601]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.16it/s, loss=0.629]


Validation Dice Score
Class 0: 0.9896, Class 1: 0.4730, Class 2: 0.0000, Class 3: 0.5558, 
Class 4: 0.4550, Class 5: 0.3049, Class 6: 0.4548, 
Training Loss: 0.6687, Validation Loss: 0.5995
Epoch 59/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.663]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.20it/s, loss=0.744]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.5777, Class 2: 0.0000, Class 3: 0.4801, 
Class 4: 0.3591, Class 5: 0.1021, Class 6: 0.3235, 
Training Loss: 0.6695, Validation Loss: 0.6623
Epoch 60/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.631]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.20it/s, loss=0.597]


Validation Dice Score
Class 0: 0.9843, Class 1: 0.4802, Class 2: 0.0510, Class 3: 0.5290, 
Class 4: 0.4391, Class 5: 0.1813, Class 6: 0.1254, 
Training Loss: 0.6450, Validation Loss: 0.6415
Epoch 61/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.46s/it, loss=0.635]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, loss=0.665]


Validation Dice Score
Class 0: 0.9889, Class 1: 0.7047, Class 2: 0.0506, Class 3: 0.2726, 
Class 4: 0.3411, Class 5: 0.2766, Class 6: 0.4567, 
Training Loss: 0.6544, Validation Loss: 0.6298
Epoch 62/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.698]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.22it/s, loss=0.59] 


Validation Dice Score
Class 0: 0.9868, Class 1: 0.6946, Class 2: 0.0335, Class 3: 0.5114, 
Class 4: 0.5235, Class 5: 0.2910, Class 6: 0.2872, 
Training Loss: 0.6611, Validation Loss: 0.6045
Epoch 63/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.613]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.20it/s, loss=0.624]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.7478, Class 2: 0.0900, Class 3: 0.5160, 
Class 4: 0.4770, Class 5: 0.2737, Class 6: 0.3019, 
Training Loss: 0.6577, Validation Loss: 0.5809
Best model saved with validation loss: 0.5809
Epoch 64/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.629]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.542]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.5086, Class 2: 0.0952, Class 3: 0.4836, 
Class 4: 0.5254, Class 5: 0.2233, Class 6: 0.1933, 
Training Loss: 0.6735, Validation Loss: 0.6159
Epoch 65/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.781]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.691]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.4138, Class 2: 0.0943, Class 3: 0.2992, 
Class 4: 0.2996, Class 5: 0.2638, Class 6: 0.2786, 
Training Loss: 0.6454, Validation Loss: 0.6750
Epoch 66/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.709]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, loss=0.662]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.5047, Class 2: 0.2456, Class 3: 0.4520, 
Class 4: 0.3622, Class 5: 0.3532, Class 6: 0.1521, 
Training Loss: 0.6558, Validation Loss: 0.6193
Epoch 67/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.579]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.639]


Validation Dice Score
Class 0: 0.9885, Class 1: 0.5689, Class 2: 0.0806, Class 3: 0.4251, 
Class 4: 0.2940, Class 5: 0.1353, Class 6: 0.1980, 
Training Loss: 0.6299, Validation Loss: 0.6658
Epoch 68/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.667]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.22it/s, loss=0.643]


Validation Dice Score
Class 0: 0.9853, Class 1: 0.5004, Class 2: 0.2014, Class 3: 0.5346, 
Class 4: 0.3293, Class 5: 0.2475, Class 6: 0.1666, 
Training Loss: 0.6585, Validation Loss: 0.6480
Epoch 69/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.762]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.25it/s, loss=0.64] 


Validation Dice Score
Class 0: 0.9915, Class 1: 0.6566, Class 2: 0.0765, Class 3: 0.6128, 
Class 4: 0.1835, Class 5: 0.2971, Class 6: 0.4208, 
Training Loss: 0.6542, Validation Loss: 0.6248
Epoch 70/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.535]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, loss=0.703]


Validation Dice Score
Class 0: 0.9881, Class 1: 0.4304, Class 2: 0.0913, Class 3: 0.3540, 
Class 4: 0.5351, Class 5: 0.2294, Class 6: 0.1860, 
Training Loss: 0.6305, Validation Loss: 0.6510
Epoch 71/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.62] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.22it/s, loss=0.656]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.5445, Class 2: 0.0303, Class 3: 0.2045, 
Class 4: 0.2497, Class 5: 0.2523, Class 6: 0.4341, 
Training Loss: 0.6580, Validation Loss: 0.6651
Epoch 72/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.692]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.22it/s, loss=0.672]


Validation Dice Score
Class 0: 0.9862, Class 1: 0.5756, Class 2: 0.0057, Class 3: 0.3239, 
Class 4: 0.1940, Class 5: 0.1992, Class 6: 0.2677, 
Training Loss: 0.6394, Validation Loss: 0.6637
Epoch 73/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.46s/it, loss=0.701]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.662]


Validation Dice Score
Class 0: 0.9861, Class 1: 0.7064, Class 2: 0.0811, Class 3: 0.2704, 
Class 4: 0.4929, Class 5: 0.2318, Class 6: 0.3124, 
Training Loss: 0.6502, Validation Loss: 0.6090
Epoch 74/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.625]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.27it/s, loss=0.682]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.4681, Class 2: 0.0981, Class 3: 0.2782, 
Class 4: 0.1579, Class 5: 0.2721, Class 6: 0.2691, 
Training Loss: 0.6530, Validation Loss: 0.6938
Epoch 75/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.662]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.653]


Validation Dice Score
Class 0: 0.9885, Class 1: 0.4687, Class 2: 0.0928, Class 3: 0.4825, 
Class 4: 0.4063, Class 5: 0.3420, Class 6: 0.0666, 
Training Loss: 0.6491, Validation Loss: 0.6278
Epoch 76/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.58] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s, loss=0.686]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.5014, Class 2: 0.0706, Class 3: 0.3520, 
Class 4: 0.4088, Class 5: 0.2117, Class 6: 0.0375, 
Training Loss: 0.6491, Validation Loss: 0.6834
Epoch 77/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.728]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.749]


Validation Dice Score
Class 0: 0.9901, Class 1: 0.6400, Class 2: 0.0000, Class 3: 0.2535, 
Class 4: 0.4351, Class 5: 0.1819, Class 6: 0.3399, 
Training Loss: 0.6644, Validation Loss: 0.6579
Epoch 78/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.613]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.24it/s, loss=0.658]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.2887, Class 2: 0.1069, Class 3: 0.2794, 
Class 4: 0.2263, Class 5: 0.2550, Class 6: 0.3552, 
Training Loss: 0.6503, Validation Loss: 0.6799
Epoch 79/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.698]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.654]


Validation Dice Score
Class 0: 0.9878, Class 1: 0.4963, Class 2: 0.0701, Class 3: 0.6145, 
Class 4: 0.4751, Class 5: 0.2094, Class 6: 0.2716, 
Training Loss: 0.6639, Validation Loss: 0.6122
Epoch 80/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.732]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.579]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.6844, Class 2: 0.1559, Class 3: 0.3924, 
Class 4: 0.3659, Class 5: 0.2156, Class 6: 0.1172, 
Training Loss: 0.6536, Validation Loss: 0.6263
Epoch 81/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.45s/it, loss=0.628]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.647]


Validation Dice Score
Class 0: 0.9843, Class 1: 0.5741, Class 2: 0.2025, Class 3: 0.2935, 
Class 4: 0.3092, Class 5: 0.3572, Class 6: 0.1741, 
Training Loss: 0.6537, Validation Loss: 0.6296
Epoch 82/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.583]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.17it/s, loss=0.685]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.7463, Class 2: 0.1179, Class 3: 0.4563, 
Class 4: 0.2513, Class 5: 0.1599, Class 6: 0.1802, 
Training Loss: 0.6276, Validation Loss: 0.6506
Epoch 83/4000


Training: 100%|██████████| 24/24 [00:34<00:00,  1.44s/it, loss=0.675]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.32it/s, loss=0.636]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7106, Class 2: 0.0449, Class 3: 0.4855, 
Class 4: 0.2349, Class 5: 0.3029, Class 6: 0.4568, 
Training Loss: 0.6486, Validation Loss: 0.6323
Epoch 84/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.36s/it, loss=0.576]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.30it/s, loss=0.556]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.7473, Class 2: 0.0230, Class 3: 0.3253, 
Class 4: 0.4152, Class 5: 0.2009, Class 6: 0.2363, 
Training Loss: 0.6526, Validation Loss: 0.6373
Epoch 85/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.37s/it, loss=0.576]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.32it/s, loss=0.518]


Validation Dice Score
Class 0: 0.9890, Class 1: 0.6516, Class 2: 0.1266, Class 3: 0.4792, 
Class 4: 0.2857, Class 5: 0.2636, Class 6: 0.2301, 
Training Loss: 0.6271, Validation Loss: 0.6300
Epoch 86/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.37s/it, loss=0.631]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.29it/s, loss=0.723]


Validation Dice Score
Class 0: 0.9824, Class 1: 0.6297, Class 2: 0.0429, Class 3: 0.2224, 
Class 4: 0.2340, Class 5: 0.2522, Class 6: 0.0140, 
Training Loss: 0.6367, Validation Loss: 0.6799
Epoch 87/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.36s/it, loss=0.702]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.35it/s, loss=0.646]


Validation Dice Score
Class 0: 0.9835, Class 1: 0.4485, Class 2: 0.0961, Class 3: 0.3935, 
Class 4: 0.2349, Class 5: 0.2145, Class 6: 0.0003, 
Training Loss: 0.6451, Validation Loss: 0.7036
Epoch 88/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.36s/it, loss=0.657]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.34it/s, loss=0.675]


Validation Dice Score
Class 0: 0.9879, Class 1: 0.6263, Class 2: 0.0642, Class 3: 0.3786, 
Class 4: 0.0049, Class 5: 0.2879, Class 6: 0.3207, 
Training Loss: 0.6351, Validation Loss: 0.6573
Epoch 89/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.36s/it, loss=0.673]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.27it/s, loss=0.578]


Validation Dice Score
Class 0: 0.9875, Class 1: 0.7324, Class 2: 0.0142, Class 3: 0.5486, 
Class 4: 0.2361, Class 5: 0.3210, Class 6: 0.3154, 
Training Loss: 0.6464, Validation Loss: 0.6018
Epoch 90/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.37s/it, loss=0.697]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.32it/s, loss=0.52] 


Validation Dice Score
Class 0: 0.9895, Class 1: 0.5029, Class 2: 0.1420, Class 3: 0.4584, 
Class 4: 0.3032, Class 5: 0.3459, Class 6: 0.4972, 
Training Loss: 0.6613, Validation Loss: 0.6065
Epoch 91/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.37s/it, loss=0.646]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s, loss=0.619]


Validation Dice Score
Class 0: 0.9842, Class 1: 0.6721, Class 2: 0.1180, Class 3: 0.2278, 
Class 4: 0.1444, Class 5: 0.2168, Class 6: 0.3172, 
Training Loss: 0.6313, Validation Loss: 0.6455
Epoch 92/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.37s/it, loss=0.637]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.30it/s, loss=0.566]


Validation Dice Score
Class 0: 0.9876, Class 1: 0.5521, Class 2: 0.1641, Class 3: 0.4872, 
Class 4: 0.3861, Class 5: 0.1779, Class 6: 0.0136, 
Training Loss: 0.6399, Validation Loss: 0.6535
Epoch 93/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.36s/it, loss=0.502]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.33it/s, loss=0.789]


Validation Dice Score
Class 0: 0.9838, Class 1: 0.5179, Class 2: 0.0078, Class 3: 0.3684, 
Class 4: 0.2721, Class 5: 0.1663, Class 6: 0.2330, 
Training Loss: 0.6423, Validation Loss: 0.6537
Epoch 94/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.36s/it, loss=0.704]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.31it/s, loss=0.653]


Validation Dice Score
Class 0: 0.9846, Class 1: 0.7175, Class 2: 0.0606, Class 3: 0.4104, 
Class 4: 0.3845, Class 5: 0.2604, Class 6: 0.4717, 
Training Loss: 0.6480, Validation Loss: 0.5799
Best model saved with validation loss: 0.5799
Epoch 95/4000


Training: 100%|██████████| 24/24 [00:33<00:00,  1.39s/it, loss=0.567]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.29it/s, loss=0.641]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.7212, Class 2: 0.0949, Class 3: 0.2022, 
Class 4: 0.4907, Class 5: 0.2560, Class 6: 0.3245, 
Training Loss: 0.6464, Validation Loss: 0.6153
Epoch 96/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.36s/it, loss=0.56] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.30it/s, loss=0.709]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.4803, Class 2: 0.1798, Class 3: 0.3200, 
Class 4: 0.3026, Class 5: 0.2016, Class 6: 0.2965, 
Training Loss: 0.6410, Validation Loss: 0.6572
Epoch 97/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.37s/it, loss=0.572]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.36it/s, loss=0.625]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.6993, Class 2: 0.2213, Class 3: 0.4822, 
Class 4: 0.3486, Class 5: 0.2533, Class 6: 0.3796, 
Training Loss: 0.6379, Validation Loss: 0.5656
Best model saved with validation loss: 0.5656
Epoch 98/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.37s/it, loss=0.664]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.33it/s, loss=0.692]


Validation Dice Score
Class 0: 0.9907, Class 1: 0.3768, Class 2: 0.0000, Class 3: 0.3643, 
Class 4: 0.3161, Class 5: 0.2122, Class 6: 0.0872, 
Training Loss: 0.6335, Validation Loss: 0.6946
Epoch 99/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.36s/it, loss=0.597]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.34it/s, loss=0.628]


Validation Dice Score
Class 0: 0.9883, Class 1: 0.5303, Class 2: 0.1940, Class 3: 0.3358, 
Class 4: 0.3588, Class 5: 0.1285, Class 6: 0.1777, 
Training Loss: 0.6382, Validation Loss: 0.6594
Epoch 100/4000


Training: 100%|██████████| 24/24 [00:32<00:00,  1.37s/it, loss=0.639]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.34it/s, loss=0.683]


Validation Dice Score
Class 0: 0.9855, Class 1: 0.6893, Class 2: 0.1667, Class 3: 0.6379, 
Class 4: 0.3093, Class 5: 0.2631, Class 6: 0.2124, 
Training Loss: 0.6385, Validation Loss: 0.5810
Epoch 101/4000


Training: 100%|██████████| 24/24 [00:39<00:00,  1.64s/it, loss=0.589]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.09it/s, loss=0.67] 


Validation Dice Score
Class 0: 0.9873, Class 1: 0.4517, Class 2: 0.1281, Class 3: 0.5215, 
Class 4: 0.5171, Class 5: 0.1611, Class 6: 0.1536, 
Training Loss: 0.6463, Validation Loss: 0.6419
Epoch 102/4000


Training: 100%|██████████| 24/24 [00:36<00:00,  1.52s/it, loss=0.734]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.13it/s, loss=0.582]


Validation Dice Score
Class 0: 0.9834, Class 1: 0.5205, Class 2: 0.1191, Class 3: 0.3421, 
Class 4: 0.1415, Class 5: 0.2769, Class 6: 0.3399, 
Training Loss: 0.6383, Validation Loss: 0.6535
Epoch 103/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.61s/it, loss=0.714]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.06it/s, loss=0.593]


Validation Dice Score
Class 0: 0.9873, Class 1: 0.4705, Class 2: 0.0376, Class 3: 0.3447, 
Class 4: 0.6594, Class 5: 0.1663, Class 6: 0.1704, 
Training Loss: 0.6367, Validation Loss: 0.6300
Epoch 104/4000


Training: 100%|██████████| 24/24 [00:37<00:00,  1.58s/it, loss=0.598]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.13it/s, loss=0.544]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.6623, Class 2: 0.0433, Class 3: 0.6014, 
Class 4: 0.3907, Class 5: 0.2029, Class 6: 0.5949, 
Training Loss: 0.6342, Validation Loss: 0.5836
Epoch 105/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.62s/it, loss=0.6]  
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.04it/s, loss=0.603]


Validation Dice Score
Class 0: 0.9895, Class 1: 0.7404, Class 2: 0.0155, Class 3: 0.2885, 
Class 4: 0.4602, Class 5: 0.3093, Class 6: 0.2591, 
Training Loss: 0.6137, Validation Loss: 0.6187
Epoch 106/4000


Training: 100%|██████████| 24/24 [00:40<00:00,  1.71s/it, loss=0.732]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.07it/s, loss=0.709]


Validation Dice Score
Class 0: 0.9879, Class 1: 0.5417, Class 2: 0.1574, Class 3: 0.2832, 
Class 4: 0.2892, Class 5: 0.2539, Class 6: 0.1649, 
Training Loss: 0.6612, Validation Loss: 0.6783
Epoch 107/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.59s/it, loss=0.701]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.04it/s, loss=0.758]


Validation Dice Score
Class 0: 0.9842, Class 1: 0.4479, Class 2: 0.0966, Class 3: 0.2631, 
Class 4: 0.0626, Class 5: 0.1347, Class 6: 0.3812, 
Training Loss: 0.6190, Validation Loss: 0.7053
Epoch 108/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.61s/it, loss=0.543]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.08it/s, loss=0.557]


Validation Dice Score
Class 0: 0.9887, Class 1: 0.6407, Class 2: 0.3580, Class 3: 0.4152, 
Class 4: 0.3357, Class 5: 0.2859, Class 6: 0.2539, 
Training Loss: 0.6412, Validation Loss: 0.5878
Epoch 109/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.59s/it, loss=0.549]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.11it/s, loss=0.733]


Validation Dice Score
Class 0: 0.9879, Class 1: 0.5057, Class 2: 0.0960, Class 3: 0.3824, 
Class 4: 0.2724, Class 5: 0.1659, Class 6: 0.0000, 
Training Loss: 0.6109, Validation Loss: 0.7127
Epoch 110/4000


Training: 100%|██████████| 24/24 [00:39<00:00,  1.63s/it, loss=0.728]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.17it/s, loss=0.665]


Validation Dice Score
Class 0: 0.9848, Class 1: 0.4200, Class 2: 0.1388, Class 3: 0.4214, 
Class 4: 0.1667, Class 5: 0.2077, Class 6: 0.1667, 
Training Loss: 0.6425, Validation Loss: 0.6714
Epoch 111/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.61s/it, loss=0.657]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.20it/s, loss=0.723]


Validation Dice Score
Class 0: 0.9851, Class 1: 0.5139, Class 2: 0.1971, Class 3: 0.3293, 
Class 4: 0.3117, Class 5: 0.1452, Class 6: 0.0006, 
Training Loss: 0.6383, Validation Loss: 0.6961
Epoch 112/4000


Training: 100%|██████████| 24/24 [00:37<00:00,  1.58s/it, loss=0.636]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, loss=0.673]


Validation Dice Score
Class 0: 0.9886, Class 1: 0.5296, Class 2: 0.0839, Class 3: 0.3637, 
Class 4: 0.5756, Class 5: 0.2433, Class 6: 0.2975, 
Training Loss: 0.6570, Validation Loss: 0.6145
Epoch 113/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.59s/it, loss=0.671]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.12it/s, loss=0.657]


Validation Dice Score
Class 0: 0.9866, Class 1: 0.4166, Class 2: 0.1771, Class 3: 0.5087, 
Class 4: 0.2054, Class 5: 0.2731, Class 6: 0.1586, 
Training Loss: 0.6518, Validation Loss: 0.6508
Epoch 114/4000


Training: 100%|██████████| 24/24 [00:37<00:00,  1.57s/it, loss=0.721]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.07it/s, loss=0.652]


Validation Dice Score
Class 0: 0.9869, Class 1: 0.7760, Class 2: 0.0401, Class 3: 0.4065, 
Class 4: 0.4448, Class 5: 0.2948, Class 6: 0.5428, 
Training Loss: 0.6218, Validation Loss: 0.5810
Epoch 115/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.62s/it, loss=0.619]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.05it/s, loss=0.611]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.7145, Class 2: 0.1469, Class 3: 0.5956, 
Class 4: 0.3636, Class 5: 0.3006, Class 6: 0.5384, 
Training Loss: 0.6138, Validation Loss: 0.5844
Epoch 116/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.59s/it, loss=0.603]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.12it/s, loss=0.571]


Validation Dice Score
Class 0: 0.9852, Class 1: 0.7324, Class 2: 0.1643, Class 3: 0.4902, 
Class 4: 0.4298, Class 5: 0.2639, Class 6: 0.2273, 
Training Loss: 0.6336, Validation Loss: 0.5787
Epoch 117/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.60s/it, loss=0.543]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.10it/s, loss=0.729]


Validation Dice Score
Class 0: 0.9859, Class 1: 0.5963, Class 2: 0.0354, Class 3: 0.4894, 
Class 4: 0.3174, Class 5: 0.2263, Class 6: 0.0766, 
Training Loss: 0.6160, Validation Loss: 0.6918
Epoch 118/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.59s/it, loss=0.581]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.05it/s, loss=0.488]


Validation Dice Score
Class 0: 0.9872, Class 1: 0.5361, Class 2: 0.2364, Class 3: 0.6158, 
Class 4: 0.2550, Class 5: 0.2629, Class 6: 0.2599, 
Training Loss: 0.6074, Validation Loss: 0.6029
Epoch 119/4000


Training: 100%|██████████| 24/24 [00:39<00:00,  1.64s/it, loss=0.657]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.11it/s, loss=0.603]


Validation Dice Score
Class 0: 0.9897, Class 1: 0.7011, Class 2: 0.2070, Class 3: 0.3493, 
Class 4: 0.3250, Class 5: 0.2730, Class 6: 0.4138, 
Training Loss: 0.6394, Validation Loss: 0.6128
Epoch 120/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.60s/it, loss=0.679]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.08it/s, loss=0.717]


Validation Dice Score
Class 0: 0.9901, Class 1: 0.6231, Class 2: 0.0119, Class 3: 0.3960, 
Class 4: 0.4532, Class 5: 0.2634, Class 6: 0.3311, 
Training Loss: 0.6292, Validation Loss: 0.6495
Epoch 121/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.62s/it, loss=0.64] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.10it/s, loss=0.587]


Validation Dice Score
Class 0: 0.9868, Class 1: 0.6142, Class 2: 0.1596, Class 3: 0.4247, 
Class 4: 0.2203, Class 5: 0.3206, Class 6: 0.4530, 
Training Loss: 0.6431, Validation Loss: 0.6044
Epoch 122/4000


Training: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it, loss=0.661]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.16it/s, loss=0.671]


Validation Dice Score
Class 0: 0.9860, Class 1: 0.3619, Class 2: 0.0617, Class 3: 0.4890, 
Class 4: 0.4692, Class 5: 0.2793, Class 6: 0.3269, 
Training Loss: 0.6170, Validation Loss: 0.6463
Epoch 123/4000


Training: 100%|██████████| 24/24 [00:36<00:00,  1.51s/it, loss=0.539]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.17it/s, loss=0.646]


Validation Dice Score
Class 0: 0.9871, Class 1: 0.6851, Class 2: 0.1216, Class 3: 0.2990, 
Class 4: 0.0892, Class 5: 0.2455, Class 6: 0.3118, 
Training Loss: 0.6368, Validation Loss: 0.6687
Epoch 124/4000


Training: 100%|██████████| 24/24 [00:37<00:00,  1.55s/it, loss=0.618]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.05it/s, loss=0.577]


Validation Dice Score
Class 0: 0.9899, Class 1: 0.7206, Class 2: 0.0152, Class 3: 0.2569, 
Class 4: 0.1684, Class 5: 0.1350, Class 6: 0.7041, 
Training Loss: 0.6444, Validation Loss: 0.6330
Epoch 125/4000


Training: 100%|██████████| 24/24 [00:39<00:00,  1.65s/it, loss=0.65] 
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.15it/s, loss=0.536]


Validation Dice Score
Class 0: 0.9877, Class 1: 0.8129, Class 2: 0.0478, Class 3: 0.4204, 
Class 4: 0.4128, Class 5: 0.2114, Class 6: 0.4924, 
Training Loss: 0.6450, Validation Loss: 0.5748
Epoch 126/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.59s/it, loss=0.702]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.13it/s, loss=0.631]


Validation Dice Score
Class 0: 0.9844, Class 1: 0.6788, Class 2: 0.0611, Class 3: 0.3793, 
Class 4: 0.4212, Class 5: 0.1872, Class 6: 0.3428, 
Training Loss: 0.6280, Validation Loss: 0.6337
Epoch 127/4000


Training: 100%|██████████| 24/24 [00:38<00:00,  1.59s/it, loss=0.779]
Validation: 100%|██████████| 4/4 [00:01<00:00,  2.11it/s, loss=0.665]


Validation Dice Score
Class 0: 0.9849, Class 1: 0.6429, Class 2: 0.0027, Class 3: 0.3862, 
Class 4: 0.2535, Class 5: 0.2937, Class 6: 0.3981, 
Training Loss: 0.6357, Validation Loss: 0.6415
Epoch 128/4000


Training:  38%|███▊      | 9/24 [00:14<00:24,  1.60s/it, loss=0.608]

In [None]:
if:

SyntaxError: invalid syntax (879943805.py, line 1)

# Inference

In [None]:
from src.dataset.preprocessing import Preprocessor

In [None]:
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, GaussianSmoothd
from monai.data import DataLoader, Dataset, CacheDataset
from monai.networks.nets import SwinUNETR
from pathlib import Path
import numpy as np
import copick

import torch
print("Done.")

Done.


In [None]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
          "name" : "beta-amylase",
            "is_particle": true,
            "pdb_id": "8ZRZ",
            "label": 2,
            "color": [255, 255, 255, 128],
            "radius": 90,
            "map_threshold": 0.0578  
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "./kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "./kaggle/input/czii-cryo-et-object-identification/test/static"
}"""

copick_config_path = "./kaggle/working/copick.config"
preprocessor = Preprocessor(config_blob,copick_config_path=copick_config_path)
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    GaussianSmoothd(
        keys=["image"],      # 변환을 적용할 키
        sigma=[1.0, 1.0, 1.0]  # 각 축(x, y, z)의 시그마 값
        ),
    ])

Config file written to ./kaggle/working/copick.config
file length: 7


In [None]:
img_size = 96
img_depth = img_size
n_classes = 7 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_path = "./model_checkpoints/SwinUNETR96_96_lr0.001_lambda0.52_batch2/best_model.pt"
model = SwinUNETR(
    img_size=(img_depth, img_size, img_size),
    in_channels=1,
    out_channels=n_classes,
    feature_size=48,
    use_checkpoint=True,
).to(device)
# Pretrained weights 불러오기
checkpoint = torch.load(pretrain_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(pretrain_path, map_location=device)


<All keys matched successfully>

In [None]:
import torch
import numpy as np
from scipy.ndimage import label, center_of_mass
import pandas as pd
from tqdm import tqdm
from monai.data import CacheDataset, DataLoader
from monai.transforms import Compose, NormalizeIntensity
import cc3d

def dict_to_df(coord_dict, experiment_name):
    all_coords = []
    all_labels = []
    
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    all_coords = np.vstack(all_coords)
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })
    return df

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]

model.eval()
with torch.no_grad():
    location_dfs = []  # DataFrame 리스트로 초기화
    
    for vol_idx, run in enumerate(preprocessor.root.runs):
        print(f"Processing volume {vol_idx + 1}/{len(preprocessor.root.runs)}")
        tomogram = preprocessor.processing(run=run, task="task")
        task_files = [{"image": tomogram}]
        task_ds = CacheDataset(data=task_files, transform=non_random_transforms)
        task_loader = DataLoader(task_ds, batch_size=1, num_workers=0)
        
        for task_data in task_loader:
            images = task_data['image'].to("cuda")
            outputs = sliding_window_inference(
                inputs=images,
                roi_size=(96, 96, 96),  # ROI 크기
                sw_batch_size=4,
                predictor=model.forward,
                overlap=0.1,
                sw_device="cuda",
                device="cpu",
                buffer_steps=1,
                buffer_dim=-1
            )
            outputs = outputs.argmax(dim=1).squeeze(0).cpu().numpy()  # 클래스 채널 예측
            location = {}  # 좌표 저장용 딕셔너리
            for c in classes:
                cc = cc3d.connected_components(outputs == c)  # cc3d 라벨링
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 스케일 변환
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]  # 크기 필터링
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])  # 좌표 스왑 (z, y, x -> x, y, z)

                location[id_to_name[c]] = xyz  # ID 이름 매칭 저장

            # 데이터프레임 변환
            df = dict_to_df(location, run.name)
            location_dfs.append(df)  # 리스트에 추가
        
        # if vol_idx == 2:
        #     break
    
    # DataFrame 병합
    final_df = pd.concat(location_dfs, ignore_index=True)
    
    # ID 추가 및 CSV 저장
    final_df.insert(loc=0, column='id', value=np.arange(len(final_df)))
    final_df.to_csv("submission.csv", index=False)
    print("Submission saved to: submission.csv")


Processing volume 1/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.94s/it]


Processing volume 2/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Processing volume 3/7


Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Submission saved to: submission.csv
