In [1]:
from typing import List, Tuple, Union
import numpy as np
import torch
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
import os
from monai.transforms import (
    Compose, 
    EnsureChannelFirstd, 
    Orientationd,  
    AsDiscrete,  
    RandFlipd, 
    RandRotate90d, 
    NormalizeIntensityd,
    RandCropByLabelClassesd,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRAIN_IMG_DIR = "./datasets/train/images"
TRAIN_LABEL_DIR = "./datasets/train/labels"
VAL_IMG_DIR = "./datasets/val/images"
VAL_LABEL_DIR = "./datasets/val/labels"

train_list = os.listdir(TRAIN_IMG_DIR)
val_list = os.listdir(VAL_IMG_DIR)
train_files = []
valid_files = []


for name in train_list:
    train_image = np.load(os.path.join(TRAIN_IMG_DIR, f"{name}"))    
    train_label = np.load(os.path.join(TRAIN_LABEL_DIR, f"{name.replace("image", "label")}"))

    train_files.append({"image": train_image, "label": train_label})    

for name in val_list:
    valid_image = np.load(os.path.join(VAL_IMG_DIR, f"{name}"))
    valid_label = np.load(os.path.join(VAL_LABEL_DIR, f"{name.replace("image", "label")}"))

    valid_files.append({"image": valid_image, "label": valid_label})

In [3]:
# Non-random transforms to be cached
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim='no_channel'),
    NormalizeIntensityd(keys="image"),
    # Orientationd(keys=["image", "label"], axcodes="ASR")
])

raw_train_ds = CacheDataset(data=train_files, transform=non_random_transforms, cache_rate=1.0)


my_num_samples = 1
train_batch_size = 1

xy_patch = 96
z_patch = 3
# Random transforms to be applied during training
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[z_patch, xy_patch, xy_patch],
        num_classes=7,
        num_samples=my_num_samples
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),    
])


train_ds = Dataset(data=raw_train_ds, transform=random_transforms)


# DataLoader remains the same
train_loader = DataLoader(
    train_ds,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)


Loading dataset: 100%|██████████| 24/24 [00:01<00:00, 12.18it/s]


In [4]:
# 데이터 검증 함수
def inspect_batch(loader):
    # 첫 번째 배치 가져오기
    batch = next(iter(loader))
    
    print("=== 배치 데이터 검증 ===")
    print(f"이미지 shape: {batch['image'].shape}")
    print(f"이미지 dtype: {batch['image'].dtype}")
    print(f"이미지 값 범위: [{batch['image'].min():.3f}, {batch['image'].max():.3f}]")
    print("\n")
    print(f"라벨 shape: {batch['label'].shape}")
    print(f"라벨 dtype: {batch['label'].dtype}")
    print(f"라벨 고유값: {torch.unique(batch['label'])}")

# 실행
inspect_batch(train_loader)
inspect_batch(train_loader)
inspect_batch(train_loader)

=== 배치 데이터 검증 ===
이미지 shape: torch.Size([1, 1, 96, 96, 3])
이미지 dtype: torch.float32
이미지 값 범위: [-7.705, 3.054]


라벨 shape: torch.Size([1, 1, 96, 96, 3])
라벨 dtype: torch.uint8
라벨 고유값: tensor([0, 6], dtype=torch.uint8)
=== 배치 데이터 검증 ===
이미지 shape: torch.Size([1, 1, 3, 96, 96])
이미지 dtype: torch.float32
이미지 값 범위: [-12.571, 1.839]


라벨 shape: torch.Size([1, 1, 3, 96, 96])
라벨 dtype: torch.uint8
라벨 고유값: tensor([0, 1], dtype=torch.uint8)
=== 배치 데이터 검증 ===
이미지 shape: torch.Size([1, 1, 96, 96, 3])
이미지 dtype: torch.float32
이미지 값 범위: [-5.037, 5.835]


라벨 shape: torch.Size([1, 1, 96, 96, 3])
라벨 dtype: torch.uint8
라벨 고유값: tensor([0, 3, 6], dtype=torch.uint8)


# Validation DataLoader

In [5]:
import os
import numpy as np
import json
from pathlib import Path

valid_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image")
])

def load_validation_patches(patch_dir):
    """
    저장된 validation 패치들을 로드하여 데이터셋용 리스트 생성
    Args:
        patch_dir: 패치가 저장된 디렉토리 (images, labels 서브디렉토리와 coordinates.json 포함)
    """
    patch_dir = Path(patch_dir)
    val_patched_data = []
    
    # coordinates.json 로드
    with open(patch_dir / "coordinates.json", 'r') as f:
        coordinates = json.load(f)
    
    # 각 패치에 대해
    for coord in coordinates:
        image = np.load(patch_dir / "images" / coord["patch_file"])
        label = np.load(patch_dir / "labels" / coord["patch_file"])
        
        val_patched_data.append({
            "image": image,      # shape: (11, 96, 96)
            "label": label,      # shape: (96, 96)
            "coords": coord      # 원본 위치 정보 (옵션)
        })
    
    return val_patched_data

# 패치 데이터 로드
val_patched_data = load_validation_patches("./datasets/val_patches")

# Dataset과 DataLoader 설정
valid_ds = CacheDataset(data=val_patched_data, transform=valid_transforms, cache_rate=1.0)

valid_batch_size = train_batch_size
valid_loader = DataLoader(
    valid_ds,
    batch_size=valid_batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)


Loading dataset: 100%|██████████| 26496/26496 [00:14<00:00, 1778.34it/s]


In [6]:
# # 데이터 검증 함수
# def inspect_batch(loader):
#     # 첫 번째 배치 가져오기
#     batch = next(iter(loader))
    
#     print("=== 배치 데이터 검증 ===")
#     print(f"이미지 shape: {batch['image'].shape}")
#     print(f"이미지 dtype: {batch['image'].dtype}")
#     print(f"이미지 값 범위: [{batch['image'].min():.3f}, {batch['image'].max():.3f}]")
#     print("\n")
#     print(f"라벨 shape: {batch['label'].shape}")
#     print(f"라벨 dtype: {batch['label'].dtype}")
#     print(f"라벨 고유값: {torch.unique(batch['label'])}")

# # 실행
# inspect_batch(valid_loader)

In [None]:
from CSANet.CSANet.networks.vit_seg_modeling import VisionTransformer as ViT_seg
from CSANet.CSANet.networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from monai.losses import TverskyLoss

# Model Configuration
vit_name = 'R50-ViT-B_16'
config_vit = CONFIGS_ViT_seg[vit_name]
config_vit.n_classes = 7  # Your number of classes
config_vit.n_skip = 3
img_size = 96  # Match your patch size
vit_patches_size = 16

# Initialize model
config_vit.patches.grid = (int(img_size / vit_patches_size), int(img_size / vit_patches_size))
model = ViT_seg(config_vit, img_size=img_size, num_classes=config_vit.n_classes)
model.cuda()

# Load pretrained weights
model.load_from(weights=np.load(config_vit.real_pretrained_path, allow_pickle=True))

# Training setup
num_epochs = 40
optimizer = optim.AdamW(model.parameters(), lr=0.001)
# TverskyLoss 설정
criterion = TverskyLoss(
    alpha=0.3,  # FP에 대한 가중치
    beta=0.7,   # FN에 대한 가중치
    include_background=True,
    softmax=True
)


# Training loop 수정
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    with tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}') as pbar:
        for batch_data in pbar:
            images = batch_data['image'].cuda()
            labels = batch_data['label'].cuda()
            
            # 차원 분리 및 reshape
            # print(images.shape)
            if images.shape[2] == 96:
                images = images.permute(0, 1, 4, 2, 3)
                labels = labels.permute(0, 1, 4, 2, 3)
                # print(images.shape)
            # prev_image, image, next_image = torch.split(images, 1, dim=2)
            center = images.shape[2] // 2
            prev_image = images[:, :, 0:center, :, :]
            image = images[:, :, center:center+1, :, :]
            next_image = images[:, :, center+1:, :, :]
            
            # 필요한 차원 형태로 변환
            prev_image = prev_image.squeeze(2)  # [B, C, H, W]
            image = image.squeeze(2)
            next_image = next_image.squeeze(2)
            
            # 라벨 처리
            labels = labels[:, :, center:center+1, :, :]
            labels = labels.squeeze(2)  # [B, 1, H, W]
            labels = labels.squeeze(1)  # [B, H, W]

            optimizer.zero_grad()
            outputs = model(prev_image, image, next_image)

            # 라벨을 long으로 변환하고 one-hot encoding
            labels = labels.long()
            labels_onehot = torch.nn.functional.one_hot(labels, num_classes=config_vit.n_classes)  # [B, H, W, C]
            # 차원 순서 변경: [B, H, W, C] -> [B, C, H, W]
            labels_onehot = labels_onehot.permute(0, 3, 1, 2).float()
            
            loss = criterion(outputs, labels_onehot)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})
    
    # Training loop는 그대로 유지

    # Training loop 내부의 validation과 saving 코드 수정
    # Validation
    if (epoch + 1) % 1 == 0:  # Validate every 5 epochs
        model.eval()
        val_loss = 0
        with torch.no_grad():
            val_loop = tqdm(valid_loader, desc=f'Validation Epoch {epoch+1}', leave=False)
            for val_data in val_loop:
                val_images = val_data['image'].cuda()
                val_labels = val_data['label'].cuda()
                
                # Training과 동일한 방식으로 데이터 처리
                val_prev, val_curr, val_next = torch.split(val_images, 1, dim=2)
                val_prev = val_prev.squeeze(2)
                val_curr = val_curr.squeeze(2)
                val_next = val_next.squeeze(2)
                
                val_labels = val_labels.squeeze(1)
                val_labels = val_labels.long()
                
                val_outputs = model(val_prev, val_curr, val_next)
                
                val_labels_onehot = torch.nn.functional.one_hot(val_labels, num_classes=config_vit.n_classes)
                val_labels_onehot = val_labels_onehot.permute(0, 3, 1, 2).float()
                
                current_loss = criterion(val_outputs, val_labels_onehot).item()
                val_loss += current_loss
                val_loop.set_postfix({'val_loss': current_loss})
        
        print(f'Validation Loss: {val_loss/len(valid_loader):.4f}')
        
    # Save checkpoint with progress bar
    if (epoch + 1) % 10 == 0:  # Save every 10 epochs
        with tqdm(total=1, desc=f'Saving checkpoint for epoch {epoch+1}', leave=False) as pbar:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
            }, f'checkpoint_epoch_{epoch+1}.pth')
            pbar.update(1)


load_pretrained: grid-size from 14 to 6


Epoch 1/40: 100%|██████████| 24/24 [00:20<00:00,  1.19it/s, loss=0.886]
Validation Epoch 1:   5%|▌         | 1456/26496 [05:43<1:44:17,  4.00it/s, val_loss=0.888]